kaiko-eva 0.1.0__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/writers/embeddings/base.py +3 -4
- eva/core/data/dataloaders/dataloader.py +2 -2
- eva/core/data/splitting/random.py +6 -5
- eva/core/data/splitting/stratified.py +12 -6
- eva/core/losses/__init__.py +5 -0
- eva/core/losses/cross_entropy.py +27 -0
- eva/core/metrics/__init__.py +0 -4
- eva/core/metrics/defaults/__init__.py +0 -2
- eva/core/models/modules/module.py +9 -9
- eva/core/models/transforms/extract_cls_features.py +17 -9
- eva/core/models/transforms/extract_patch_features.py +23 -11
- eva/core/utils/progress_bar.py +15 -0
- eva/vision/data/datasets/__init__.py +4 -0
- eva/vision/data/datasets/classification/__init__.py +2 -1
- eva/vision/data/datasets/classification/camelyon16.py +4 -1
- eva/vision/data/datasets/classification/panda.py +17 -1
- eva/vision/data/datasets/classification/wsi.py +4 -1
- eva/vision/data/datasets/segmentation/__init__.py +2 -0
- eva/vision/data/datasets/segmentation/consep.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +49 -29
- eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
- eva/vision/data/datasets/segmentation/monusac.py +7 -7
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +2 -2
- eva/vision/data/datasets/wsi.py +37 -1
- eva/vision/data/wsi/patching/coordinates.py +9 -1
- eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
- eva/vision/data/wsi/patching/samplers/random.py +4 -2
- eva/vision/losses/__init__.py +2 -2
- eva/vision/losses/dice.py +75 -8
- eva/vision/metrics/__init__.py +11 -0
- eva/vision/metrics/defaults/__init__.py +7 -0
- eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
- eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
- eva/vision/metrics/segmentation/BUILD +1 -0
- eva/vision/metrics/segmentation/__init__.py +9 -0
- eva/vision/metrics/segmentation/_utils.py +69 -0
- eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
- eva/vision/metrics/segmentation/mean_iou.py +57 -0
- eva/vision/models/modules/semantic_segmentation.py +4 -3
- eva/vision/models/networks/backbones/_utils.py +12 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
- eva/vision/models/networks/backbones/pathology/histai.py +8 -2
- eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
- eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
- eva/vision/models/networks/backbones/pathology/paige.py +51 -0
- eva/vision/models/networks/decoders/__init__.py +1 -1
- eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
- eva/vision/models/networks/decoders/segmentation/base.py +16 -0
- eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
- eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
- eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
- eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
- eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
- eva/vision/utils/io/__init__.py +7 -1
- eva/vision/utils/io/nifti.py +19 -4
- {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/METADATA +3 -34
- {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/RECORD +61 -48
- {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/WHEEL +1 -1
- eva/core/metrics/mean_iou.py +0 -120
- eva/vision/models/networks/decoders/decoder.py +0 -7
- {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,12 +5,14 @@ import glob
|
|
|
5
5
|
import os
|
|
6
6
|
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
8
10
|
import torch
|
|
9
11
|
from torchvision import tv_tensors
|
|
10
12
|
from typing_extensions import override
|
|
11
13
|
|
|
12
14
|
from eva.core import utils
|
|
13
|
-
from eva.
|
|
15
|
+
from eva.core.data import splitting
|
|
14
16
|
from eva.vision.data.datasets import _validators
|
|
15
17
|
from eva.vision.data.datasets.segmentation import base
|
|
16
18
|
from eva.vision.utils import io
|
|
@@ -20,22 +22,23 @@ class LiTS(base.ImageSegmentation):
|
|
|
20
22
|
"""LiTS - Liver Tumor Segmentation Challenge.
|
|
21
23
|
|
|
22
24
|
Webpage: https://competitions.codalab.org/competitions/17094
|
|
23
|
-
|
|
24
|
-
For the splits we follow: https://arxiv.org/pdf/2010.01663v2
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
27
|
+
_train_ratio: float = 0.7
|
|
28
|
+
_val_ratio: float = 0.15
|
|
29
|
+
_test_ratio: float = 0.15
|
|
30
30
|
"""Index ranges per split."""
|
|
31
31
|
|
|
32
|
+
_fix_orientation: bool = True
|
|
33
|
+
"""Whether to fix the orientation of the images to match the default for radiologists."""
|
|
34
|
+
|
|
32
35
|
_sample_every_n_slices: int | None = None
|
|
33
36
|
"""The amount of slices to sub-sample per 3D CT scan image."""
|
|
34
37
|
|
|
35
38
|
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
36
|
-
"train":
|
|
37
|
-
"val":
|
|
38
|
-
"test":
|
|
39
|
+
"train": 38686,
|
|
40
|
+
"val": 11192,
|
|
41
|
+
"test": 8760,
|
|
39
42
|
None: 58638,
|
|
40
43
|
}
|
|
41
44
|
"""Dataset version and split to the expected size."""
|
|
@@ -51,6 +54,7 @@ class LiTS(base.ImageSegmentation):
|
|
|
51
54
|
root: str,
|
|
52
55
|
split: Literal["train", "val", "test"] | None = None,
|
|
53
56
|
transforms: Callable | None = None,
|
|
57
|
+
seed: int = 8,
|
|
54
58
|
) -> None:
|
|
55
59
|
"""Initialize dataset.
|
|
56
60
|
|
|
@@ -60,12 +64,13 @@ class LiTS(base.ImageSegmentation):
|
|
|
60
64
|
split: Dataset split to use.
|
|
61
65
|
transforms: A function/transforms that takes in an image and a target
|
|
62
66
|
mask and returns the transformed versions of both.
|
|
67
|
+
seed: Seed used for generating the dataset splits.
|
|
63
68
|
"""
|
|
64
69
|
super().__init__(transforms=transforms)
|
|
65
70
|
|
|
66
71
|
self._root = root
|
|
67
72
|
self._split = split
|
|
68
|
-
|
|
73
|
+
self._seed = seed
|
|
69
74
|
self._indices: List[Tuple[int, int]] = []
|
|
70
75
|
|
|
71
76
|
@property
|
|
@@ -90,10 +95,12 @@ class LiTS(base.ImageSegmentation):
|
|
|
90
95
|
|
|
91
96
|
@override
|
|
92
97
|
def validate(self) -> None:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
98
|
+
for i in range(len(self._volume_files)):
|
|
99
|
+
seg_path = self._segmentation_file(i)
|
|
100
|
+
if not os.path.exists(seg_path):
|
|
101
|
+
raise FileNotFoundError(
|
|
102
|
+
f"Segmentation file {seg_path} not found for volume {self._volume_files[i]}."
|
|
103
|
+
)
|
|
97
104
|
|
|
98
105
|
_validators.check_dataset_integrity(
|
|
99
106
|
self,
|
|
@@ -107,15 +114,27 @@ class LiTS(base.ImageSegmentation):
|
|
|
107
114
|
sample_index, slice_index = self._indices[index]
|
|
108
115
|
volume_path = self._volume_files[sample_index]
|
|
109
116
|
image_array = io.read_nifti(volume_path, slice_index)
|
|
117
|
+
if self._fix_orientation:
|
|
118
|
+
image_array = self._orientation(image_array, sample_index)
|
|
110
119
|
return tv_tensors.Image(image_array.transpose(2, 0, 1))
|
|
111
120
|
|
|
112
121
|
@override
|
|
113
122
|
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
114
123
|
sample_index, slice_index = self._indices[index]
|
|
115
|
-
segmentation_path = self.
|
|
124
|
+
segmentation_path = self._segmentation_file(sample_index)
|
|
116
125
|
semantic_labels = io.read_nifti(segmentation_path, slice_index)
|
|
126
|
+
if self._fix_orientation:
|
|
127
|
+
semantic_labels = self._orientation(semantic_labels, sample_index)
|
|
117
128
|
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
118
129
|
|
|
130
|
+
def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray:
|
|
131
|
+
volume_path = self._volume_files[sample_index]
|
|
132
|
+
orientation = io.fetch_nifti_axis_direction_code(volume_path)
|
|
133
|
+
array = np.rot90(array, axes=(0, 1))
|
|
134
|
+
if orientation == "LPS":
|
|
135
|
+
array = np.flip(array, axis=0)
|
|
136
|
+
return array.copy()
|
|
137
|
+
|
|
119
138
|
@override
|
|
120
139
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
121
140
|
_, slice_index = self._indices[index]
|
|
@@ -137,11 +156,10 @@ class LiTS(base.ImageSegmentation):
|
|
|
137
156
|
files = glob.glob(files_pattern, recursive=True)
|
|
138
157
|
return utils.numeric_sort(files)
|
|
139
158
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
return utils.numeric_sort(files)
|
|
159
|
+
def _segmentation_file(self, index: int) -> str:
|
|
160
|
+
volume_file_path = self._volume_files[index]
|
|
161
|
+
segmentation_file = os.path.basename(volume_file_path).replace("volume", "segmentation")
|
|
162
|
+
return os.path.join(os.path.dirname(volume_file_path), segmentation_file)
|
|
145
163
|
|
|
146
164
|
def _create_indices(self) -> List[Tuple[int, int]]:
|
|
147
165
|
"""Builds the dataset indices for the specified split.
|
|
@@ -161,17 +179,19 @@ class LiTS(base.ImageSegmentation):
|
|
|
161
179
|
|
|
162
180
|
def _get_split_indices(self) -> List[int]:
|
|
163
181
|
"""Returns the sample indices for the specified dataset split."""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
182
|
+
indices = list(range(len(self._volume_files)))
|
|
183
|
+
train_indices, val_indices, test_indices = splitting.random_split(
|
|
184
|
+
indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed
|
|
185
|
+
)
|
|
186
|
+
split_indices_dict = {
|
|
187
|
+
"train": train_indices,
|
|
188
|
+
"val": val_indices,
|
|
189
|
+
"test": test_indices,
|
|
190
|
+
None: indices,
|
|
169
191
|
}
|
|
170
|
-
|
|
171
|
-
if index_ranges is None:
|
|
192
|
+
if self._split not in split_indices_dict:
|
|
172
193
|
raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.")
|
|
173
|
-
|
|
174
|
-
return data_utils.ranges_to_indices(index_ranges)
|
|
194
|
+
return list(split_indices_dict[self._split])
|
|
175
195
|
|
|
176
196
|
def _print_license(self) -> None:
|
|
177
197
|
"""Prints the dataset license."""
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Balanced LiTS dataset."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, Dict, List, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from eva.vision.data.datasets.segmentation import lits
|
|
9
|
+
from eva.vision.utils import io
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LiTSBalanced(lits.LiTS):
|
|
13
|
+
"""Balanced version of the LiTS - Liver Tumor Segmentation Challenge dataset.
|
|
14
|
+
|
|
15
|
+
For each volume in the dataset, we sample the same number of slices where
|
|
16
|
+
only the liver and where both liver and tumor are present.
|
|
17
|
+
|
|
18
|
+
Webpage: https://competitions.codalab.org/competitions/17094
|
|
19
|
+
|
|
20
|
+
For the splits we follow: https://arxiv.org/pdf/2010.01663v2
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
24
|
+
"train": 5514,
|
|
25
|
+
"val": 1332,
|
|
26
|
+
"test": 1530,
|
|
27
|
+
None: 8376,
|
|
28
|
+
}
|
|
29
|
+
"""Dataset version and split to the expected size."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
root: str,
|
|
34
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
35
|
+
transforms: Callable | None = None,
|
|
36
|
+
seed: int = 8,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Initialize dataset.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
root: Path to the root directory of the dataset. The dataset will
|
|
42
|
+
be downloaded and extracted here, if it does not already exist.
|
|
43
|
+
split: Dataset split to use.
|
|
44
|
+
transforms: A function/transforms that takes in an image and a target
|
|
45
|
+
mask and returns the transformed versions of both.
|
|
46
|
+
seed: Seed used for generating the dataset splits and sampling of the slices.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(root=root, split=split, transforms=transforms, seed=seed)
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
def _create_indices(self) -> List[Tuple[int, int]]:
|
|
52
|
+
"""Builds the dataset indices for the specified split.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
A list of tuples, where the first value indicates the
|
|
56
|
+
sample index which the second its corresponding slice
|
|
57
|
+
index.
|
|
58
|
+
"""
|
|
59
|
+
split_indices = set(self._get_split_indices())
|
|
60
|
+
indices: List[Tuple[int, int]] = []
|
|
61
|
+
random_generator = np.random.default_rng(seed=self._seed)
|
|
62
|
+
|
|
63
|
+
for sample_idx in range(len(self._volume_files)):
|
|
64
|
+
if sample_idx not in split_indices:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
segmentation = io.read_nifti(self._segmentation_file(sample_idx))
|
|
68
|
+
tumor_filter = segmentation == 2
|
|
69
|
+
tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0
|
|
70
|
+
|
|
71
|
+
if tumor_filter.sum() == 0:
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
liver_filter = segmentation == 1
|
|
75
|
+
liver_slice_filter = liver_filter.sum(axis=(0, 1)) > 0
|
|
76
|
+
|
|
77
|
+
liver_and_tumor_filter = liver_slice_filter & tumor_slice_filter
|
|
78
|
+
liver_only_filter = liver_slice_filter & ~tumor_slice_filter
|
|
79
|
+
|
|
80
|
+
n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum())
|
|
81
|
+
tumor_indices = list(np.where(liver_and_tumor_filter)[0])
|
|
82
|
+
tumor_indices = list(
|
|
83
|
+
random_generator.choice(tumor_indices, size=n_slice_samples, replace=False)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
liver_indices = list(np.where(liver_only_filter)[0])
|
|
87
|
+
liver_indices = list(
|
|
88
|
+
random_generator.choice(liver_indices, size=n_slice_samples, replace=False)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices])
|
|
92
|
+
|
|
93
|
+
return list(indices)
|
|
@@ -10,12 +10,12 @@ import imagesize
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import numpy.typing as npt
|
|
12
12
|
import torch
|
|
13
|
-
import tqdm
|
|
14
13
|
from skimage import draw
|
|
15
14
|
from torchvision import tv_tensors
|
|
16
15
|
from torchvision.datasets import utils
|
|
17
16
|
from typing_extensions import override
|
|
18
17
|
|
|
18
|
+
from eva.core.utils.progress_bar import tqdm
|
|
19
19
|
from eva.vision.data.datasets import _validators, structs
|
|
20
20
|
from eva.vision.data.datasets.segmentation import base
|
|
21
21
|
from eva.vision.utils import io
|
|
@@ -84,7 +84,7 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
84
84
|
@property
|
|
85
85
|
@override
|
|
86
86
|
def classes(self) -> List[str]:
|
|
87
|
-
return ["Epithelial", "Lymphocyte", "Neutrophil", "Macrophage"]
|
|
87
|
+
return ["Background", "Epithelial", "Lymphocyte", "Neutrophil", "Macrophage", "Ambiguous"]
|
|
88
88
|
|
|
89
89
|
@functools.cached_property
|
|
90
90
|
@override
|
|
@@ -107,8 +107,8 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
107
107
|
_validators.check_dataset_integrity(
|
|
108
108
|
self,
|
|
109
109
|
length=self._expected_dataset_lengths.get(self._split, 0),
|
|
110
|
-
n_classes=
|
|
111
|
-
first_and_last_labels=("
|
|
110
|
+
n_classes=6,
|
|
111
|
+
first_and_last_labels=("Background", "Ambiguous"),
|
|
112
112
|
)
|
|
113
113
|
|
|
114
114
|
@override
|
|
@@ -161,7 +161,7 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
161
161
|
for index, filename in enumerate(self._image_files)
|
|
162
162
|
]
|
|
163
163
|
to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
|
|
164
|
-
for sample_index, filename in tqdm
|
|
164
|
+
for sample_index, filename in tqdm(
|
|
165
165
|
list(to_export),
|
|
166
166
|
desc=">> Exporting semantic masks",
|
|
167
167
|
leave=False,
|
|
@@ -199,9 +199,9 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
199
199
|
semantic_labels = np.zeros((height, width), "uint8") # type: ignore[reportCallIssue]
|
|
200
200
|
for level in range(len(root)):
|
|
201
201
|
label = [item.attrib["Name"] for item in root[level][0]][0]
|
|
202
|
-
class_id = self.class_to_idx.get(label,
|
|
202
|
+
class_id = self.class_to_idx.get(label, self.class_to_idx["Ambiguous"])
|
|
203
203
|
# for the test dataset an additional class 'Ambiguous' was added for
|
|
204
|
-
# difficult regions with fuzzy boundaries
|
|
204
|
+
# difficult regions with fuzzy boundaries
|
|
205
205
|
regions = [item for child in root[level] for item in child if item.tag == "Region"]
|
|
206
206
|
for region in regions:
|
|
207
207
|
vertices = np.array(
|
|
@@ -8,11 +8,11 @@ from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import numpy.typing as npt
|
|
10
10
|
import torch
|
|
11
|
-
import tqdm
|
|
12
11
|
from torchvision import tv_tensors
|
|
13
12
|
from torchvision.datasets import utils
|
|
14
13
|
from typing_extensions import override
|
|
15
14
|
|
|
15
|
+
from eva.core.utils.progress_bar import tqdm
|
|
16
16
|
from eva.vision.data.datasets import _validators, structs
|
|
17
17
|
from eva.vision.data.datasets.segmentation import base
|
|
18
18
|
from eva.vision.utils import io
|
|
@@ -224,7 +224,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
224
224
|
]
|
|
225
225
|
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
|
|
226
226
|
|
|
227
|
-
for sample_index, filename in tqdm
|
|
227
|
+
for sample_index, filename in tqdm(
|
|
228
228
|
list(to_export),
|
|
229
229
|
desc=">> Exporting optimized semantic masks",
|
|
230
230
|
leave=False,
|
eva/vision/data/datasets/wsi.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
import bisect
|
|
4
4
|
import os
|
|
5
|
-
from typing import Callable, List
|
|
5
|
+
from typing import Any, Callable, Dict, List
|
|
6
6
|
|
|
7
|
+
import pandas as pd
|
|
7
8
|
from loguru import logger
|
|
8
9
|
from torch.utils.data import dataset as torch_datasets
|
|
9
10
|
from torchvision import tv_tensors
|
|
@@ -85,6 +86,17 @@ class WsiDataset(vision.VisionDataset):
|
|
|
85
86
|
patch = self._apply_transforms(patch)
|
|
86
87
|
return patch
|
|
87
88
|
|
|
89
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
90
|
+
"""Loads the metadata for the patch at the specified index."""
|
|
91
|
+
x, y = self._coords.x_y[index]
|
|
92
|
+
return {
|
|
93
|
+
"x": x,
|
|
94
|
+
"y": y,
|
|
95
|
+
"width": self._coords.width,
|
|
96
|
+
"height": self._coords.height,
|
|
97
|
+
"level_idx": self._coords.level_idx,
|
|
98
|
+
}
|
|
99
|
+
|
|
88
100
|
def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
|
|
89
101
|
if self._image_transforms is not None:
|
|
90
102
|
image = self._image_transforms(image)
|
|
@@ -105,6 +117,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
105
117
|
overwrite_mpp: float | None = None,
|
|
106
118
|
backend: str = "openslide",
|
|
107
119
|
image_transforms: Callable | None = None,
|
|
120
|
+
coords_path: str | None = None,
|
|
108
121
|
):
|
|
109
122
|
"""Initializes a new dataset instance.
|
|
110
123
|
|
|
@@ -118,6 +131,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
118
131
|
sampler: The sampler to use for sampling patch coordinates.
|
|
119
132
|
backend: The backend to use for reading the whole-slide images.
|
|
120
133
|
image_transforms: Transforms to apply to the extracted image patches.
|
|
134
|
+
coords_path: File path to save the patch coordinates as .csv.
|
|
121
135
|
"""
|
|
122
136
|
super().__init__()
|
|
123
137
|
|
|
@@ -130,6 +144,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
130
144
|
self._sampler = sampler
|
|
131
145
|
self._backend = backend
|
|
132
146
|
self._image_transforms = image_transforms
|
|
147
|
+
self._coords_path = coords_path
|
|
133
148
|
|
|
134
149
|
self._concat_dataset: torch_datasets.ConcatDataset
|
|
135
150
|
|
|
@@ -146,6 +161,7 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
146
161
|
@override
|
|
147
162
|
def configure(self) -> None:
|
|
148
163
|
self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
|
|
164
|
+
self._save_coords_to_file()
|
|
149
165
|
|
|
150
166
|
@override
|
|
151
167
|
def __len__(self) -> int:
|
|
@@ -159,6 +175,12 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
159
175
|
def filename(self, index: int) -> str:
|
|
160
176
|
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
|
|
161
177
|
|
|
178
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
179
|
+
"""Loads the metadata for the patch at the specified index."""
|
|
180
|
+
dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
|
|
181
|
+
patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
|
|
182
|
+
return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata
|
|
183
|
+
|
|
162
184
|
def _load_datasets(self) -> list[WsiDataset]:
|
|
163
185
|
logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
|
|
164
186
|
wsi_datasets = []
|
|
@@ -185,3 +207,17 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
185
207
|
|
|
186
208
|
def _get_dataset_idx(self, index: int) -> int:
|
|
187
209
|
return bisect.bisect_right(self.cumulative_sizes, index)
|
|
210
|
+
|
|
211
|
+
def _get_sample_idx(self, index: int) -> int:
|
|
212
|
+
dataset_idx = self._get_dataset_idx(index)
|
|
213
|
+
return index if dataset_idx == 0 else index - self.cumulative_sizes[dataset_idx - 1]
|
|
214
|
+
|
|
215
|
+
def _save_coords_to_file(self):
|
|
216
|
+
if self._coords_path is not None:
|
|
217
|
+
coords = [
|
|
218
|
+
{"file": self._file_paths[i]} | dataset._coords.to_dict()
|
|
219
|
+
for i, dataset in enumerate(self.datasets)
|
|
220
|
+
]
|
|
221
|
+
os.makedirs(os.path.abspath(os.path.join(self._coords_path, os.pardir)), exist_ok=True)
|
|
222
|
+
pd.DataFrame(coords).to_csv(self._coords_path, index=False)
|
|
223
|
+
logger.info(f"Saved patch coordinates to: {self._coords_path}")
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import functools
|
|
5
|
-
from typing import List, Tuple
|
|
5
|
+
from typing import Any, Dict, List, Tuple
|
|
6
6
|
|
|
7
7
|
from eva.vision.data.wsi import backends
|
|
8
8
|
from eva.vision.data.wsi.patching import samplers
|
|
@@ -75,6 +75,14 @@ class PatchCoordinates:
|
|
|
75
75
|
|
|
76
76
|
return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
|
|
77
77
|
|
|
78
|
+
def to_dict(self, include_keys: List[str] | None = None) -> Dict[str, Any]:
|
|
79
|
+
"""Convert the coordinates to a dictionary."""
|
|
80
|
+
include_keys = include_keys or ["x_y", "width", "height", "level_idx"]
|
|
81
|
+
coord_dict = dataclasses.asdict(self)
|
|
82
|
+
if include_keys:
|
|
83
|
+
coord_dict = {key: coord_dict[key] for key in include_keys}
|
|
84
|
+
return coord_dict
|
|
85
|
+
|
|
78
86
|
|
|
79
87
|
@functools.lru_cache(LRU_CACHE_SIZE)
|
|
80
88
|
def get_cached_coords(
|
|
@@ -1,14 +1,8 @@
|
|
|
1
|
-
import random
|
|
2
1
|
from typing import Tuple
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def set_seed(seed: int) -> None:
|
|
8
|
-
random.seed(seed)
|
|
9
|
-
np.random.seed(seed)
|
|
10
|
-
|
|
11
|
-
|
|
12
6
|
def get_grid_coords_and_indices(
|
|
13
7
|
layer_shape: Tuple[int, int],
|
|
14
8
|
width: int,
|
|
@@ -33,8 +27,8 @@ def get_grid_coords_and_indices(
|
|
|
33
27
|
|
|
34
28
|
indices = list(range(len(x_y)))
|
|
35
29
|
if shuffle:
|
|
36
|
-
|
|
37
|
-
|
|
30
|
+
random_generator = np.random.default_rng(seed)
|
|
31
|
+
random_generator.shuffle(indices)
|
|
38
32
|
return x_y, indices
|
|
39
33
|
|
|
40
34
|
|
|
@@ -18,6 +18,7 @@ class RandomSampler(base.Sampler):
|
|
|
18
18
|
"""Initializes the sampler."""
|
|
19
19
|
self.seed = seed
|
|
20
20
|
self.n_samples = n_samples
|
|
21
|
+
self.random_generator = random.Random(seed) # nosec
|
|
21
22
|
|
|
22
23
|
def sample(
|
|
23
24
|
self,
|
|
@@ -33,9 +34,10 @@ class RandomSampler(base.Sampler):
|
|
|
33
34
|
layer_shape: The shape of the layer.
|
|
34
35
|
"""
|
|
35
36
|
_utils.validate_dimensions(width, height, layer_shape)
|
|
36
|
-
_utils.set_seed(self.seed)
|
|
37
37
|
|
|
38
38
|
x_max, y_max = layer_shape[0], layer_shape[1]
|
|
39
39
|
for _ in range(self.n_samples):
|
|
40
|
-
x, y =
|
|
40
|
+
x, y = self.random_generator.randint(0, x_max - width), self.random_generator.randint(
|
|
41
|
+
0, y_max - height
|
|
42
|
+
)
|
|
41
43
|
yield x, y
|
eva/vision/losses/__init__.py
CHANGED
eva/vision/losses/dice.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""Dice loss."""
|
|
1
|
+
"""Dice based loss functions."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence, Tuple
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
from monai import losses
|
|
@@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore
|
|
|
12
14
|
Extends the implementation from MONAI
|
|
13
15
|
- to support semantic target labels (meaning targets of shape BHW)
|
|
14
16
|
- to support `ignore_index` functionality
|
|
17
|
+
- accept weight argument in list format
|
|
15
18
|
"""
|
|
16
19
|
|
|
17
|
-
def __init__(
|
|
18
|
-
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
*args,
|
|
23
|
+
ignore_index: int | None = None,
|
|
24
|
+
weight: Sequence[float] | torch.Tensor | None = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize the DiceLoss.
|
|
19
28
|
|
|
20
29
|
Args:
|
|
21
30
|
args: Positional arguments from the base class.
|
|
22
31
|
ignore_index: Specifies a target value that is ignored and
|
|
23
32
|
does not contribute to the input gradient.
|
|
33
|
+
weight: A list of weights to assign to each class.
|
|
24
34
|
kwargs: Key-word arguments from the base class.
|
|
25
35
|
"""
|
|
26
|
-
|
|
36
|
+
if weight is not None and not isinstance(weight, torch.Tensor):
|
|
37
|
+
weight = torch.tensor(weight)
|
|
38
|
+
|
|
39
|
+
super().__init__(*args, **kwargs, weight=weight)
|
|
27
40
|
|
|
28
41
|
self.ignore_index = ignore_index
|
|
29
42
|
|
|
30
43
|
@override
|
|
31
44
|
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
targets = targets * mask
|
|
35
|
-
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
|
|
45
|
+
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
|
|
46
|
+
targets = _to_one_hot(targets, num_classes=inputs.shape[1])
|
|
36
47
|
|
|
37
48
|
if targets.ndim == 3:
|
|
38
49
|
targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
|
|
39
50
|
|
|
40
51
|
return super().forward(inputs, targets)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DiceCELoss(losses.dice.DiceCELoss):
|
|
55
|
+
"""Combination of Dice and Cross Entropy Loss.
|
|
56
|
+
|
|
57
|
+
Extends the implementation from MONAI
|
|
58
|
+
- to support semantic target labels (meaning targets of shape BHW)
|
|
59
|
+
- to support `ignore_index` functionality
|
|
60
|
+
- accept weight argument in list format
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
*args,
|
|
66
|
+
ignore_index: int | None = None,
|
|
67
|
+
weight: Sequence[float] | torch.Tensor | None = None,
|
|
68
|
+
**kwargs,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Initialize the DiceCELoss.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
args: Positional arguments from the base class.
|
|
74
|
+
ignore_index: Specifies a target value that is ignored and
|
|
75
|
+
does not contribute to the input gradient.
|
|
76
|
+
weight: A list of weights to assign to each class.
|
|
77
|
+
kwargs: Key-word arguments from the base class.
|
|
78
|
+
"""
|
|
79
|
+
if weight is not None and not isinstance(weight, torch.Tensor):
|
|
80
|
+
weight = torch.tensor(weight)
|
|
81
|
+
|
|
82
|
+
super().__init__(*args, **kwargs, weight=weight)
|
|
83
|
+
|
|
84
|
+
self.ignore_index = ignore_index
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
|
|
88
|
+
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
|
|
89
|
+
targets = _to_one_hot(targets, num_classes=inputs.shape[1])
|
|
90
|
+
|
|
91
|
+
return super().forward(inputs, targets)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _apply_ignore_index(
|
|
95
|
+
inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None
|
|
96
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
97
|
+
if ignore_index is not None:
|
|
98
|
+
mask = targets != ignore_index
|
|
99
|
+
targets = targets * mask
|
|
100
|
+
inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
|
|
101
|
+
return inputs, targets
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
|
|
105
|
+
if tensor.ndim == 3:
|
|
106
|
+
return one_hot(tensor[:, None, ...], num_classes=num_classes)
|
|
107
|
+
return tensor
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Default metric collections API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
|
|
4
|
+
from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
|
|
5
|
+
from eva.vision.metrics.segmentation.mean_iou import MeanIoU
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"MulticlassSegmentationMetrics",
|
|
9
|
+
"GeneralizedDiceScore",
|
|
10
|
+
"MeanIoU",
|
|
11
|
+
]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""Default segmentation metric collections API."""
|
|
2
2
|
|
|
3
|
-
from eva.
|
|
3
|
+
from eva.vision.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
|
|
4
4
|
|
|
5
5
|
__all__ = ["MulticlassSegmentationMetrics"]
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Default metric collection for multiclass semantic segmentation tasks."""
|
|
2
2
|
|
|
3
|
-
from eva.core.metrics import
|
|
3
|
+
from eva.core.metrics import structs
|
|
4
|
+
from eva.vision.metrics.segmentation import generalized_dice, mean_iou
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class MulticlassSegmentationMetrics(structs.MetricCollection):
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
python_sources()
|