kaiko-eva 0.0.0.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (111) hide show
  1. eva/.DS_Store +0 -0
  2. eva/__init__.py +33 -0
  3. eva/__main__.py +18 -0
  4. eva/__version__.py +25 -0
  5. eva/core/__init__.py +19 -0
  6. eva/core/callbacks/__init__.py +5 -0
  7. eva/core/callbacks/writers/__init__.py +5 -0
  8. eva/core/callbacks/writers/embeddings.py +169 -0
  9. eva/core/callbacks/writers/typings.py +23 -0
  10. eva/core/cli/__init__.py +5 -0
  11. eva/core/cli/cli.py +19 -0
  12. eva/core/cli/logo.py +38 -0
  13. eva/core/cli/setup.py +89 -0
  14. eva/core/data/__init__.py +14 -0
  15. eva/core/data/dataloaders/__init__.py +5 -0
  16. eva/core/data/dataloaders/dataloader.py +80 -0
  17. eva/core/data/datamodules/__init__.py +6 -0
  18. eva/core/data/datamodules/call.py +33 -0
  19. eva/core/data/datamodules/datamodule.py +108 -0
  20. eva/core/data/datamodules/schemas.py +62 -0
  21. eva/core/data/datasets/__init__.py +7 -0
  22. eva/core/data/datasets/base.py +53 -0
  23. eva/core/data/datasets/classification/__init__.py +5 -0
  24. eva/core/data/datasets/classification/embeddings.py +154 -0
  25. eva/core/data/datasets/dataset.py +6 -0
  26. eva/core/data/samplers/__init__.py +5 -0
  27. eva/core/data/samplers/sampler.py +6 -0
  28. eva/core/data/transforms/__init__.py +5 -0
  29. eva/core/data/transforms/dtype/__init__.py +5 -0
  30. eva/core/data/transforms/dtype/array.py +28 -0
  31. eva/core/interface/__init__.py +5 -0
  32. eva/core/interface/interface.py +79 -0
  33. eva/core/metrics/__init__.py +17 -0
  34. eva/core/metrics/average_loss.py +47 -0
  35. eva/core/metrics/binary_balanced_accuracy.py +22 -0
  36. eva/core/metrics/defaults/__init__.py +6 -0
  37. eva/core/metrics/defaults/classification/__init__.py +6 -0
  38. eva/core/metrics/defaults/classification/binary.py +76 -0
  39. eva/core/metrics/defaults/classification/multiclass.py +80 -0
  40. eva/core/metrics/structs/__init__.py +9 -0
  41. eva/core/metrics/structs/collection.py +6 -0
  42. eva/core/metrics/structs/metric.py +6 -0
  43. eva/core/metrics/structs/module.py +115 -0
  44. eva/core/metrics/structs/schemas.py +47 -0
  45. eva/core/metrics/structs/typings.py +15 -0
  46. eva/core/models/__init__.py +13 -0
  47. eva/core/models/modules/__init__.py +7 -0
  48. eva/core/models/modules/head.py +113 -0
  49. eva/core/models/modules/inference.py +37 -0
  50. eva/core/models/modules/module.py +190 -0
  51. eva/core/models/modules/typings.py +23 -0
  52. eva/core/models/modules/utils/__init__.py +6 -0
  53. eva/core/models/modules/utils/batch_postprocess.py +57 -0
  54. eva/core/models/modules/utils/grad.py +23 -0
  55. eva/core/models/networks/__init__.py +6 -0
  56. eva/core/models/networks/_utils.py +25 -0
  57. eva/core/models/networks/mlp.py +69 -0
  58. eva/core/models/networks/transforms/__init__.py +5 -0
  59. eva/core/models/networks/transforms/extract_cls_features.py +25 -0
  60. eva/core/models/networks/wrappers/__init__.py +8 -0
  61. eva/core/models/networks/wrappers/base.py +47 -0
  62. eva/core/models/networks/wrappers/from_function.py +58 -0
  63. eva/core/models/networks/wrappers/huggingface.py +37 -0
  64. eva/core/models/networks/wrappers/onnx.py +47 -0
  65. eva/core/trainers/__init__.py +6 -0
  66. eva/core/trainers/_logging.py +81 -0
  67. eva/core/trainers/_recorder.py +149 -0
  68. eva/core/trainers/_utils.py +12 -0
  69. eva/core/trainers/functional.py +113 -0
  70. eva/core/trainers/trainer.py +97 -0
  71. eva/core/utils/__init__.py +1 -0
  72. eva/core/utils/io/__init__.py +5 -0
  73. eva/core/utils/io/dataframe.py +21 -0
  74. eva/core/utils/multiprocessing.py +44 -0
  75. eva/core/utils/workers.py +21 -0
  76. eva/vision/__init__.py +14 -0
  77. eva/vision/data/__init__.py +5 -0
  78. eva/vision/data/datasets/__init__.py +22 -0
  79. eva/vision/data/datasets/_utils.py +50 -0
  80. eva/vision/data/datasets/_validators.py +44 -0
  81. eva/vision/data/datasets/classification/__init__.py +15 -0
  82. eva/vision/data/datasets/classification/bach.py +174 -0
  83. eva/vision/data/datasets/classification/base.py +103 -0
  84. eva/vision/data/datasets/classification/crc.py +176 -0
  85. eva/vision/data/datasets/classification/mhist.py +106 -0
  86. eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
  87. eva/vision/data/datasets/classification/total_segmentator.py +212 -0
  88. eva/vision/data/datasets/segmentation/__init__.py +6 -0
  89. eva/vision/data/datasets/segmentation/base.py +112 -0
  90. eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
  91. eva/vision/data/datasets/structs.py +17 -0
  92. eva/vision/data/datasets/vision.py +43 -0
  93. eva/vision/data/transforms/__init__.py +5 -0
  94. eva/vision/data/transforms/common/__init__.py +5 -0
  95. eva/vision/data/transforms/common/resize_and_crop.py +44 -0
  96. eva/vision/models/__init__.py +5 -0
  97. eva/vision/models/networks/__init__.py +6 -0
  98. eva/vision/models/networks/abmil.py +176 -0
  99. eva/vision/models/networks/postprocesses/__init__.py +5 -0
  100. eva/vision/models/networks/postprocesses/cls.py +25 -0
  101. eva/vision/utils/__init__.py +5 -0
  102. eva/vision/utils/io/__init__.py +12 -0
  103. eva/vision/utils/io/_utils.py +29 -0
  104. eva/vision/utils/io/image.py +54 -0
  105. eva/vision/utils/io/nifti.py +50 -0
  106. eva/vision/utils/io/text.py +18 -0
  107. kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
  108. kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
  109. kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
  110. kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
  111. kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,203 @@
