kaiko-eva 0.1.1__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/callbacks/writers/embeddings/base.py +3 -4
- eva/core/data/dataloaders/dataloader.py +2 -2
- eva/core/data/splitting/random.py +6 -5
- eva/core/data/splitting/stratified.py +12 -6
- eva/core/losses/__init__.py +5 -0
- eva/core/losses/cross_entropy.py +27 -0
- eva/core/metrics/__init__.py +0 -4
- eva/core/metrics/defaults/__init__.py +0 -2
- eva/core/models/modules/module.py +9 -9
- eva/core/models/transforms/extract_cls_features.py +17 -9
- eva/core/models/transforms/extract_patch_features.py +23 -11
- eva/core/utils/io/__init__.py +2 -1
- eva/core/utils/io/gz.py +28 -0
- eva/core/utils/multiprocessing.py +46 -1
- eva/core/utils/progress_bar.py +15 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +7 -4
- eva/vision/data/datasets/__init__.py +4 -0
- eva/vision/data/datasets/classification/__init__.py +2 -1
- eva/vision/data/datasets/classification/camelyon16.py +4 -1
- eva/vision/data/datasets/classification/panda.py +17 -1
- eva/vision/data/datasets/classification/wsi.py +4 -1
- eva/vision/data/datasets/segmentation/__init__.py +2 -0
- eva/vision/data/datasets/segmentation/consep.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +49 -29
- eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
- eva/vision/data/datasets/segmentation/monusac.py +7 -7
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +50 -18
- eva/vision/data/datasets/wsi.py +37 -1
- eva/vision/data/wsi/patching/coordinates.py +9 -1
- eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
- eva/vision/data/wsi/patching/samplers/random.py +4 -2
- eva/vision/losses/__init__.py +2 -2
- eva/vision/losses/dice.py +75 -8
- eva/vision/metrics/__init__.py +11 -0
- eva/vision/metrics/defaults/__init__.py +7 -0
- eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
- eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
- eva/vision/metrics/segmentation/BUILD +1 -0
- eva/vision/metrics/segmentation/__init__.py +9 -0
- eva/vision/metrics/segmentation/_utils.py +69 -0
- eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
- eva/vision/metrics/segmentation/mean_iou.py +57 -0
- eva/vision/models/modules/semantic_segmentation.py +4 -3
- eva/vision/models/networks/backbones/_utils.py +12 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
- eva/vision/models/networks/backbones/pathology/histai.py +8 -2
- eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
- eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
- eva/vision/models/networks/backbones/pathology/paige.py +51 -0
- eva/vision/models/networks/decoders/__init__.py +1 -1
- eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
- eva/vision/models/networks/decoders/segmentation/base.py +16 -0
- eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
- eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
- eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
- eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
- eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
- eva/vision/utils/colormap.py +20 -0
- eva/vision/utils/io/__init__.py +7 -1
- eva/vision/utils/io/nifti.py +19 -4
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/METADATA +8 -39
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/RECORD +66 -52
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/WHEEL +1 -1
- eva/core/metrics/mean_iou.py +0 -120
- eva/vision/models/networks/decoders/decoder.py +0 -7
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,12 +5,14 @@ import glob
|
|
|
5
5
|
import os
|
|
6
6
|
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
8
10
|
import torch
|
|
9
11
|
from torchvision import tv_tensors
|
|
10
12
|
from typing_extensions import override
|
|
11
13
|
|
|
12
14
|
from eva.core import utils
|
|
13
|
-
from eva.
|
|
15
|
+
from eva.core.data import splitting
|
|
14
16
|
from eva.vision.data.datasets import _validators
|
|
15
17
|
from eva.vision.data.datasets.segmentation import base
|
|
16
18
|
from eva.vision.utils import io
|
|
@@ -20,22 +22,23 @@ class LiTS(base.ImageSegmentation):
|
|
|
20
22
|
"""LiTS - Liver Tumor Segmentation Challenge.
|
|
21
23
|
|
|
22
24
|
Webpage: https://competitions.codalab.org/competitions/17094
|
|
23
|
-
|
|
24
|
-
For the splits we follow: https://arxiv.org/pdf/2010.01663v2
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
27
|
+
_train_ratio: float = 0.7
|
|
28
|
+
_val_ratio: float = 0.15
|
|
29
|
+
_test_ratio: float = 0.15
|
|
30
30
|
"""Index ranges per split."""
|
|
31
31
|
|
|
32
|
+
_fix_orientation: bool = True
|
|
33
|
+
"""Whether to fix the orientation of the images to match the default for radiologists."""
|
|
34
|
+
|
|
32
35
|
_sample_every_n_slices: int | None = None
|
|
33
36
|
"""The amount of slices to sub-sample per 3D CT scan image."""
|
|
34
37
|
|
|
35
38
|
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
36
|
-
"train":
|
|
37
|
-
"val":
|
|
38
|
-
"test":
|
|
39
|
+
"train": 38686,
|
|
40
|
+
"val": 11192,
|
|
41
|
+
"test": 8760,
|
|
39
42
|
None: 58638,
|
|
40
43
|
}
|
|
41
44
|
"""Dataset version and split to the expected size."""
|
|
@@ -51,6 +54,7 @@ class LiTS(base.ImageSegmentation):
|
|
|
51
54
|
root: str,
|
|
52
55
|
split: Literal["train", "val", "test"] | None = None,
|
|
53
56
|
transforms: Callable | None = None,
|
|
57
|
+
seed: int = 8,
|
|
54
58
|
) -> None:
|
|
55
59
|
"""Initialize dataset.
|
|
56
60
|
|
|
@@ -60,12 +64,13 @@ class LiTS(base.ImageSegmentation):
|
|
|
60
64
|
split: Dataset split to use.
|
|
61
65
|
transforms: A function/transforms that takes in an image and a target
|
|
62
66
|
mask and returns the transformed versions of both.
|
|
67
|
+
seed: Seed used for generating the dataset splits.
|
|
63
68
|
"""
|
|
64
69
|
super().__init__(transforms=transforms)
|
|
65
70
|
|
|
66
71
|
self._root = root
|
|
67
72
|
self._split = split
|
|
68
|
-
|
|
73
|
+
self._seed = seed
|
|
69
74
|
self._indices: List[Tuple[int, int]] = []
|
|
70
75
|
|
|
71
76
|
@property
|
|
@@ -90,10 +95,12 @@ class LiTS(base.ImageSegmentation):
|
|
|
90
95
|
|
|
91
96
|
@override
|
|
92
97
|
def validate(self) -> None:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
98
|
+
for i in range(len(self._volume_files)):
|
|
99
|
+
seg_path = self._segmentation_file(i)
|
|
100
|
+
if not os.path.exists(seg_path):
|
|
101
|
+
raise FileNotFoundError(
|
|
102
|
+
f"Segmentation file {seg_path} not found for volume {self._volume_files[i]}."
|
|
103
|
+
)
|
|
97
104
|
|
|
98
105
|
_validators.check_dataset_integrity(
|
|
99
106
|
self,
|
|
@@ -107,15 +114,27 @@ class LiTS(base.ImageSegmentation):
|
|
|
107
114
|
sample_index, slice_index = self._indices[index]
|
|
108
115
|
volume_path = self._volume_files[sample_index]
|
|
109
116
|
image_array = io.read_nifti(volume_path, slice_index)
|
|
117
|
+
if self._fix_orientation:
|
|
118
|
+
image_array = self._orientation(image_array, sample_index)
|
|
110
119
|
return tv_tensors.Image(image_array.transpose(2, 0, 1))
|
|
111
120
|
|
|
112
121
|
@override
|
|
113
122
|
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
114
123
|
sample_index, slice_index = self._indices[index]
|
|
115
|
-
segmentation_path = self.
|
|
124
|
+
segmentation_path = self._segmentation_file(sample_index)
|
|
116
125
|
semantic_labels = io.read_nifti(segmentation_path, slice_index)
|
|
126
|
+
if self._fix_orientation:
|
|
127
|
+
semantic_labels = self._orientation(semantic_labels, sample_index)
|
|
117
128
|
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
118
129
|
|
|
130
|
+
def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray:
|
|
131
|
+
volume_path = self._volume_files[sample_index]
|
|
132
|
+
orientation = io.fetch_nifti_axis_direction_code(volume_path)
|
|
133
|
+
array = np.rot90(array, axes=(0, 1))
|
|
134
|
+
if orientation == "LPS":
|
|
135
|
+
array = np.flip(array, axis=0)
|
|
136
|
+
return array.copy()
|
|
137
|
+
|
|
119
138
|
@override
|
|
120
139
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
121
140
|
_, slice_index = self._indices[index]
|
|
@@ -137,11 +156,10 @@ class LiTS(base.ImageSegmentation):
|
|
|
137
156
|
files = glob.glob(files_pattern, recursive=True)
|
|
138
157
|
return utils.numeric_sort(files)
|
|
139
158
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
return utils.numeric_sort(files)
|
|
159
|
+
def _segmentation_file(self, index: int) -> str:
|
|
160
|
+
volume_file_path = self._volume_files[index]
|
|
161
|
+
segmentation_file = os.path.basename(volume_file_path).replace("volume", "segmentation")
|
|
162
|
+
return os.path.join(os.path.dirname(volume_file_path), segmentation_file)
|
|
145
163
|
|
|
146
164
|
def _create_indices(self) -> List[Tuple[int, int]]:
|
|
147
165
|
"""Builds the dataset indices for the specified split.
|
|
@@ -161,17 +179,19 @@ class LiTS(base.ImageSegmentation):
|
|
|
161
179
|
|
|
162
180
|
def _get_split_indices(self) -> List[int]:
|
|
163
181
|
"""Returns the sample indices for the specified dataset split."""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
182
|
+
indices = list(range(len(self._volume_files)))
|
|
183
|
+
train_indices, val_indices, test_indices = splitting.random_split(
|
|
184
|
+
indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed
|
|
185
|
+
)
|
|
186
|
+
split_indices_dict = {
|
|
187
|
+
"train": train_indices,
|
|
188
|
+
"val": val_indices,
|
|
189
|
+
"test": test_indices,
|
|
190
|
+
None: indices,
|
|
169
191
|
}
|
|
170
|
-
|
|
171
|
-
if index_ranges is None:
|
|
192
|
+
if self._split not in split_indices_dict:
|
|
172
193
|
raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.")
|
|
173
|
-
|
|
174
|
-
return data_utils.ranges_to_indices(index_ranges)
|
|
194
|
+
return list(split_indices_dict[self._split])
|
|
175
195
|
|
|
176
196
|
def _print_license(self) -> None:
|
|
177
197
|
"""Prints the dataset license."""
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Balanced LiTS dataset."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, Dict, List, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.vision.data.datasets.segmentation import lits
|
|
9
|
+
from eva.vision.utils import io
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LiTSBalanced(lits.LiTS):
|
|
13
|
+
"""Balanced version of the LiTS - Liver Tumor Segmentation Challenge dataset.
|
|
14
|
+
|
|
15
|
+
For each volume in the dataset, we sample the same number of slices where
|
|
16
|
+
only the liver and where both liver and tumor are present.
|
|
17
|
+
|
|
18
|
+
Webpage: https://competitions.codalab.org/competitions/17094
|
|
19
|
+
|
|
20
|
+
For the splits we follow: https://arxiv.org/pdf/2010.01663v2
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
24
|
+
"train": 5514,
|
|
25
|
+
"val": 1332,
|
|
26
|
+
"test": 1530,
|
|
27
|
+
None: 8376,
|
|
28
|
+
}
|
|
29
|
+
"""Dataset version and split to the expected size."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
root: str,
|
|
34
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
35
|
+
transforms: Callable | None = None,
|
|
36
|
+
seed: int = 8,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Initialize dataset.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
root: Path to the root directory of the dataset. The dataset will
|
|
42
|
+
be downloaded and extracted here, if it does not already exist.
|
|
43
|
+
split: Dataset split to use.
|
|
44
|
+
transforms: A function/transforms that takes in an image and a target
|
|
45
|
+
mask and returns the transformed versions of both.
|
|
46
|
+
seed: Seed used for generating the dataset splits and sampling of the slices.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(root=root, split=split, transforms=transforms, seed=seed)
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
def _create_indices(self) -> List[Tuple[int, int]]:
|
|
52
|
+
"""Builds the dataset indices for the specified split.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
A list of tuples, where the first value indicates the
|
|
56
|
+
sample index which the second its corresponding slice
|
|
57
|
+
index.
|
|
58
|
+
"""
|
|
59
|
+
split_indices = set(self._get_split_indices())
|
|
60
|
+
indices: List[Tuple[int, int]] = []
|
|
61
|
+
random_generator = np.random.default_rng(seed=self._seed)
|
|
62
|
+
|
|
63
|
+
for sample_idx in range(len(self._volume_files)):
|
|
64
|
+
if sample_idx not in split_indices:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
segmentation = io.read_nifti(self._segmentation_file(sample_idx))
|
|
68
|
+
tumor_filter = segmentation == 2
|
|
69
|
+
tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0
|
|
70
|
+
|
|
71
|
+
if tumor_filter.sum() == 0:
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
liver_filter = segmentation == 1
|
|
75
|
+
liver_slice_filter = liver_filter.sum(axis=(0, 1)) > 0
|
|
76
|
+
|
|
77
|
+
liver_and_tumor_filter = liver_slice_filter & tumor_slice_filter
|
|
78
|
+
liver_only_filter = liver_slice_filter & ~tumor_slice_filter
|
|
79
|
+
|
|
80
|
+
n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum())
|
|
81
|
+
tumor_indices = list(np.where(liver_and_tumor_filter)[0])
|
|
82
|
+
tumor_indices = list(
|
|
83
|
+
random_generator.choice(tumor_indices, size=n_slice_samples, replace=False)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
liver_indices = list(np.where(liver_only_filter)[0])
|
|
87
|
+
liver_indices = list(
|
|
88
|
+
random_generator.choice(liver_indices, size=n_slice_samples, replace=False)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices])
|
|
92
|
+
|
|
93
|
+
return list(indices)
|
|
@@ -10,12 +10,12 @@ import imagesize
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import numpy.typing as npt
|
|
12
12
|
import torch
|
|
13
|
-
import tqdm
|
|
14
13
|
from skimage import draw
|
|
15
14
|
from torchvision import tv_tensors
|
|
16
15
|
from torchvision.datasets import utils
|
|
17
16
|
from typing_extensions import override
|
|
18
17
|
|
|
18
|
+
from eva.core.utils.progress_bar import tqdm
|
|
19
19
|
from eva.vision.data.datasets import _validators, structs
|
|
20
20
|
from eva.vision.data.datasets.segmentation import base
|
|
21
21
|
from eva.vision.utils import io
|
|
@@ -84,7 +84,7 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
84
84
|
@property
|
|
85
85
|
@override
|
|
86
86
|
def classes(self) -> List[str]:
|
|
87
|
-
return ["Epithelial", "Lymphocyte", "Neutrophil", "Macrophage"]
|
|
87
|
+
return ["Background", "Epithelial", "Lymphocyte", "Neutrophil", "Macrophage", "Ambiguous"]
|
|
88
88
|
|
|
89
89
|
@functools.cached_property
|
|
90
90
|
@override
|
|
@@ -107,8 +107,8 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
107
107
|
_validators.check_dataset_integrity(
|
|
108
108
|
self,
|
|
109
109
|
length=self._expected_dataset_lengths.get(self._split, 0),
|
|
110
|
-
n_classes=
|
|
111
|
-
first_and_last_labels=("
|
|
110
|
+
n_classes=6,
|
|
111
|
+
first_and_last_labels=("Background", "Ambiguous"),
|
|
112
112
|
)
|
|
113
113
|
|
|
114
114
|
@override
|
|
@@ -161,7 +161,7 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
161
161
|
for index, filename in enumerate(self._image_files)
|
|
162
162
|
]
|
|
163
163
|
to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
|
|
164
|
-
for sample_index, filename in tqdm
|
|
164
|
+
for sample_index, filename in tqdm(
|
|
165
165
|
list(to_export),
|
|
166
166
|
desc=">> Exporting semantic masks",
|
|
167
167
|
leave=False,
|
|
@@ -199,9 +199,9 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
199
199
|
semantic_labels = np.zeros((height, width), "uint8") # type: ignore[reportCallIssue]
|
|
200
200
|
for level in range(len(root)):
|
|
201
201
|
label = [item.attrib["Name"] for item in root[level][0]][0]
|
|
202
|
-
class_id = self.class_to_idx.get(label,
|
|
202
|
+
class_id = self.class_to_idx.get(label, self.class_to_idx["Ambiguous"])
|
|
203
203
|
# for the test dataset an additional class 'Ambiguous' was added for
|
|
204
|
-
# difficult regions with fuzzy boundaries
|
|
204
|
+
# difficult regions with fuzzy boundaries
|
|
205
205
|
regions = [item for child in root[level] for item in child if item.tag == "Region"]
|
|
206
206
|
for region in regions:
|
|
207
207
|
vertices = np.array(
|
|
@@ -3,16 +3,18 @@
|
|
|
3
3
|
import functools
|
|
4
4
|
import os
|
|
5
5
|
from glob import glob
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import numpy.typing as npt
|
|
10
11
|
import torch
|
|
11
|
-
import tqdm
|
|
12
12
|
from torchvision import tv_tensors
|
|
13
13
|
from torchvision.datasets import utils
|
|
14
14
|
from typing_extensions import override
|
|
15
15
|
|
|
16
|
+
from eva.core.utils import io as core_io
|
|
17
|
+
from eva.core.utils import multiprocessing
|
|
16
18
|
from eva.vision.data.datasets import _validators, structs
|
|
17
19
|
from eva.vision.data.datasets.segmentation import base
|
|
18
20
|
from eva.vision.utils import io
|
|
@@ -65,6 +67,8 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
65
67
|
download: bool = False,
|
|
66
68
|
classes: List[str] | None = None,
|
|
67
69
|
optimize_mask_loading: bool = True,
|
|
70
|
+
decompress: bool = True,
|
|
71
|
+
num_workers: int = 10,
|
|
68
72
|
transforms: Callable | None = None,
|
|
69
73
|
) -> None:
|
|
70
74
|
"""Initialize dataset.
|
|
@@ -85,8 +89,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
85
89
|
in order to optimize the loading time. In the `setup` method, it
|
|
86
90
|
will reformat the binary one-hot masks to a semantic mask and store
|
|
87
91
|
it on disk.
|
|
92
|
+
decompress: Whether to decompress the ct.nii.gz files when preparing the data.
|
|
93
|
+
The label masks won't be decompressed, but when enabling optimize_mask_loading
|
|
94
|
+
it will export the semantic label masks to a single file in uncompressed .nii
|
|
95
|
+
format.
|
|
96
|
+
num_workers: The number of workers to use for optimizing the masks &
|
|
97
|
+
decompressing the .gz files.
|
|
88
98
|
transforms: A function/transforms that takes in an image and a target
|
|
89
99
|
mask and returns the transformed versions of both.
|
|
100
|
+
|
|
90
101
|
"""
|
|
91
102
|
super().__init__(transforms=transforms)
|
|
92
103
|
|
|
@@ -96,6 +107,8 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
96
107
|
self._download = download
|
|
97
108
|
self._classes = classes
|
|
98
109
|
self._optimize_mask_loading = optimize_mask_loading
|
|
110
|
+
self._decompress = decompress
|
|
111
|
+
self._num_workers = num_workers
|
|
99
112
|
|
|
100
113
|
if self._optimize_mask_loading and self._classes is not None:
|
|
101
114
|
raise ValueError(
|
|
@@ -128,23 +141,29 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
128
141
|
def class_to_idx(self) -> Dict[str, int]:
|
|
129
142
|
return {label: index for index, label in enumerate(self.classes)}
|
|
130
143
|
|
|
144
|
+
@property
|
|
145
|
+
def _file_suffix(self) -> str:
|
|
146
|
+
return "nii" if self._decompress else "nii.gz"
|
|
147
|
+
|
|
131
148
|
@override
|
|
132
|
-
def filename(self, index: int
|
|
149
|
+
def filename(self, index: int) -> str:
|
|
133
150
|
sample_idx, _ = self._indices[index]
|
|
134
151
|
sample_dir = self._samples_dirs[sample_idx]
|
|
135
|
-
return os.path.join(sample_dir, "ct.
|
|
152
|
+
return os.path.join(sample_dir, f"ct.{self._file_suffix}")
|
|
136
153
|
|
|
137
154
|
@override
|
|
138
155
|
def prepare_data(self) -> None:
|
|
139
156
|
if self._download:
|
|
140
157
|
self._download_dataset()
|
|
158
|
+
if self._decompress:
|
|
159
|
+
self._decompress_files()
|
|
160
|
+
self._samples_dirs = self._fetch_samples_dirs()
|
|
161
|
+
if self._optimize_mask_loading:
|
|
162
|
+
self._export_semantic_label_masks()
|
|
141
163
|
|
|
142
164
|
@override
|
|
143
165
|
def configure(self) -> None:
|
|
144
|
-
self._samples_dirs = self._fetch_samples_dirs()
|
|
145
166
|
self._indices = self._create_indices()
|
|
146
|
-
if self._optimize_mask_loading:
|
|
147
|
-
self._export_semantic_label_masks()
|
|
148
167
|
|
|
149
168
|
@override
|
|
150
169
|
def validate(self) -> None:
|
|
@@ -186,16 +205,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
186
205
|
return {"slice_index": slice_index}
|
|
187
206
|
|
|
188
207
|
def _load_mask(self, index: int) -> tv_tensors.Mask:
|
|
189
|
-
"""Loads and builds the segmentation mask from NifTi files."""
|
|
190
208
|
sample_index, slice_index = self._indices[index]
|
|
191
209
|
semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
|
|
192
|
-
return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
210
|
+
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
193
211
|
|
|
194
212
|
def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
|
|
195
213
|
"""Loads the segmentation mask from a semantic label NifTi file."""
|
|
196
214
|
sample_index, slice_index = self._indices[index]
|
|
197
215
|
masks_dir = self._get_masks_dir(sample_index)
|
|
198
|
-
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii
|
|
216
|
+
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii")
|
|
199
217
|
semantic_labels = io.read_nifti(filename, slice_index)
|
|
200
218
|
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
201
219
|
|
|
@@ -209,7 +227,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
209
227
|
slice_index: Whether to return only a specific slice.
|
|
210
228
|
"""
|
|
211
229
|
masks_dir = self._get_masks_dir(sample_index)
|
|
212
|
-
mask_paths = [os.path.join(masks_dir, label
|
|
230
|
+
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in self.classes]
|
|
213
231
|
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
|
|
214
232
|
background_mask = np.zeros_like(binary_masks[0])
|
|
215
233
|
return np.argmax([background_mask] + binary_masks, axis=0)
|
|
@@ -219,24 +237,28 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
219
237
|
total_samples = len(self._samples_dirs)
|
|
220
238
|
masks_dirs = map(self._get_masks_dir, range(total_samples))
|
|
221
239
|
semantic_labels = [
|
|
222
|
-
(index, os.path.join(directory, "semantic_labels", "masks.nii
|
|
240
|
+
(index, os.path.join(directory, "semantic_labels", "masks.nii"))
|
|
223
241
|
for index, directory in enumerate(masks_dirs)
|
|
224
242
|
]
|
|
225
243
|
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
|
|
226
244
|
|
|
227
|
-
|
|
228
|
-
list(to_export),
|
|
229
|
-
desc=">> Exporting optimized semantic masks",
|
|
230
|
-
leave=False,
|
|
231
|
-
):
|
|
245
|
+
def _process_mask(sample_index: Any, filename: str) -> None:
|
|
232
246
|
semantic_labels = self._load_masks_as_semantic_label(sample_index)
|
|
233
247
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
234
248
|
io.save_array_as_nifti(semantic_labels, filename)
|
|
235
249
|
|
|
250
|
+
multiprocessing.run_with_threads(
|
|
251
|
+
_process_mask,
|
|
252
|
+
list(to_export),
|
|
253
|
+
num_workers=self._num_workers,
|
|
254
|
+
progress_desc=">> Exporting optimized semantic mask",
|
|
255
|
+
return_results=False,
|
|
256
|
+
)
|
|
257
|
+
|
|
236
258
|
def _get_image_path(self, sample_index: int) -> str:
|
|
237
259
|
"""Returns the corresponding image path."""
|
|
238
260
|
sample_dir = self._samples_dirs[sample_index]
|
|
239
|
-
return os.path.join(self._root, sample_dir, "ct.
|
|
261
|
+
return os.path.join(self._root, sample_dir, f"ct.{self._file_suffix}")
|
|
240
262
|
|
|
241
263
|
def _get_masks_dir(self, sample_index: int) -> str:
|
|
242
264
|
"""Returns the directory of the corresponding masks."""
|
|
@@ -246,7 +268,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
246
268
|
def _get_semantic_labels_filename(self, sample_index: int) -> str:
|
|
247
269
|
"""Returns the semantic label filename."""
|
|
248
270
|
masks_dir = self._get_masks_dir(sample_index)
|
|
249
|
-
return os.path.join(masks_dir, "semantic_labels", "masks.nii
|
|
271
|
+
return os.path.join(masks_dir, "semantic_labels", "masks.nii")
|
|
250
272
|
|
|
251
273
|
def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
|
|
252
274
|
"""Returns the total amount of slices of a sample."""
|
|
@@ -320,6 +342,16 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
320
342
|
remove_finished=True,
|
|
321
343
|
)
|
|
322
344
|
|
|
345
|
+
def _decompress_files(self) -> None:
|
|
346
|
+
compressed_paths = Path(self._root).rglob("*/ct.nii.gz")
|
|
347
|
+
multiprocessing.run_with_threads(
|
|
348
|
+
core_io.gunzip_file,
|
|
349
|
+
[(str(path),) for path in compressed_paths],
|
|
350
|
+
num_workers=self._num_workers,
|
|
351
|
+
progress_desc=">> Decompressing .gz files",
|
|
352
|
+
return_results=False,
|
|
353
|
+
)
|
|
354
|
+
|
|
323
355
|
def _print_license(self) -> None:
|
|
324
356
|
"""Prints the dataset license."""
|
|
325
357
|
print(f"Dataset license: {self._license}")
|
eva/vision/data/datasets/wsi.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
import bisect
|
|
4
4
|
import os
|
|
5
|
-
from typing import Callable, List
|
|
5
|
+
from typing import Any, Callable, Dict, List
|
|
6
6
|
|
|
7
|
+
import pandas as pd
|
|
7
8
|
from loguru import logger
|
|
8
9
|
from torch.utils.data import dataset as torch_datasets
|
|
9
10
|
from torchvision import tv_tensors
|
|
@@ -85,6 +86,17 @@ class WsiDataset(vision.VisionDataset):
|
|
|
85
86
|
patch = self._apply_transforms(patch)
|
|
86
87
|
return patch
|
|
87
88
|
|
|
89
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
90
|
+
"""Loads the metadata for the patch at the specified index."""
|
|
91
|
+
x, y = self._coords.x_y[index]
|
|
92
|
+
return {
|
|
93
|
+
"x": x,
|
|
94
|
+
"y": y,
|
|
95
|
+
"width": self._coords.width,
|
|
96
|
+
"height": self._coords.height,
|
|
97
|
+
"level_idx": self._coords.level_idx,
|
|
98
|
+
}
|
|
99
|
+
|
|
88
100
|
def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
|
|
89
101
|
if self._image_transforms is not None:
|
|
90
102
|
image = self._image_transforms(image)
|
|
@@ -105,6 +117,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
105
117
|
overwrite_mpp: float | None = None,
|
|
106
118
|
backend: str = "openslide",
|
|
107
119
|
image_transforms: Callable | None = None,
|
|
120
|
+
coords_path: str | None = None,
|
|
108
121
|
):
|
|
109
122
|
"""Initializes a new dataset instance.
|
|
110
123
|
|
|
@@ -118,6 +131,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
118
131
|
sampler: The sampler to use for sampling patch coordinates.
|
|
119
132
|
backend: The backend to use for reading the whole-slide images.
|
|
120
133
|
image_transforms: Transforms to apply to the extracted image patches.
|
|
134
|
+
coords_path: File path to save the patch coordinates as .csv.
|
|
121
135
|
"""
|
|
122
136
|
super().__init__()
|
|
123
137
|
|
|
@@ -130,6 +144,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
130
144
|
self._sampler = sampler
|
|
131
145
|
self._backend = backend
|
|
132
146
|
self._image_transforms = image_transforms
|
|
147
|
+
self._coords_path = coords_path
|
|
133
148
|
|
|
134
149
|
self._concat_dataset: torch_datasets.ConcatDataset
|
|
135
150
|
|
|
@@ -146,6 +161,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
146
161
|
@override
|
|
147
162
|
def configure(self) -> None:
|
|
148
163
|
self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
|
|
164
|
+
self._save_coords_to_file()
|
|
149
165
|
|
|
150
166
|
@override
|
|
151
167
|
def __len__(self) -> int:
|
|
@@ -159,6 +175,12 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
159
175
|
def filename(self, index: int) -> str:
|
|
160
176
|
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
|
|
161
177
|
|
|
178
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
179
|
+
"""Loads the metadata for the patch at the specified index."""
|
|
180
|
+
dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
|
|
181
|
+
patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
|
|
182
|
+
return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata
|
|
183
|
+
|
|
162
184
|
def _load_datasets(self) -> list[WsiDataset]:
|
|
163
185
|
logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
|
|
164
186
|
wsi_datasets = []
|
|
@@ -185,3 +207,17 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
185
207
|
|
|
186
208
|
def _get_dataset_idx(self, index: int) -> int:
|
|
187
209
|
return bisect.bisect_right(self.cumulative_sizes, index)
|
|
210
|
+
|
|
211
|
+
def _get_sample_idx(self, index: int) -> int:
|
|
212
|
+
dataset_idx = self._get_dataset_idx(index)
|
|
213
|
+
return index if dataset_idx == 0 else index - self.cumulative_sizes[dataset_idx - 1]
|
|
214
|
+
|
|
215
|
+
def _save_coords_to_file(self):
|
|
216
|
+
if self._coords_path is not None:
|
|
217
|
+
coords = [
|
|
218
|
+
{"file": self._file_paths[i]} | dataset._coords.to_dict()
|
|
219
|
+
for i, dataset in enumerate(self.datasets)
|
|
220
|
+
]
|
|
221
|
+
os.makedirs(os.path.abspath(os.path.join(self._coords_path, os.pardir)), exist_ok=True)
|
|
222
|
+
pd.DataFrame(coords).to_csv(self._coords_path, index=False)
|
|
223
|
+
logger.info(f"Saved patch coordinates to: {self._coords_path}")
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import functools
|
|
5
|
-
from typing import List, Tuple
|
|
5
|
+
from typing import Any, Dict, List, Tuple
|
|
6
6
|
|
|
7
7
|
from eva.vision.data.wsi import backends
|
|
8
8
|
from eva.vision.data.wsi.patching import samplers
|
|
@@ -75,6 +75,14 @@ class PatchCoordinates:
|
|
|
75
75
|
|
|
76
76
|
return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
|
|
77
77
|
|
|
78
|
+
def to_dict(self, include_keys: List[str] | None = None) -> Dict[str, Any]:
|
|
79
|
+
"""Convert the coordinates to a dictionary."""
|
|
80
|
+
include_keys = include_keys or ["x_y", "width", "height", "level_idx"]
|
|
81
|
+
coord_dict = dataclasses.asdict(self)
|
|
82
|
+
if include_keys:
|
|
83
|
+
coord_dict = {key: coord_dict[key] for key in include_keys}
|
|
84
|
+
return coord_dict
|
|
85
|
+
|
|
78
86
|
|
|
79
87
|
@functools.lru_cache(LRU_CACHE_SIZE)
|
|
80
88
|
def get_cached_coords(
|
|
@@ -1,14 +1,8 @@
|
|
|
1
|
-
import random
|
|
2
1
|
from typing import Tuple
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def set_seed(seed: int) -> None:
|
|
8
|
-
random.seed(seed)
|
|
9
|
-
np.random.seed(seed)
|
|
10
|
-
|
|
11
|
-
|
|
12
6
|
def get_grid_coords_and_indices(
|
|
13
7
|
layer_shape: Tuple[int, int],
|
|
14
8
|
width: int,
|
|
@@ -33,8 +27,8 @@ def get_grid_coords_and_indices(
|
|
|
33
27
|
|
|
34
28
|
indices = list(range(len(x_y)))
|
|
35
29
|
if shuffle:
|
|
36
|
-
|
|
37
|
-
|
|
30
|
+
random_generator = np.random.default_rng(seed)
|
|
31
|
+
random_generator.shuffle(indices)
|
|
38
32
|
return x_y, indices
|
|
39
33
|
|
|
40
34
|
|
|
@@ -18,6 +18,7 @@ class RandomSampler(base.Sampler):
|
|
|
18
18
|
"""Initializes the sampler."""
|
|
19
19
|
self.seed = seed
|
|
20
20
|
self.n_samples = n_samples
|
|
21
|
+
self.random_generator = random.Random(seed) # nosec
|
|
21
22
|
|
|
22
23
|
def sample(
|
|
23
24
|
self,
|
|
@@ -33,9 +34,10 @@ class RandomSampler(base.Sampler):
|
|
|
33
34
|
layer_shape: The shape of the layer.
|
|
34
35
|
"""
|
|
35
36
|
_utils.validate_dimensions(width, height, layer_shape)
|
|
36
|
-
_utils.set_seed(self.seed)
|
|
37
37
|
|
|
38
38
|
x_max, y_max = layer_shape[0], layer_shape[1]
|
|
39
39
|
for _ in range(self.n_samples):
|
|
40
|
-
x, y =
|
|
40
|
+
x, y = self.random_generator.randint(0, x_max - width), self.random_generator.randint(
|
|
41
|
+
0, y_max - height
|
|
42
|
+
)
|
|
41
43
|
yield x, y
|
eva/vision/losses/__init__.py
CHANGED