kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,44 @@
1
+ """Dataset validation related functions."""
2
+
3
+ from typing_extensions import List, Tuple
4
+
5
+ from eva.vision.data.datasets import vision
6
+
7
+ _SUFFIX_ERROR_MESSAGE = "Please verify that the data are properly downloaded and stored."
8
+ """Common suffix dataset verification error message."""
9
+
10
+
11
+ def check_dataset_integrity(
12
+ dataset: vision.VisionDataset,
13
+ *,
14
+ length: int,
15
+ n_classes: int,
16
+ first_and_last_labels: Tuple[str, str],
17
+ ) -> None:
18
+ """Verifies the datasets integrity.
19
+
20
+ Raise:
21
+ ValuesError: If the input dataset's values do not
22
+ match the expected ones.
23
+ """
24
+ if len(dataset) != length:
25
+ raise ValueError(
26
+ f"Dataset's '{dataset.__class__.__qualname__}' length "
27
+ f"({len(dataset)}) does not match the expected one ({length}). "
28
+ f"{_SUFFIX_ERROR_MESSAGE}"
29
+ )
30
+
31
+ dataset_classes: List[str] = getattr(dataset, "classes", [])
32
+ if dataset_classes and len(dataset_classes) != n_classes:
33
+ raise ValueError(
34
+ f"Dataset's '{dataset.__class__.__qualname__}' number of classes "
35
+ f"({len(dataset_classes)}) does not match the expected one ({n_classes})."
36
+ f"{_SUFFIX_ERROR_MESSAGE}"
37
+ )
38
+
39
+ if dataset_classes and (dataset_classes[0], dataset_classes[-1]) != first_and_last_labels:
40
+ raise ValueError(
41
+ f"Dataset's '{dataset.__class__.__qualname__}' first and last labels "
42
+ f"({(dataset_classes[0], dataset_classes[-1])}) does not match the expected "
43
+ f"ones ({first_and_last_labels}). {_SUFFIX_ERROR_MESSAGE}"
44
+ )
@@ -0,0 +1,15 @@
1
+ """Image classification datasets API."""
2
+
3
+ from eva.vision.data.datasets.classification.bach import BACH
4
+ from eva.vision.data.datasets.classification.crc import CRC
5
+ from eva.vision.data.datasets.classification.mhist import MHIST
6
+ from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
7
+ from eva.vision.data.datasets.classification.total_segmentator import TotalSegmentatorClassification
8
+
9
+ __all__ = [
10
+ "BACH",
11
+ "CRC",
12
+ "MHIST",
13
+ "PatchCamelyon",
14
+ "TotalSegmentatorClassification",
15
+ ]
@@ -0,0 +1,174 @@
1
+ """BACH dataset class."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, List, Literal, Tuple
5
+
6
+ import numpy as np
7
+ from torchvision.datasets import folder, utils
8
+ from typing_extensions import override
9
+
10
+ from eva.vision.data.datasets import _utils, _validators, structs
11
+ from eva.vision.data.datasets.classification import base
12
+ from eva.vision.utils import io
13
+
14
+
15
+ class BACH(base.ImageClassification):
16
+ """Dataset class for BACH images and corresponding targets."""
17
+
18
+ _train_index_ranges: List[Tuple[int, int]] = [
19
+ (0, 41),
20
+ (59, 60),
21
+ (90, 139),
22
+ (169, 240),
23
+ (258, 260),
24
+ (273, 345),
25
+ (368, 400),
26
+ ]
27
+ """Train range indices."""
28
+
29
+ _val_index_ranges: List[Tuple[int, int]] = [
30
+ (41, 59),
31
+ (60, 90),
32
+ (139, 169),
33
+ (240, 258),
34
+ (260, 273),
35
+ (345, 368),
36
+ ]
37
+ """Validation range indices."""
38
+
39
+ _resources: List[structs.DownloadResource] = [
40
+ structs.DownloadResource(
41
+ filename="ICIAR2018_BACH_Challenge.zip",
42
+ url="https://zenodo.org/records/3632035/files/ICIAR2018_BACH_Challenge.zip",
43
+ ),
44
+ ]
45
+ """Dataset resources."""
46
+
47
+ _license: str = "CC BY-NC-ND 4.0 (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode)"
48
+ """Dataset license."""
49
+
50
+ def __init__(
51
+ self,
52
+ root: str,
53
+ split: Literal["train", "val"] | None = None,
54
+ download: bool = False,
55
+ image_transforms: Callable | None = None,
56
+ target_transforms: Callable | None = None,
57
+ ) -> None:
58
+ """Initialize the dataset.
59
+
60
+ The dataset is split into train and validation by taking into account
61
+ the patient IDs to avoid any data leakage.
62
+
63
+ Args:
64
+ root: Path to the root directory of the dataset. The dataset will
65
+ be downloaded and extracted here, if it does not already exist.
66
+ split: Dataset split to use. If `None`, the entire dataset is used.
67
+ download: Whether to download the data for the specified split.
68
+ Note that the download will be executed only by additionally
69
+ calling the :meth:`prepare_data` method and if the data does
70
+ not yet exist on disk.
71
+ image_transforms: A function/transform that takes in an image
72
+ and returns a transformed version.
73
+ target_transforms: A function/transform that takes in the target
74
+ and transforms it.
75
+ """
76
+ super().__init__(
77
+ image_transforms=image_transforms,
78
+ target_transforms=target_transforms,
79
+ )
80
+
81
+ self._root = root
82
+ self._split = split
83
+ self._download = download
84
+
85
+ self._samples: List[Tuple[str, int]] = []
86
+ self._indices: List[int] = []
87
+
88
+ @property
89
+ @override
90
+ def classes(self) -> List[str]:
91
+ return ["Benign", "InSitu", "Invasive", "Normal"]
92
+
93
+ @property
94
+ @override
95
+ def class_to_idx(self) -> Dict[str, int]:
96
+ return {"Benign": 0, "InSitu": 1, "Invasive": 2, "Normal": 3}
97
+
98
+ @property
99
+ def dataset_path(self) -> str:
100
+ """Returns the path of the image data of the dataset."""
101
+ return os.path.join(self._root, "ICIAR2018_BACH_Challenge", "Photos")
102
+
103
+ @override
104
+ def filename(self, index: int) -> str:
105
+ image_path, _ = self._samples[self._indices[index]]
106
+ return os.path.relpath(image_path, self.dataset_path)
107
+
108
+ @override
109
+ def prepare_data(self) -> None:
110
+ if self._download:
111
+ self._download_dataset()
112
+
113
+ @override
114
+ def configure(self) -> None:
115
+ self._samples = folder.make_dataset(
116
+ directory=self.dataset_path,
117
+ class_to_idx=self.class_to_idx,
118
+ extensions=(".tif"),
119
+ )
120
+ self._indices = self._make_indices()
121
+
122
+ @override
123
+ def validate(self) -> None:
124
+ _validators.check_dataset_integrity(
125
+ self,
126
+ length=268 if self._split == "train" else 132,
127
+ n_classes=4,
128
+ first_and_last_labels=("Benign", "Normal"),
129
+ )
130
+
131
+ @override
132
+ def load_image(self, index: int) -> np.ndarray:
133
+ image_path, _ = self._samples[self._indices[index]]
134
+ return io.read_image(image_path)
135
+
136
+ @override
137
+ def load_target(self, index: int) -> np.ndarray:
138
+ _, target = self._samples[self._indices[index]]
139
+ return np.asarray(target, dtype=np.int64)
140
+
141
+ @override
142
+ def __len__(self) -> int:
143
+ return len(self._indices)
144
+
145
+ def _download_dataset(self) -> None:
146
+ """Downloads the dataset."""
147
+ for resource in self._resources:
148
+ if os.path.isdir(self.dataset_path):
149
+ continue
150
+
151
+ self._print_license()
152
+ utils.download_and_extract_archive(
153
+ resource.url,
154
+ download_root=self._root,
155
+ filename=resource.filename,
156
+ remove_finished=True,
157
+ )
158
+
159
+ def _print_license(self) -> None:
160
+ """Prints the dataset license."""
161
+ print(f"Dataset license: {self._license}")
162
+
163
+ def _make_indices(self) -> List[int]:
164
+ """Builds the dataset indices for the specified split."""
165
+ split_index_ranges = {
166
+ "train": self._train_index_ranges,
167
+ "val": self._val_index_ranges,
168
+ None: [(0, 400)],
169
+ }
170
+ index_ranges = split_index_ranges.get(self._split)
171
+ if index_ranges is None:
172
+ raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
173
+
174
+ return _utils.ranges_to_indices(index_ranges)
@@ -0,0 +1,103 @@
1
+ """Base for image classification datasets."""
2
+
3
+ import abc
4
+ from typing import Any, Callable, Dict, List, Tuple
5
+
6
+ import numpy as np
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.datasets import vision
10
+
11
+
12
+ class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
13
+ """Image classification abstract dataset."""
14
+
15
+ def __init__(
16
+ self,
17
+ image_transforms: Callable | None = None,
18
+ target_transforms: Callable | None = None,
19
+ ) -> None:
20
+ """Initializes the image classification dataset.
21
+
22
+ Args:
23
+ image_transforms: A function/transform that takes in an image
24
+ and returns a transformed version.
25
+ target_transforms: A function/transform that takes in the target
26
+ and transforms it.
27
+ """
28
+ super().__init__()
29
+
30
+ self._image_transforms = image_transforms
31
+ self._target_transforms = target_transforms
32
+
33
+ @property
34
+ def classes(self) -> List[str] | None:
35
+ """Returns the list with names of the dataset names."""
36
+
37
+ @property
38
+ def class_to_idx(self) -> Dict[str, int] | None:
39
+ """Returns a mapping of the class name to its target index."""
40
+
41
+ def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
42
+ """Returns the dataset metadata.
43
+
44
+ Args:
45
+ index: The index of the data sample to return the metadata of.
46
+ If `None`, it will return the metadata of the current dataset.
47
+
48
+ Returns:
49
+ The sample metadata.
50
+ """
51
+
52
+ @abc.abstractmethod
53
+ def load_image(self, index: int) -> np.ndarray:
54
+ """Returns the `index`'th image sample.
55
+
56
+ Args:
57
+ index: The index of the data sample to load.
58
+
59
+ Returns:
60
+ The image as a numpy array.
61
+ """
62
+
63
+ @abc.abstractmethod
64
+ def load_target(self, index: int) -> np.ndarray:
65
+ """Returns the `index`'th target sample.
66
+
67
+ Args:
68
+ index: The index of the data sample to load.
69
+
70
+ Returns:
71
+ The sample target as an array.
72
+ """
73
+
74
+ @abc.abstractmethod
75
+ @override
76
+ def __len__(self) -> int:
77
+ raise NotImplementedError
78
+
79
+ @override
80
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
81
+ image = self.load_image(index)
82
+ target = self.load_target(index)
83
+ return self._apply_transforms(image, target)
84
+
85
+ def _apply_transforms(
86
+ self, image: np.ndarray, target: np.ndarray
87
+ ) -> Tuple[np.ndarray, np.ndarray]:
88
+ """Applies the transforms to the provided data and returns them.
89
+
90
+ Args:
91
+ image: The desired image.
92
+ target: The target of the image.
93
+
94
+ Returns:
95
+ A tuple with the image and the target transformed.
96
+ """
97
+ if self._image_transforms is not None:
98
+ image = self._image_transforms(image)
99
+
100
+ if self._target_transforms is not None:
101
+ target = self._target_transforms(target)
102
+
103
+ return image, target
@@ -0,0 +1,176 @@
1
+ """CRC dataset class."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, List, Literal, Tuple
5
+
6
+ import numpy as np
7
+ from torchvision.datasets import folder, utils
8
+ from typing_extensions import override
9
+
10
+ from eva.vision.data.datasets import _validators, structs
11
+ from eva.vision.data.datasets.classification import base
12
+ from eva.vision.utils import io
13
+
14
+
15
+ class CRC(base.ImageClassification):
16
+ """Dataset class for CRC images and corresponding targets."""
17
+
18
+ _train_resource: structs.DownloadResource = structs.DownloadResource(
19
+ filename="NCT-CRC-HE-100K.zip",
20
+ url="https://zenodo.org/records/1214456/files/NCT-CRC-HE-100K.zip?download=1",
21
+ md5="md5:035777cf327776a71a05c95da6d6325f",
22
+ )
23
+ """Train resource."""
24
+
25
+ _val_resource: structs.DownloadResource = structs.DownloadResource(
26
+ filename="CRC-VAL-HE-7K.zip",
27
+ url="https://zenodo.org/records/1214456/files/CRC-VAL-HE-7K.zip?download=1",
28
+ md5="md5:2fd1651b4f94ebd818ebf90ad2b6ce06",
29
+ )
30
+ """Validation resource."""
31
+
32
+ _license: str = "CC BY 4.0 LEGAL CODE (https://creativecommons.org/licenses/by/4.0/legalcode)"
33
+ """Dataset license."""
34
+
35
+ def __init__(
36
+ self,
37
+ root: str,
38
+ split: Literal["train", "val"],
39
+ download: bool = False,
40
+ image_transforms: Callable | None = None,
41
+ target_transforms: Callable | None = None,
42
+ ) -> None:
43
+ """Initializes the dataset.
44
+
45
+ The dataset is split into a train (train) and validation (val) set:
46
+ - train: A set of 100,000 non-overlapping image patches from
47
+ hematoxylin & eosin (H&E) stained histological images of human
48
+ colorectal cancer (CRC) and normal tissue.
49
+ - val: A set of 7180 image patches from N=50 patients with colorectal
50
+ adenocarcinoma (no overlap with patients in NCT-CRC-HE-100K).
51
+
52
+ Args:
53
+ root: Path to the root directory of the dataset.
54
+ split: Dataset split to use.
55
+ download: Whether to download the data for the specified split.
56
+ Note that the download will be executed only by additionally
57
+ calling the :meth:`prepare_data` method and if the data does
58
+ not yet exist on disk.
59
+ image_transforms: A function/transform that takes in an image
60
+ and returns a transformed version.
61
+ target_transforms: A function/transform that takes in the target
62
+ and transforms it.
63
+ """
64
+ super().__init__(
65
+ image_transforms=image_transforms,
66
+ target_transforms=target_transforms,
67
+ )
68
+
69
+ self._root = root
70
+ self._split = split
71
+ self._download = download
72
+
73
+ self._samples: List[Tuple[str, int]] = []
74
+
75
+ @property
76
+ @override
77
+ def classes(self) -> List[str]:
78
+ return ["ADI", "BACK", "DEB", "LYM", "MUC", "MUS", "NORM", "STR", "TUM"]
79
+
80
+ @property
81
+ @override
82
+ def class_to_idx(self) -> Dict[str, int]:
83
+ return {
84
+ "ADI": 0,
85
+ "BACK": 1,
86
+ "DEB": 2,
87
+ "LYM": 3,
88
+ "MUC": 4,
89
+ "MUS": 5,
90
+ "NORM": 6,
91
+ "STR": 7,
92
+ "TUM": 8,
93
+ }
94
+
95
+ @override
96
+ def filename(self, index: int) -> str:
97
+ image_path, *_ = self._samples[index]
98
+ return os.path.relpath(image_path, self._dataset_dir)
99
+
100
+ @override
101
+ def prepare_data(self) -> None:
102
+ if self._download:
103
+ self._download_dataset()
104
+
105
+ @override
106
+ def configure(self) -> None:
107
+ self._samples = self._make_dataset()
108
+
109
+ @override
110
+ def validate(self) -> None:
111
+ expected_length = {
112
+ "train": 100000,
113
+ "val": 7180,
114
+ None: 107180,
115
+ }
116
+ _validators.check_dataset_integrity(
117
+ self,
118
+ length=expected_length.get(self._split, 0),
119
+ n_classes=9,
120
+ first_and_last_labels=("ADI", "TUM"),
121
+ )
122
+
123
+ @override
124
+ def load_image(self, index: int) -> np.ndarray:
125
+ image_path, _ = self._samples[index]
126
+ return io.read_image(image_path)
127
+
128
+ @override
129
+ def load_target(self, index: int) -> np.ndarray:
130
+ _, target = self._samples[index]
131
+ return np.asarray(target, dtype=np.int64)
132
+
133
+ @override
134
+ def __len__(self) -> int:
135
+ return len(self._samples)
136
+
137
+ @property
138
+ def _dataset_dir(self) -> str:
139
+ """Returns the full path of dataset directory."""
140
+ dataset_dirs = {
141
+ "train": os.path.join(self._root, "NCT-CRC-HE-100K"),
142
+ "val": os.path.join(self._root, "CRC-VAL-HE-7K"),
143
+ }
144
+ dataset_dir = dataset_dirs.get(self._split)
145
+ if dataset_dir is None:
146
+ raise ValueError("Invalid data split. Use 'train' or 'val'.")
147
+
148
+ return dataset_dir
149
+
150
+ def _make_dataset(self) -> List[Tuple[str, int]]:
151
+ """Builds the dataset for the specified split."""
152
+ dataset = folder.make_dataset(
153
+ directory=self._dataset_dir,
154
+ class_to_idx=self.class_to_idx,
155
+ extensions=(".tif"),
156
+ )
157
+ return dataset
158
+
159
+ def _download_dataset(self) -> None:
160
+ """Downloads the dataset resources."""
161
+ for resource in [self._train_resource, self._val_resource]:
162
+ resource_dir = resource.filename.rsplit(".", maxsplit=1)[0]
163
+ if os.path.isdir(os.path.join(self._root, resource_dir)):
164
+ continue
165
+
166
+ self._print_license()
167
+ utils.download_and_extract_archive(
168
+ resource.url,
169
+ download_root=self._root,
170
+ filename=resource.filename,
171
+ remove_finished=True,
172
+ )
173
+
174
+ def _print_license(self) -> None:
175
+ """Prints the dataset license."""
176
+ print(f"Dataset license: {self._license}")
@@ -0,0 +1,106 @@
1
+ """MHIST dataset class."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, List, Literal, Tuple
5
+
6
+ import numpy as np
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.datasets import _validators
10
+ from eva.vision.data.datasets.classification import base
11
+ from eva.vision.utils import io
12
+
13
+
14
+ class MHIST(base.ImageClassification):
15
+ """MHIST dataset."""
16
+
17
+ def __init__(
18
+ self,
19
+ root: str,
20
+ split: Literal["train", "test"],
21
+ image_transforms: Callable | None = None,
22
+ target_transforms: Callable | None = None,
23
+ ) -> None:
24
+ """Initialize the dataset.
25
+
26
+ Args:
27
+ root: Path to the root directory of the dataset.
28
+ split: Dataset split to use.
29
+ image_transforms: A function/transform that takes in an image
30
+ and returns a transformed version.
31
+ target_transforms: A function/transform that takes in the target
32
+ and transforms it.
33
+ """
34
+ super().__init__(
35
+ image_transforms=image_transforms,
36
+ target_transforms=target_transforms,
37
+ )
38
+
39
+ self._root = root
40
+ self._split = split
41
+
42
+ self._samples: List[Tuple[str, str]] = []
43
+
44
+ @property
45
+ @override
46
+ def classes(self) -> List[str]:
47
+ return ["SSA", "HP"]
48
+
49
+ @property
50
+ @override
51
+ def class_to_idx(self) -> Dict[str, int]:
52
+ return {"SSA": 0, "HP": 1}
53
+
54
+ @override
55
+ def filename(self, index: int) -> str:
56
+ image_filename, _ = self._samples[index]
57
+ return image_filename
58
+
59
+ @override
60
+ def configure(self) -> None:
61
+ self._samples = self._make_dataset()
62
+
63
+ @override
64
+ def validate(self) -> None:
65
+ _validators.check_dataset_integrity(
66
+ self,
67
+ length=2175 if self._split == "train" else 977,
68
+ n_classes=2,
69
+ first_and_last_labels=("SSA", "HP"),
70
+ )
71
+
72
+ @override
73
+ def load_image(self, index: int) -> np.ndarray:
74
+ image_filename, _ = self._samples[index]
75
+ image_path = os.path.join(self._dataset_path, image_filename)
76
+ return io.read_image(image_path)
77
+
78
+ @override
79
+ def load_target(self, index: int) -> np.ndarray:
80
+ _, label = self._samples[index]
81
+ target = self.class_to_idx[label]
82
+ return np.asarray(target, dtype=np.int64)
83
+
84
+ @override
85
+ def __len__(self) -> int:
86
+ return len(self._samples)
87
+
88
+ def _make_dataset(self) -> List[Tuple[str, str]]:
89
+ """Generates and returns a list of samples of a form (image_filename, label)."""
90
+ data = io.read_csv(self._annotations_path)
91
+ samples = [
92
+ (sample["Image Name"], sample["Majority Vote Label"])
93
+ for sample in data
94
+ if sample["Partition"] == self._split
95
+ ]
96
+ return samples
97
+
98
+ @property
99
+ def _dataset_path(self) -> str:
100
+ """Returns the path of the image data of the dataset."""
101
+ return os.path.join(self._root, "images")
102
+
103
+ @property
104
+ def _annotations_path(self) -> str:
105
+ """Returns the path of the annotations file of the dataset."""
106
+ return os.path.join(self._root, "annotations.csv")