1
+ """PatchCamelyon dataset."""
2
+
3
+ import os
4
+ from typing import Callable, Dict, List, Literal
5
+
6
+ import h5py
7
+ import numpy as np
8
+ from torchvision.datasets import utils
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data.datasets import _validators, structs
12
+ from eva.vision.data.datasets.classification import base
13
+
14
+ _URL_TEMPLATE = "https://zenodo.org/records/2546921/files/{filename}.gz?download=1"
15
+ """PatchCamelyon URL files templates."""
16
+
17
+
18
+ class PatchCamelyon(base.ImageClassification):
19
+ """Dataset class for PatchCamelyon images and corresponding targets."""
20
+
21
+ _train_resources: List[structs.DownloadResource] = [
22
+ structs.DownloadResource(
23
+ filename="camelyonpatch_level_2_split_train_x.h5",
24
+ url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_train_x.h5"),
25
+ md5="01844da899645b4d6f84946d417ba453",
26
+ ),
27
+ structs.DownloadResource(
28
+ filename="camelyonpatch_level_2_split_train_y.h5",
29
+ url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_train_y.h5"),
30
+ md5="0781386bf6c2fb62d58ff18891466aca",
31
+ ),
32
+ ]
33
+ """Train resources."""
34
+
35
+ _val_resources: List[structs.DownloadResource] = [
36
+ structs.DownloadResource(
37
+ filename="camelyonpatch_level_2_split_valid_x.h5",
38
+ url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_valid_x.h5"),
39
+ md5="81cf9680f1724c40673f10dc88e909b1",
40
+ ),
41
+ structs.DownloadResource(
42
+ filename="camelyonpatch_level_2_split_valid_y.h5",
43
+ url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_valid_y.h5"),
44
+ md5="94d8aacc249253159ce2a2e78a86e658",
45
+ ),
46
+ ]
47
+ """Validation resources."""
48
+
49
+ _test_resources: List[structs.DownloadResource] = [
50
+ structs.DownloadResource(
51
+ filename="camelyonpatch_level_2_split_test_x.h5",
52
+ url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_test_x.h5"),
53
+ md5="2614b2e6717d6356be141d9d6dbfcb7e",
54
+ ),
55
+ structs.DownloadResource(
56
+ filename="camelyonpatch_level_2_split_test_y.h5",
57
+ url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_test_y.h5"),
58
+ md5="11ed647efe9fe457a4eb45df1dba19ba",
59
+ ),
60
+ ]
61
+ """Test resources."""
62
+
63
+ _license: str = (
64
+ "Creative Commons Zero v1.0 Universal (https://choosealicense.com/licenses/cc0-1.0/)"
65
+ )
66
+ """Dataset license."""
67
+
68
+ def __init__(
69
+ self,
70
+ root: str,
71
+ split: Literal["train", "val", "test"],
72
+ download: bool = False,
73
+ image_transforms: Callable | None = None,
74
+ target_transforms: Callable | None = None,
75
+ ) -> None:
76
+ """Initializes the dataset.
77
+
78
+ Args:
79
+ root: The path to the dataset root. This path should contain
80
+ the uncompressed h5 files and the metadata.
81
+ split: The dataset split for training, validation, or testing.
82
+ download: Whether to download the data for the specified split.
83
+ Note that the download will be executed only by additionally
84
+ calling the :meth:`prepare_data` method.
85
+ image_transforms: A function/transform that takes in an image
86
+ and returns a transformed version.
87
+ target_transforms: A function/transform that takes in the target
88
+ and transforms it.
89
+ """
90
+ super().__init__(
91
+ image_transforms=image_transforms,
92
+ target_transforms=target_transforms,
93
+ )
94
+
95
+ self._root = root
96
+ self._split = split
97
+ self._download = download
98
+
99
+ @property
100
+ @override
101
+ def classes(self) -> List[str]:
102
+ return ["no_tumor", "tumor"]
103
+
104
+ @property
105
+ @override
106
+ def class_to_idx(self) -> Dict[str, int]:
107
+ return {"no_tumor": 0, "tumor": 1}
108
+
109
+ @override
110
+ def filename(self, index: int) -> str:
111
+ return f"camelyonpatch_level_2_split_{self._split}_x_{index}"
112
+
113
+ @override
114
+ def prepare_data(self) -> None:
115
+ if self._download:
116
+ self._download_dataset()
117
+
118
+ @override
119
+ def validate(self) -> None:
120
+ expected_length = {
121
+ "train": 262144,
122
+ "val": 32768,
123
+ "test": 32768,
124
+ }
125
+ _validators.check_dataset_integrity(
126
+ self,
127
+ length=expected_length.get(self._split, 0),
128
+ n_classes=2,
129
+ first_and_last_labels=("no_tumor", "tumor"),
130
+ )
131
+
132
+ @override
133
+ def load_image(self, index: int) -> np.ndarray:
134
+ return self._load_from_h5("x", index)
135
+
136
+ @override
137
+ def load_target(self, index: int) -> np.ndarray:
138
+ target = self._load_from_h5("y", index).squeeze()
139
+ return np.asarray(target, dtype=np.int64)
140
+
141
+ @override
142
+ def __len__(self) -> int:
143
+ return self._fetch_dataset_length()
144
+
145
+ def _download_dataset(self) -> None:
146
+ """Downloads the PatchCamelyon dataset."""
147
+ for resource in self._train_resources + self._val_resources + self._test_resources:
148
+ file_path = os.path.join(self._root, resource.filename)
149
+ if utils.check_integrity(file_path, resource.md5):
150
+ continue
151
+
152
+ self._print_license()
153
+ utils.download_and_extract_archive(
154
+ resource.url,
155
+ download_root=self._root,
156
+ filename=resource.filename + ".gz",
157
+ remove_finished=True,
158
+ )
159
+
160
+ def _load_from_h5(
161
+ self,
162
+ data_key: Literal["x", "y"],
163
+ index: int | None = None,
164
+ ) -> np.ndarray:
165
+ """Load data or targets from an HDF5 file.
166
+
167
+ Args:
168
+ data_key: Specify whether to load 'x' or 'y'.
169
+ index: Optional parameter to load data/targets at a specific index.
170
+ If `None`, the entire data/targets array is returned.
171
+
172
+ Returns:
173
+ A array containing the specified data.
174
+ """
175
+ h5_file = self._h5_file(data_key)
176
+ with h5py.File(h5_file, "r") as file:
177
+ data = file[data_key]
178
+ return data[:] if index is None else data[index] # type: ignore
179
+
180
+ def _fetch_dataset_length(self) -> int:
181
+ """Fetches the dataset split length from its HDF5 file."""
182
+ h5_file = self._h5_file("y")
183
+ with h5py.File(h5_file, "r") as file:
184
+ data = file["y"]
185
+ return len(data) # type: ignore
186
+
187
+ def _h5_file(self, datatype: Literal["x", "y"]) -> str:
188
+ """Generates the filename for the H5 file based on the specified data type and split.
189
+
190
+ Args:
191
+ datatype: The type of data, where "x" and "y" represent the input
192
+ and target datasets respectively.
193
+
194
+ Returns:
195
+ The relative file path for the H5 file based on the provided data type and split.
196
+ """
197
+ split_suffix = "valid" if self._split == "val" else self._split
198
+ filename = f"camelyonpatch_level_2_split_{split_suffix}_{datatype}.h5"
199
+ return os.path.join(self._root, filename)
200
+
201
+ def _print_license(self) -> None:
202
+ """Prints the dataset license."""
203
+ print(f"Dataset license: {self._license}")
@@ -0,0 +1,212 @@
1
+ """TotalSegmentator 2D segmentation dataset class."""
2
+
3
+ import functools
4
+ import os
5
+ from glob import glob
6
+ from typing import Callable, Dict, List, Literal, Tuple
7
+
8
+ import numpy as np
9
+ from torchvision.datasets import utils
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data.datasets import _utils, _validators, structs
13
+ from eva.vision.data.datasets.classification import base
14
+ from eva.vision.utils import io
15
+
16
+
17
+ class TotalSegmentatorClassification(base.ImageClassification):
18
+ """TotalSegmentator multi-label classification dataset."""
19
+
20
+ _train_index_ranges: List[Tuple[int, int]] = [(0, 83)]
21
+ """Train range indices."""
22
+
23
+ _val_index_ranges: List[Tuple[int, int]] = [(83, 103)]
24
+ """Validation range indices."""
25
+
26
+ _n_slices_per_image: int = 20
27
+ """The amount of slices to sample per 3D CT scan image."""
28
+
29
+ _resources_full: List[structs.DownloadResource] = [
30
+ structs.DownloadResource(
31
+ filename="Totalsegmentator_dataset_v201.zip",
32
+ url="https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_v201.zip",
33
+ md5="fe250e5718e0a3b5df4c4ea9d58a62fe",
34
+ ),
35
+ ]
36
+ """Resources for the full dataset version."""
37
+
38
+ _resources_small: List[structs.DownloadResource] = [
39
+ structs.DownloadResource(
40
+ filename="Totalsegmentator_dataset_small_v201.zip",
41
+ url="https://zenodo.org/records/10047263/files/Totalsegmentator_dataset_small_v201.zip",
42
+ md5="6b5524af4b15e6ba06ef2d700c0c73e0",
43
+ ),
44
+ ]
45
+ """Resources for the small dataset version."""
46
+
47
+ def __init__(
48
+ self,
49
+ root: str,
50
+ split: Literal["train", "val"] | None,
51
+ version: Literal["small", "full"] = "small",
52
+ download: bool = False,
53
+ image_transforms: Callable | None = None,
54
+ target_transforms: Callable | None = None,
55
+ ) -> None:
56
+ """Initialize dataset.
57
+
58
+ Args:
59
+ root: Path to the root directory of the dataset. The dataset will
60
+ be downloaded and extracted here, if it does not already exist.
61
+ split: Dataset split to use. If None, the entire dataset is used.
62
+ version: The version of the dataset to initialize.
63
+ download: Whether to download the data for the specified split.
64
+ Note that the download will be executed only by additionally
65
+ calling the :meth:`prepare_data` method and if the data does not
66
+ exist yet on disk.
67
+ image_transforms: A function/transform that takes in an image
68
+ and returns a transformed version.
69
+ target_transforms: A function/transform that takes in the target
70
+ and transforms it.
71
+ """
72
+ super().__init__(
73
+ image_transforms=image_transforms,
74
+ target_transforms=target_transforms,
75
+ )
76
+
77
+ self._root = root
78
+ self._split = split
79
+ self._version = version
80
+ self._download = download
81
+
82
+ self._samples_dirs: List[str] = []
83
+ self._indices: List[int] = []
84
+
85
+ @functools.cached_property
86
+ @override
87
+ def classes(self) -> List[str]:
88
+ def get_filename(path: str) -> str:
89
+ """Returns the filename from the full path."""
90
+ return os.path.basename(path).split(".")[0]
91
+
92
+ first_sample_labels = os.path.join(
93
+ self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
94
+ )
95
+ return sorted(map(get_filename, glob(first_sample_labels)))
96
+
97
+ @property
98
+ @override
99
+ def class_to_idx(self) -> Dict[str, int]:
100
+ return {label: index for index, label in enumerate(self.classes)}
101
+
102
+ @override
103
+ def filename(self, index: int) -> str:
104
+ sample_dir = self._samples_dirs[self._indices[index]]
105
+ return os.path.join(sample_dir, "ct.nii.gz")
106
+
107
+ @override
108
+ def prepare_data(self) -> None:
109
+ if self._download:
110
+ self._download_dataset()
111
+
112
+ @override
113
+ def configure(self) -> None:
114
+ self._samples_dirs = self._fetch_samples_dirs()
115
+ self._indices = self._create_indices()
116
+
117
+ @override
118
+ def validate(self) -> None:
119
+ _validators.check_dataset_integrity(
120
+ self,
121
+ length=1660 if self._split == "train" else 400,
122
+ n_classes=117,
123
+ first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"),
124
+ )
125
+
126
+ @override
127
+ def __len__(self) -> int:
128
+ return len(self._indices) * self._n_slices_per_image
129
+
130
+ @override
131
+ def load_image(self, index: int) -> np.ndarray:
132
+ image_path = self._get_image_path(index)
133
+ slice_index = self._get_sample_slice_index(index)
134
+ image_array = io.read_nifti_slice(image_path, slice_index)
135
+ return image_array.repeat(3, axis=2)
136
+
137
+ @override
138
+ def load_target(self, index: int) -> np.ndarray:
139
+ masks = self._load_masks(index)
140
+ targets = [1 in masks[..., mask_index] for mask_index in range(masks.shape[-1])]
141
+ return np.asarray(targets, dtype=np.int64)
142
+
143
+ def _load_masks(self, index: int) -> np.ndarray:
144
+ """Returns the `index`'th target mask sample."""
145
+ masks_dir = self._get_masks_dir(index)
146
+ slice_index = self._get_sample_slice_index(index)
147
+ mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
148
+ masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths]
149
+ return np.concatenate(masks, axis=-1)
150
+
151
+ def _get_masks_dir(self, index: int) -> str:
152
+ """Returns the directory of the corresponding masks."""
153
+ sample_dir = self._get_sample_dir(index)
154
+ return os.path.join(self._root, sample_dir, "segmentations")
155
+
156
+ def _get_image_path(self, index: int) -> str:
157
+ """Returns the corresponding image path."""
158
+ sample_dir = self._get_sample_dir(index)
159
+ return os.path.join(self._root, sample_dir, "ct.nii.gz")
160
+
161
+ def _get_sample_dir(self, index: int) -> str:
162
+ """Returns the corresponding sample directory."""
163
+ sample_index = self._indices[index // self._n_slices_per_image]
164
+ return self._samples_dirs[sample_index]
165
+
166
+ def _get_sample_slice_index(self, index: int) -> int:
167
+ """Returns the corresponding slice index."""
168
+ image_path = self._get_image_path(index)
169
+ total_slices = io.fetch_total_nifti_slices(image_path)
170
+ slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int)
171
+ return slice_indices[index % self._n_slices_per_image]
172
+
173
+ def _fetch_samples_dirs(self) -> List[str]:
174
+ """Returns the name of all the samples of all the splits of the dataset."""
175
+ sample_filenames = [
176
+ filename
177
+ for filename in os.listdir(self._root)
178
+ if os.path.isdir(os.path.join(self._root, filename))
179
+ ]
180
+ return sorted(sample_filenames)
181
+
182
+ def _create_indices(self) -> List[int]:
183
+ """Builds the dataset indices for the specified split."""
184
+ split_index_ranges = {
185
+ "train": self._train_index_ranges,
186
+ "val": self._val_index_ranges,
187
+ None: [(0, 103)],
188
+ }
189
+ index_ranges = split_index_ranges.get(self._split)
190
+ if index_ranges is None:
191
+ raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
192
+
193
+ return _utils.ranges_to_indices(index_ranges)
194
+
195
+ def _download_dataset(self) -> None:
196
+ """Downloads the dataset."""
197
+ dataset_resources = {
198
+ "small": self._resources_small,
199
+ "full": self._resources_full,
200
+ None: (0, 103),
201
+ }
202
+ resources = dataset_resources.get(self._version)
203
+ if resources is None:
204
+ raise ValueError("Invalid data version. Use 'small' or 'full'.")
205
+
206
+ for resource in resources:
207
+ utils.download_and_extract_archive(
208
+ resource.url,
209
+ download_root=self._root,
210
+ filename=resource.filename,
211
+ remove_finished=True,
212
+ )
@@ -0,0 +1,6 @@
1
+ """Segmentation datasets API."""
2
+
3
+ from eva.vision.data.datasets.segmentation.base import ImageSegmentation
4
+ from eva.vision.data.datasets.segmentation.total_segmentator import TotalSegmentator2D
5
+
6
+ __all__ = ["ImageSegmentation", "TotalSegmentator2D"]
@@ -0,0 +1,112 @@
1
+ """Base for image segmentation datasets."""
2
+
3
+ import abc
4
+ from typing import Any, Callable, Dict, List, Tuple
5
+
6
+ import numpy as np
7
+ from typing_extensions import override
8
+
9
+ from eva.vision.data.datasets import vision
10
+
11
+
12
+ class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
13
+ """Image segmentation abstract dataset."""
14
+
15
+ def __init__(
16
+ self,
17
+ image_transforms: Callable | None = None,
18
+ target_transforms: Callable | None = None,
19
+ image_target_transforms: Callable | None = None,
20
+ ) -> None:
21
+ """Initializes the image segmentation base class.
22
+
23
+ Args:
24
+ image_transforms: A function/transform that takes in an image
25
+ and returns a transformed version.
26
+ target_transforms: A function/transform that takes in the target
27
+ and transforms it.
28
+ image_target_transforms: A function/transforms that takes in an
29
+ image and a label and returns the transformed versions of both.
30
+ This transform happens after the `image_transforms` and
31
+ `target_transforms`.
32
+ """
33
+ super().__init__()
34
+
35
+ self._image_transforms = image_transforms
36
+ self._target_transforms = target_transforms
37
+ self._image_target_transforms = image_target_transforms
38
+
39
+ @property
40
+ def classes(self) -> List[str] | None:
41
+ """Returns the list with names of the dataset names."""
42
+
43
+ @property
44
+ def class_to_idx(self) -> Dict[str, int] | None:
45
+ """Returns a mapping of the class name to its target index."""
46
+
47
+ def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
48
+ """Returns the dataset metadata.
49
+
50
+ Args:
51
+ index: The index of the data sample to return the metadata of.
52
+ If `None`, it will return the metadata of the current dataset.
53
+
54
+ Returns:
55
+ The sample metadata.
56
+ """
57
+
58
+ @abc.abstractmethod
59
+ def load_image(self, index: int) -> np.ndarray:
60
+ """Loads and returns the `index`'th image sample.
61
+
62
+ Args:
63
+ index: The index of the data sample to load.
64
+
65
+ Returns:
66
+ The image as a numpy array.
67
+ """
68
+
69
+ @abc.abstractmethod
70
+ def load_mask(self, index: int) -> np.ndarray:
71
+ """Returns the `index`'th target mask sample.
72
+
73
+ Args:
74
+ index: The index of the data sample target mask to load.
75
+
76
+ Returns:
77
+ The sample mask as a stack of binary mask arrays (label, height, width).
78
+ """
79
+
80
+ @abc.abstractmethod
81
+ @override
82
+ def __len__(self) -> int:
83
+ raise NotImplementedError
84
+
85
+ @override
86
+ def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
87
+ image = self.load_image(index)
88
+ mask = self.load_mask(index)
89
+ return self._apply_transforms(image, mask)
90
+
91
+ def _apply_transforms(
92
+ self, image: np.ndarray, target: np.ndarray
93
+ ) -> Tuple[np.ndarray, np.ndarray]:
94
+ """Applies the transforms to the provided data and returns them.
95
+
96
+ Args:
97
+ image: The desired image.
98
+ target: The target of the image.
99
+
100
+ Returns:
101
+ A tuple with the image and the target transformed.
102
+ """
103
+ if self._image_transforms is not None:
104
+ image = self._image_transforms(image)
105
+
106
+ if self._target_transforms is not None:
107
+ target = self._target_transforms(target)
108
+
109
+ if self._image_target_transforms is not None:
110
+ image, target = self._image_target_transforms(image, target)
111
+
112
+ return image, target