kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (159) hide show
  1. eva/core/callbacks/__init__.py +2 -2
  2. eva/core/callbacks/writers/__init__.py +6 -3
  3. eva/core/callbacks/writers/embeddings/__init__.py +6 -0
  4. eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
  5. eva/core/callbacks/writers/embeddings/base.py +192 -0
  6. eva/core/callbacks/writers/embeddings/classification.py +117 -0
  7. eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
  8. eva/core/callbacks/writers/embeddings/typings.py +38 -0
  9. eva/core/data/datasets/__init__.py +2 -2
  10. eva/core/data/datasets/classification/__init__.py +8 -0
  11. eva/core/data/datasets/classification/embeddings.py +34 -0
  12. eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
  13. eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
  14. eva/core/data/splitting/__init__.py +6 -0
  15. eva/core/data/splitting/random.py +41 -0
  16. eva/core/data/splitting/stratified.py +56 -0
  17. eva/core/loggers/experimental_loggers.py +2 -2
  18. eva/core/loggers/log/__init__.py +3 -2
  19. eva/core/loggers/log/image.py +71 -0
  20. eva/core/loggers/log/parameters.py +10 -0
  21. eva/core/loggers/loggers.py +6 -0
  22. eva/core/metrics/__init__.py +6 -2
  23. eva/core/metrics/defaults/__init__.py +10 -3
  24. eva/core/metrics/defaults/classification/__init__.py +1 -1
  25. eva/core/metrics/defaults/classification/binary.py +0 -9
  26. eva/core/metrics/defaults/classification/multiclass.py +0 -8
  27. eva/core/metrics/defaults/segmentation/__init__.py +5 -0
  28. eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
  29. eva/core/metrics/generalized_dice.py +59 -0
  30. eva/core/metrics/mean_iou.py +120 -0
  31. eva/core/metrics/structs/schemas.py +3 -1
  32. eva/core/models/__init__.py +3 -1
  33. eva/core/models/modules/head.py +10 -4
  34. eva/core/models/modules/typings.py +14 -1
  35. eva/core/models/modules/utils/batch_postprocess.py +37 -5
  36. eva/core/models/networks/__init__.py +1 -2
  37. eva/core/models/networks/mlp.py +2 -2
  38. eva/core/models/transforms/__init__.py +6 -0
  39. eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
  40. eva/core/models/transforms/extract_patch_features.py +47 -0
  41. eva/core/models/wrappers/__init__.py +13 -0
  42. eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
  43. eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
  44. eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
  45. eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
  46. eva/core/trainers/functional.py +1 -0
  47. eva/core/utils/__init__.py +6 -0
  48. eva/core/utils/clone.py +27 -0
  49. eva/core/utils/memory.py +28 -0
  50. eva/core/utils/operations.py +26 -0
  51. eva/core/utils/parser.py +20 -0
  52. eva/vision/__init__.py +2 -2
  53. eva/vision/callbacks/__init__.py +5 -0
  54. eva/vision/callbacks/loggers/__init__.py +5 -0
  55. eva/vision/callbacks/loggers/batch/__init__.py +5 -0
  56. eva/vision/callbacks/loggers/batch/base.py +130 -0
  57. eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
  58. eva/vision/data/datasets/__init__.py +30 -3
  59. eva/vision/data/datasets/_validators.py +15 -2
  60. eva/vision/data/datasets/classification/__init__.py +12 -1
  61. eva/vision/data/datasets/classification/bach.py +10 -15
  62. eva/vision/data/datasets/classification/base.py +17 -24
  63. eva/vision/data/datasets/classification/camelyon16.py +244 -0
  64. eva/vision/data/datasets/classification/crc.py +10 -15
  65. eva/vision/data/datasets/classification/mhist.py +10 -15
  66. eva/vision/data/datasets/classification/panda.py +184 -0
  67. eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
  68. eva/vision/data/datasets/classification/wsi.py +105 -0
  69. eva/vision/data/datasets/segmentation/__init__.py +15 -2
  70. eva/vision/data/datasets/segmentation/_utils.py +38 -0
  71. eva/vision/data/datasets/segmentation/base.py +16 -17
  72. eva/vision/data/datasets/segmentation/bcss.py +236 -0
  73. eva/vision/data/datasets/segmentation/consep.py +156 -0
  74. eva/vision/data/datasets/segmentation/embeddings.py +34 -0
  75. eva/vision/data/datasets/segmentation/lits.py +178 -0
  76. eva/vision/data/datasets/segmentation/monusac.py +236 -0
  77. eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
  78. eva/vision/data/datasets/wsi.py +187 -0
  79. eva/vision/data/transforms/__init__.py +3 -2
  80. eva/vision/data/transforms/common/__init__.py +2 -1
  81. eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
  82. eva/vision/data/transforms/common/resize_and_crop.py +6 -7
  83. eva/vision/data/transforms/normalization/__init__.py +6 -0
  84. eva/vision/data/transforms/normalization/clamp.py +43 -0
  85. eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
  86. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
  87. eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
  88. eva/vision/data/wsi/__init__.py +16 -0
  89. eva/vision/data/wsi/backends/__init__.py +69 -0
  90. eva/vision/data/wsi/backends/base.py +115 -0
  91. eva/vision/data/wsi/backends/openslide.py +73 -0
  92. eva/vision/data/wsi/backends/pil.py +52 -0
  93. eva/vision/data/wsi/backends/tiffslide.py +42 -0
  94. eva/vision/data/wsi/patching/__init__.py +6 -0
  95. eva/vision/data/wsi/patching/coordinates.py +98 -0
  96. eva/vision/data/wsi/patching/mask.py +123 -0
  97. eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
  98. eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
  99. eva/vision/data/wsi/patching/samplers/base.py +48 -0
  100. eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
  101. eva/vision/data/wsi/patching/samplers/grid.py +47 -0
  102. eva/vision/data/wsi/patching/samplers/random.py +41 -0
  103. eva/vision/losses/__init__.py +5 -0
  104. eva/vision/losses/dice.py +40 -0
  105. eva/vision/models/__init__.py +4 -2
  106. eva/vision/models/modules/__init__.py +5 -0
  107. eva/vision/models/modules/semantic_segmentation.py +161 -0
  108. eva/vision/models/networks/__init__.py +1 -2
  109. eva/vision/models/networks/backbones/__init__.py +6 -0
  110. eva/vision/models/networks/backbones/_utils.py +39 -0
  111. eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
  112. eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
  113. eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
  114. eva/vision/models/networks/backbones/pathology/histai.py +46 -0
  115. eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
  116. eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
  117. eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
  118. eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
  119. eva/vision/models/networks/backbones/registry.py +47 -0
  120. eva/vision/models/networks/backbones/timm/__init__.py +5 -0
  121. eva/vision/models/networks/backbones/timm/backbones.py +54 -0
  122. eva/vision/models/networks/backbones/universal/__init__.py +8 -0
  123. eva/vision/models/networks/backbones/universal/vit.py +54 -0
  124. eva/vision/models/networks/decoders/__init__.py +6 -0
  125. eva/vision/models/networks/decoders/decoder.py +7 -0
  126. eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
  127. eva/vision/models/networks/decoders/segmentation/common.py +74 -0
  128. eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
  129. eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
  130. eva/vision/models/wrappers/__init__.py +6 -0
  131. eva/vision/models/wrappers/from_registry.py +48 -0
  132. eva/vision/models/wrappers/from_timm.py +68 -0
  133. eva/vision/utils/colormap.py +77 -0
  134. eva/vision/utils/convert.py +56 -13
  135. eva/vision/utils/io/__init__.py +10 -4
  136. eva/vision/utils/io/image.py +21 -2
  137. eva/vision/utils/io/mat.py +36 -0
  138. eva/vision/utils/io/nifti.py +33 -12
  139. eva/vision/utils/io/text.py +10 -3
  140. kaiko_eva-0.1.1.dist-info/METADATA +553 -0
  141. kaiko_eva-0.1.1.dist-info/RECORD +205 -0
  142. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
  143. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
  144. eva/.DS_Store +0 -0
  145. eva/core/callbacks/writers/embeddings.py +0 -169
  146. eva/core/callbacks/writers/typings.py +0 -23
  147. eva/core/data/datasets/embeddings/__init__.py +0 -13
  148. eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
  149. eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
  150. eva/core/models/networks/transforms/__init__.py +0 -5
  151. eva/core/models/networks/wrappers/__init__.py +0 -8
  152. eva/vision/models/.DS_Store +0 -0
  153. eva/vision/models/networks/.DS_Store +0 -0
  154. eva/vision/models/networks/postprocesses/__init__.py +0 -5
  155. eva/vision/models/networks/postprocesses/cls.py +0 -25
  156. kaiko_eva-0.0.2.dist-info/METADATA +0 -431
  157. kaiko_eva-0.0.2.dist-info/RECORD +0 -127
  158. /eva/core/models/{networks → wrappers}/_utils.py +0 -0
  159. {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,178 @@
1
+ """LiTS dataset."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import torch
9
+ from torchvision import tv_tensors
10
+ from typing_extensions import override
11
+
12
+ from eva.core import utils
13
+ from eva.vision.data.datasets import _utils as data_utils
14
+ from eva.vision.data.datasets import _validators
15
+ from eva.vision.data.datasets.segmentation import base
16
+ from eva.vision.utils import io
17
+
18
+
19
+ class LiTS(base.ImageSegmentation):
20
+ """LiTS - Liver Tumor Segmentation Challenge.
21
+
22
+ Webpage: https://competitions.codalab.org/competitions/17094
23
+
24
+ For the splits we follow: https://arxiv.org/pdf/2010.01663v2
25
+ """
26
+
27
+ _train_index_ranges: List[Tuple[int, int]] = [(0, 102)]
28
+ _val_index_ranges: List[Tuple[int, int]] = [(102, 117)]
29
+ _test_index_ranges: List[Tuple[int, int]] = [(117, 131)]
30
+ """Index ranges per split."""
31
+
32
+ _sample_every_n_slices: int | None = None
33
+ """The amount of slices to sub-sample per 3D CT scan image."""
34
+
35
+ _expected_dataset_lengths: Dict[str | None, int] = {
36
+ "train": 39307,
37
+ "val": 12045,
38
+ "test": 7286,
39
+ None: 58638,
40
+ }
41
+ """Dataset version and split to the expected size."""
42
+
43
+ _license: str = (
44
+ "Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License "
45
+ "(https://creativecommons.org/licenses/by-nc-nd/4.0/deed.en)"
46
+ )
47
+ """Dataset license."""
48
+
49
+ def __init__(
50
+ self,
51
+ root: str,
52
+ split: Literal["train", "val", "test"] | None = None,
53
+ transforms: Callable | None = None,
54
+ ) -> None:
55
+ """Initialize dataset.
56
+
57
+ Args:
58
+ root: Path to the root directory of the dataset. The dataset will
59
+ be downloaded and extracted here, if it does not already exist.
60
+ split: Dataset split to use.
61
+ transforms: A function/transforms that takes in an image and a target
62
+ mask and returns the transformed versions of both.
63
+ """
64
+ super().__init__(transforms=transforms)
65
+
66
+ self._root = root
67
+ self._split = split
68
+
69
+ self._indices: List[Tuple[int, int]] = []
70
+
71
+ @property
72
+ @override
73
+ def classes(self) -> List[str]:
74
+ return ["liver", "tumor"]
75
+
76
+ @functools.cached_property
77
+ @override
78
+ def class_to_idx(self) -> Dict[str, int]:
79
+ return {label: index for index, label in enumerate(self.classes)}
80
+
81
+ @override
82
+ def filename(self, index: int) -> str:
83
+ sample_index, _ = self._indices[index]
84
+ volume_file_path = self._volume_files[sample_index]
85
+ return os.path.relpath(volume_file_path, self._root)
86
+
87
+ @override
88
+ def configure(self) -> None:
89
+ self._indices = self._create_indices()
90
+
91
+ @override
92
+ def validate(self) -> None:
93
+ if len(self._volume_files) != len(self._segmentation_files):
94
+ raise ValueError(
95
+ "The number of volume files does not match the number of the segmentation ones."
96
+ )
97
+
98
+ _validators.check_dataset_integrity(
99
+ self,
100
+ length=self._expected_dataset_lengths.get(self._split, 0),
101
+ n_classes=2,
102
+ first_and_last_labels=("liver", "tumor"),
103
+ )
104
+
105
+ @override
106
+ def load_image(self, index: int) -> tv_tensors.Image:
107
+ sample_index, slice_index = self._indices[index]
108
+ volume_path = self._volume_files[sample_index]
109
+ image_array = io.read_nifti(volume_path, slice_index)
110
+ return tv_tensors.Image(image_array.transpose(2, 0, 1))
111
+
112
+ @override
113
+ def load_mask(self, index: int) -> tv_tensors.Mask:
114
+ sample_index, slice_index = self._indices[index]
115
+ segmentation_path = self._segmentation_files[sample_index]
116
+ semantic_labels = io.read_nifti(segmentation_path, slice_index)
117
+ return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
118
+
119
+ @override
120
+ def load_metadata(self, index: int) -> Dict[str, Any]:
121
+ _, slice_index = self._indices[index]
122
+ return {"slice_index": slice_index}
123
+
124
+ @override
125
+ def __len__(self) -> int:
126
+ return len(self._indices)
127
+
128
+ def _get_number_of_slices_per_volume(self, sample_index: int) -> int:
129
+ """Returns the total amount of slices of a volume."""
130
+ file_path = self._volume_files[sample_index]
131
+ volume_shape = io.fetch_nifti_shape(file_path)
132
+ return volume_shape[-1]
133
+
134
+ @functools.cached_property
135
+ def _volume_files(self) -> List[str]:
136
+ files_pattern = os.path.join(self._root, "**", "volume-*.nii")
137
+ files = glob.glob(files_pattern, recursive=True)
138
+ return utils.numeric_sort(files)
139
+
140
+ @functools.cached_property
141
+ def _segmentation_files(self) -> List[str]:
142
+ files_pattern = os.path.join(self._root, "**", "segmentation-*.nii")
143
+ files = glob.glob(files_pattern, recursive=True)
144
+ return utils.numeric_sort(files)
145
+
146
+ def _create_indices(self) -> List[Tuple[int, int]]:
147
+ """Builds the dataset indices for the specified split.
148
+
149
+ Returns:
150
+ A list of tuples, where the first value indicates the
151
+ sample index which the second its corresponding slice
152
+ index.
153
+ """
154
+ indices = [
155
+ (sample_idx, slide_idx)
156
+ for sample_idx in self._get_split_indices()
157
+ for slide_idx in range(self._get_number_of_slices_per_volume(sample_idx))
158
+ if slide_idx % (self._sample_every_n_slices or 1) == 0
159
+ ]
160
+ return indices
161
+
162
+ def _get_split_indices(self) -> List[int]:
163
+ """Returns the sample indices for the specified dataset split."""
164
+ split_index_ranges = {
165
+ "train": self._train_index_ranges,
166
+ "val": self._val_index_ranges,
167
+ "test": self._test_index_ranges,
168
+ None: [(0, len(self._volume_files))],
169
+ }
170
+ index_ranges = split_index_ranges.get(self._split)
171
+ if index_ranges is None:
172
+ raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.")
173
+
174
+ return data_utils.ranges_to_indices(index_ranges)
175
+
176
+ def _print_license(self) -> None:
177
+ """Prints the dataset license."""
178
+ print(f"Dataset license: {self._license}")
@@ -0,0 +1,236 @@
1
+ """MoNuSAC dataset."""
2
+
3
+ import functools
4
+ import glob
5
+ import os
6
+ from typing import Any, Callable, Dict, List, Literal
7
+ from xml.etree import ElementTree # nosec
8
+
9
+ import imagesize
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+ import torch
13
+ import tqdm
14
+ from skimage import draw
15
+ from torchvision import tv_tensors
16
+ from torchvision.datasets import utils
17
+ from typing_extensions import override
18
+
19
+ from eva.vision.data.datasets import _validators, structs
20
+ from eva.vision.data.datasets.segmentation import base
21
+ from eva.vision.utils import io
22
+
23
+
24
+ class MoNuSAC(base.ImageSegmentation):
25
+ """MoNuSAC2020: A Multi-organ Nuclei Segmentation and Classification Challenge.
26
+
27
+ Webpage: https://monusac-2020.grand-challenge.org/
28
+ """
29
+
30
+ _expected_dataset_lengths: Dict[str, int] = {
31
+ "train": 209,
32
+ "test": 85,
33
+ }
34
+ """Dataset version and split to the expected size."""
35
+
36
+ _resources: List[structs.DownloadResource] = [
37
+ structs.DownloadResource(
38
+ filename="MoNuSAC_images_and_annotations.zip",
39
+ url="https://drive.google.com/file/d/1lxMZaAPSpEHLSxGA9KKMt_r-4S8dwLhq/view?usp=sharing",
40
+ ),
41
+ structs.DownloadResource(
42
+ filename="MoNuSAC Testing Data and Annotations.zip",
43
+ url="https://drive.google.com/file/d/1G54vsOdxWY1hG7dzmkeK3r0xz9s-heyQ/view?usp=sharing",
44
+ ),
45
+ ]
46
+ """Resources for the full dataset version."""
47
+
48
+ _license: str = (
49
+ "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International "
50
+ "(https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode)"
51
+ )
52
+ """Dataset license."""
53
+
54
+ def __init__(
55
+ self,
56
+ root: str,
57
+ split: Literal["train", "test"],
58
+ export_masks: bool = True,
59
+ download: bool = False,
60
+ transforms: Callable | None = None,
61
+ ) -> None:
62
+ """Initialize dataset.
63
+
64
+ Args:
65
+ root: Path to the root directory of the dataset. The dataset will
66
+ be downloaded and extracted here, if it does not already exist.
67
+ split: Dataset split to use.
68
+ export_masks: Whether to export, save and use the semantic label masks
69
+ from disk.
70
+ download: Whether to download the data for the specified split.
71
+ Note that the download will be executed only by additionally
72
+ calling the :meth:`prepare_data` method and if the data does not
73
+ exist yet on disk.
74
+ transforms: A function/transforms that takes in an image and a target
75
+ mask and returns the transformed versions of both.
76
+ """
77
+ super().__init__(transforms=transforms)
78
+
79
+ self._root = root
80
+ self._split = split
81
+ self._export_masks = export_masks
82
+ self._download = download
83
+
84
+ @property
85
+ @override
86
+ def classes(self) -> List[str]:
87
+ return ["Epithelial", "Lymphocyte", "Neutrophil", "Macrophage"]
88
+
89
+ @functools.cached_property
90
+ @override
91
+ def class_to_idx(self) -> Dict[str, int]:
92
+ return {label: index for index, label in enumerate(self.classes)}
93
+
94
+ @override
95
+ def filename(self, index: int) -> str:
96
+ return os.path.relpath(self._image_files[index], self._root)
97
+
98
+ @override
99
+ def prepare_data(self) -> None:
100
+ if self._download:
101
+ self._download_dataset()
102
+ if self._export_masks:
103
+ self._export_semantic_label_masks()
104
+
105
+ @override
106
+ def validate(self) -> None:
107
+ _validators.check_dataset_integrity(
108
+ self,
109
+ length=self._expected_dataset_lengths.get(self._split, 0),
110
+ n_classes=4,
111
+ first_and_last_labels=("Epithelial", "Macrophage"),
112
+ )
113
+
114
+ @override
115
+ def load_image(self, index: int) -> tv_tensors.Image:
116
+ image_path = self._image_files[index]
117
+ image_rgb_array = io.read_image(image_path)
118
+ return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
119
+
120
+ @override
121
+ def load_mask(self, index: int) -> tv_tensors.Mask:
122
+ semantic_labels = (
123
+ self._load_semantic_mask_file(index)
124
+ if self._export_masks
125
+ else self._get_semantic_mask(index)
126
+ )
127
+ return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
128
+
129
+ @override
130
+ def __len__(self) -> int:
131
+ return len(self._image_files)
132
+
133
+ @functools.cached_property
134
+ def _image_files(self) -> List[str]:
135
+ """Return the list of image files in the dataset.
136
+
137
+ Returns:
138
+ List of image file paths.
139
+ """
140
+ files_pattern = os.path.join(self._data_directory, "**", "*.tif")
141
+ image_files = glob.glob(files_pattern, recursive=True)
142
+ return sorted(image_files)
143
+
144
+ @functools.cached_property
145
+ def _data_directory(self) -> str:
146
+ """Returns the data directory of the dataset."""
147
+ match self._split:
148
+ case "train":
149
+ directory = "MoNuSAC_images_and_annotations"
150
+ case "test":
151
+ directory = "MoNuSAC Testing Data and Annotations"
152
+ case _:
153
+ raise ValueError(f"Invalid 'split' value '{self._split}'.")
154
+
155
+ return os.path.join(self._root, directory)
156
+
157
+ def _export_semantic_label_masks(self) -> None:
158
+ """Export semantic label masks to disk."""
159
+ mask_files = [
160
+ (index, filename.replace(".tif", ".npy"))
161
+ for index, filename in enumerate(self._image_files)
162
+ ]
163
+ to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
164
+ for sample_index, filename in tqdm.tqdm(
165
+ list(to_export),
166
+ desc=">> Exporting semantic masks",
167
+ leave=False,
168
+ ):
169
+ semantic_labels = self._get_semantic_mask(sample_index)
170
+ np.save(filename, semantic_labels)
171
+
172
+ def _load_semantic_mask_file(self, index: int) -> npt.NDArray[Any]:
173
+ """Load a semantic mask file from disk.
174
+
175
+ Args:
176
+ index: Index of the mask file to load.
177
+
178
+ Returns:
179
+ Loaded mask as a numpy array.
180
+ """
181
+ mask_filename = self._image_files[index].replace(".tif", ".npy")
182
+ return np.load(mask_filename)
183
+
184
+ def _get_semantic_mask(self, index: int) -> npt.NDArray[Any]:
185
+ """Builds and loads the semantic label mask from the XML annotations.
186
+
187
+ Args:
188
+ index: Index of the image file.
189
+
190
+ Returns:
191
+ Semantic label mask as a numpy array.
192
+ """
193
+ image_path = self._image_files[index]
194
+ width, height = imagesize.get(image_path)
195
+ annotation_path = image_path.replace(".tif", ".xml")
196
+ element_tree = ElementTree.parse(annotation_path) # nosec
197
+ root = element_tree.getroot()
198
+
199
+ semantic_labels = np.zeros((height, width), "uint8") # type: ignore[reportCallIssue]
200
+ for level in range(len(root)):
201
+ label = [item.attrib["Name"] for item in root[level][0]][0]
202
+ class_id = self.class_to_idx.get(label, 254) + 1
203
+ # for the test dataset an additional class 'Ambiguous' was added for
204
+ # difficult regions with fuzzy boundaries - we return it as 255
205
+ regions = [item for child in root[level] for item in child if item.tag == "Region"]
206
+ for region in regions:
207
+ vertices = np.array(
208
+ [(vertex.attrib["X"], vertex.attrib["Y"]) for vertex in region[1]],
209
+ dtype=np.dtype(float),
210
+ )
211
+ fill_row_coords, fill_col_coords = draw.polygon(
212
+ vertices[:, 0],
213
+ vertices[:, 1],
214
+ (width, height),
215
+ )
216
+ semantic_labels[fill_col_coords, fill_row_coords] = class_id
217
+
218
+ return semantic_labels
219
+
220
+ def _download_dataset(self) -> None:
221
+ """Downloads the dataset."""
222
+ self._print_license()
223
+ for resource in self._resources:
224
+ if os.path.isdir(self._data_directory):
225
+ continue
226
+
227
+ utils.download_and_extract_archive(
228
+ resource.url,
229
+ download_root=self._root,
230
+ filename=resource.filename,
231
+ remove_finished=True,
232
+ )
233
+
234
+ def _print_license(self) -> None:
235
+ """Prints the dataset license."""
236
+ print(f"Dataset license: {self._license}")