kaiko-eva 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/dataloaders/dataloader.py +3 -1
- eva/core/interface/interface.py +22 -0
- eva/core/loggers/log/__init__.py +2 -1
- eva/core/loggers/log/table.py +73 -0
- eva/core/trainers/_recorder.py +7 -3
- eva/core/trainers/functional.py +2 -2
- eva/vision/data/datasets/__init__.py +0 -2
- eva/vision/data/datasets/segmentation/__init__.py +0 -2
- eva/vision/models/networks/decoders/segmentation/__init__.py +2 -0
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +5 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +151 -0
- {kaiko_eva-0.3.1.dist-info → kaiko_eva-0.3.3.dist-info}/METADATA +1 -1
- {kaiko_eva-0.3.1.dist-info → kaiko_eva-0.3.3.dist-info}/RECORD +16 -17
- eva/vision/data/datasets/segmentation/metadata/_total_segmentator.py +0 -91
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +0 -414
- {kaiko_eva-0.3.1.dist-info → kaiko_eva-0.3.3.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.3.1.dist-info → kaiko_eva-0.3.3.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.3.1.dist-info → kaiko_eva-0.3.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -77,7 +77,9 @@ class DataLoader:
|
|
|
77
77
|
shuffle=self.shuffle,
|
|
78
78
|
sampler=sampler or self.sampler,
|
|
79
79
|
batch_sampler=self.batch_sampler,
|
|
80
|
-
num_workers=
|
|
80
|
+
num_workers=(
|
|
81
|
+
multiprocessing.cpu_count() if self.num_workers is None else self.num_workers
|
|
82
|
+
),
|
|
81
83
|
collate_fn=self.collate_fn,
|
|
82
84
|
pin_memory=self.pin_memory,
|
|
83
85
|
drop_last=self.drop_last,
|
eva/core/interface/interface.py
CHANGED
|
@@ -102,6 +102,9 @@ class Interface:
|
|
|
102
102
|
model: The model module to use but not modify.
|
|
103
103
|
data: The data module containing validation data.
|
|
104
104
|
"""
|
|
105
|
+
if getattr(data.datasets, "val", None) is None:
|
|
106
|
+
raise ValueError("The provided data module does not contain a validation dataset.")
|
|
107
|
+
|
|
105
108
|
eva_trainer.run_evaluation_session(
|
|
106
109
|
base_trainer=trainer,
|
|
107
110
|
base_model=model,
|
|
@@ -110,3 +113,22 @@ class Interface:
|
|
|
110
113
|
n_runs=trainer.n_runs,
|
|
111
114
|
verbose=trainer.n_runs > 1,
|
|
112
115
|
)
|
|
116
|
+
|
|
117
|
+
def test(
|
|
118
|
+
self,
|
|
119
|
+
trainer: eva_trainer.Trainer,
|
|
120
|
+
model: modules.ModelModule,
|
|
121
|
+
data: datamodules.DataModule,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Same as validate, but runs the test stage."""
|
|
124
|
+
if getattr(data.datasets, "test", None) is None:
|
|
125
|
+
raise ValueError("The provided data module does not contain a test dataset.")
|
|
126
|
+
|
|
127
|
+
eva_trainer.run_evaluation_session(
|
|
128
|
+
base_trainer=trainer,
|
|
129
|
+
base_model=model,
|
|
130
|
+
datamodule=data,
|
|
131
|
+
stages=["test"],
|
|
132
|
+
n_runs=trainer.n_runs,
|
|
133
|
+
verbose=trainer.n_runs > 1,
|
|
134
|
+
)
|
eva/core/loggers/log/__init__.py
CHANGED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Table log functionality."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from eva.core.loggers import loggers
|
|
9
|
+
from eva.core.loggers.log import utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@functools.singledispatch
|
|
13
|
+
def log_table(
|
|
14
|
+
logger,
|
|
15
|
+
tag: str,
|
|
16
|
+
columns: List[str] | None = None,
|
|
17
|
+
data: List[List[str]] | None = None,
|
|
18
|
+
dataframe: pd.DataFrame | None = None,
|
|
19
|
+
step: int = 0,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Adds a a table to the logger.
|
|
22
|
+
|
|
23
|
+
The table can be defined either with `columns` and `data` or with `dataframe`.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
logger: The logger to log the table to.
|
|
27
|
+
tag: The log tag.
|
|
28
|
+
columns: The column names of the table.
|
|
29
|
+
data: The data of the table as a list of lists.
|
|
30
|
+
dataframe: A pandas DataFrame to log.
|
|
31
|
+
step: The global step of the log.
|
|
32
|
+
"""
|
|
33
|
+
utils.raise_not_supported(logger, "table")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@log_table.register
|
|
37
|
+
def _(
|
|
38
|
+
loggers: list,
|
|
39
|
+
tag: str,
|
|
40
|
+
columns: List[str] | None = None,
|
|
41
|
+
data: List[List[str]] | None = None,
|
|
42
|
+
dataframe: pd.DataFrame | None = None,
|
|
43
|
+
step: int = 0,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Adds a table to a list of supported loggers."""
|
|
46
|
+
for logger in loggers:
|
|
47
|
+
log_table(
|
|
48
|
+
logger,
|
|
49
|
+
tag=tag,
|
|
50
|
+
columns=columns,
|
|
51
|
+
data=data,
|
|
52
|
+
dataframe=dataframe,
|
|
53
|
+
step=step,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@log_table.register
|
|
58
|
+
def _(
|
|
59
|
+
logger: loggers.WandbLogger,
|
|
60
|
+
tag: str,
|
|
61
|
+
columns: List[str] | None = None,
|
|
62
|
+
data: List[List[str]] | None = None,
|
|
63
|
+
dataframe: pd.DataFrame | None = None,
|
|
64
|
+
step: int = 0,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Adds a table to a Wandb logger."""
|
|
67
|
+
logger.log_table(
|
|
68
|
+
key=tag,
|
|
69
|
+
columns=columns,
|
|
70
|
+
data=data,
|
|
71
|
+
dataframe=dataframe,
|
|
72
|
+
step=step,
|
|
73
|
+
)
|
eva/core/trainers/_recorder.py
CHANGED
|
@@ -86,10 +86,13 @@ class SessionRecorder:
|
|
|
86
86
|
|
|
87
87
|
def update(
|
|
88
88
|
self,
|
|
89
|
-
validation_scores: _EVALUATE_OUTPUT,
|
|
89
|
+
validation_scores: _EVALUATE_OUTPUT | None = None,
|
|
90
90
|
test_scores: _EVALUATE_OUTPUT | None = None,
|
|
91
91
|
) -> None:
|
|
92
92
|
"""Updates the state of the tracked metrics in-place."""
|
|
93
|
+
if validation_scores is None and test_scores is None:
|
|
94
|
+
raise ValueError("At least one of validation_scores or test_scores must be provided.")
|
|
95
|
+
|
|
93
96
|
self._update_validation_metrics(validation_scores)
|
|
94
97
|
self._update_test_metrics(test_scores)
|
|
95
98
|
|
|
@@ -117,9 +120,10 @@ class SessionRecorder:
|
|
|
117
120
|
self._validation_metrics = []
|
|
118
121
|
self._test_metrics = []
|
|
119
122
|
|
|
120
|
-
def _update_validation_metrics(self, metrics: _EVALUATE_OUTPUT) -> None:
|
|
123
|
+
def _update_validation_metrics(self, metrics: _EVALUATE_OUTPUT | None) -> None:
|
|
121
124
|
"""Updates the validation metrics in-place."""
|
|
122
|
-
|
|
125
|
+
if metrics:
|
|
126
|
+
self._validation_metrics = _update_session_metrics(self._validation_metrics, metrics)
|
|
123
127
|
|
|
124
128
|
def _update_test_metrics(self, metrics: _EVALUATE_OUTPUT | None) -> None:
|
|
125
129
|
"""Updates the test metrics in-place."""
|
eva/core/trainers/functional.py
CHANGED
|
@@ -47,7 +47,7 @@ def run_evaluation_session(
|
|
|
47
47
|
stages=stages,
|
|
48
48
|
verbose=not verbose,
|
|
49
49
|
)
|
|
50
|
-
if validation_scores:
|
|
50
|
+
if validation_scores or test_scores:
|
|
51
51
|
recorder.update(validation_scores, test_scores)
|
|
52
52
|
recorder.save()
|
|
53
53
|
|
|
@@ -89,7 +89,7 @@ def run_evaluation(
|
|
|
89
89
|
|
|
90
90
|
if "fit" in stages:
|
|
91
91
|
trainer.fit(model, datamodule=datamodule)
|
|
92
|
-
if "validate" in stages:
|
|
92
|
+
if "validate" in stages and getattr(datamodule.datasets, "val", None) is not None:
|
|
93
93
|
validation_scores = trainer.validate(
|
|
94
94
|
model=model,
|
|
95
95
|
datamodule=datamodule,
|
|
@@ -22,7 +22,6 @@ from eva.vision.data.datasets.segmentation import (
|
|
|
22
22
|
LiTS17,
|
|
23
23
|
MoNuSAC,
|
|
24
24
|
MSDTask7Pancreas,
|
|
25
|
-
TotalSegmentator2D,
|
|
26
25
|
)
|
|
27
26
|
from eva.vision.data.datasets.vision import VisionDataset
|
|
28
27
|
from eva.vision.data.datasets.wsi import MultiWsiDataset, WsiDataset
|
|
@@ -40,7 +39,6 @@ __all__ = [
|
|
|
40
39
|
"PANDASmall",
|
|
41
40
|
"Camelyon16",
|
|
42
41
|
"PatchCamelyon",
|
|
43
|
-
"TotalSegmentator2D",
|
|
44
42
|
"UniToPatho",
|
|
45
43
|
"WsiClassificationDataset",
|
|
46
44
|
"CoNSeP",
|
|
@@ -7,7 +7,6 @@ from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentat
|
|
|
7
7
|
from eva.vision.data.datasets.segmentation.lits17 import LiTS17
|
|
8
8
|
from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
|
|
9
9
|
from eva.vision.data.datasets.segmentation.msd_task7_pancreas import MSDTask7Pancreas
|
|
10
|
-
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
|
|
11
10
|
|
|
12
11
|
__all__ = [
|
|
13
12
|
"BCSS",
|
|
@@ -17,5 +16,4 @@ __all__ = [
|
|
|
17
16
|
"LiTS17",
|
|
18
17
|
"MSDTask7Pancreas",
|
|
19
18
|
"MoNuSAC",
|
|
20
|
-
"TotalSegmentator2D",
|
|
21
19
|
]
|
|
@@ -9,6 +9,7 @@ from eva.vision.models.networks.decoders.segmentation.semantic import (
|
|
|
9
9
|
ConvDecoderWithImage,
|
|
10
10
|
SingleLinearDecoder,
|
|
11
11
|
SwinUNETRDecoder,
|
|
12
|
+
SwinUNETRDecoderWithProjection,
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
@@ -20,4 +21,5 @@ __all__ = [
|
|
|
20
21
|
"LinearDecoder",
|
|
21
22
|
"SingleLinearDecoder",
|
|
22
23
|
"SwinUNETRDecoder",
|
|
24
|
+
"SwinUNETRDecoderWithProjection",
|
|
23
25
|
]
|
|
@@ -5,7 +5,10 @@ from eva.vision.models.networks.decoders.segmentation.semantic.common import (
|
|
|
5
5
|
ConvDecoderMS,
|
|
6
6
|
SingleLinearDecoder,
|
|
7
7
|
)
|
|
8
|
-
from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import
|
|
8
|
+
from eva.vision.models.networks.decoders.segmentation.semantic.swin_unetr import (
|
|
9
|
+
SwinUNETRDecoder,
|
|
10
|
+
SwinUNETRDecoderWithProjection,
|
|
11
|
+
)
|
|
9
12
|
from eva.vision.models.networks.decoders.segmentation.semantic.with_image import (
|
|
10
13
|
ConvDecoderWithImage,
|
|
11
14
|
)
|
|
@@ -16,4 +19,5 @@ __all__ = [
|
|
|
16
19
|
"ConvDecoderWithImage",
|
|
17
20
|
"SingleLinearDecoder",
|
|
18
21
|
"SwinUNETRDecoder",
|
|
22
|
+
"SwinUNETRDecoderWithProjection",
|
|
19
23
|
]
|
|
@@ -102,3 +102,154 @@ class SwinUNETRDecoder(nn.Module):
|
|
|
102
102
|
(batch_size, n_classes, image_height, image_width).
|
|
103
103
|
"""
|
|
104
104
|
return self._forward_features(features)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SwinUNETRDecoderWithProjection(nn.Module):
|
|
108
|
+
"""Swin transformer decoder based on UNETR [0].
|
|
109
|
+
|
|
110
|
+
This implementation adds additional projection layers to reduce
|
|
111
|
+
the number of channels in the feature maps before applying the upscaling
|
|
112
|
+
convolutional blocks. This reduces the number of trainable parameters
|
|
113
|
+
significantly and is useful when scaling up the encoder architecture.
|
|
114
|
+
|
|
115
|
+
- [0] UNETR: Transformers for 3D Medical Image Segmentation
|
|
116
|
+
https://arxiv.org/pdf/2103.10504
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
out_channels: int,
|
|
122
|
+
feature_size: int = 48,
|
|
123
|
+
spatial_dims: int = 3,
|
|
124
|
+
project_dims: list[int] | None = None,
|
|
125
|
+
checkpoint_path: str | None = None,
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Builds the decoder.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
out_channels: Number of output channels.
|
|
131
|
+
feature_size: Dimension of network feature size.
|
|
132
|
+
spatial_dims: Number of spatial dimensions.
|
|
133
|
+
project_dims: List of 6 dimensions to project encoder features to.
|
|
134
|
+
If None, uses default channel progression based on feature_size.
|
|
135
|
+
This is not part of the original implementation, but helps
|
|
136
|
+
to reduce the number of decoder parameters when scaling up
|
|
137
|
+
the encoder architecture (feature_size).
|
|
138
|
+
checkpoint_path: Path to the checkpoint file.
|
|
139
|
+
"""
|
|
140
|
+
super().__init__()
|
|
141
|
+
|
|
142
|
+
self._checkpoint_path = checkpoint_path
|
|
143
|
+
self._project_dims = project_dims
|
|
144
|
+
|
|
145
|
+
if project_dims is not None and len(project_dims) != 6:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"project_dims must have exactly 6 dimensions, got {len(project_dims)}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
channel_dims = project_dims or [
|
|
151
|
+
feature_size,
|
|
152
|
+
feature_size,
|
|
153
|
+
feature_size * 2,
|
|
154
|
+
feature_size * 4,
|
|
155
|
+
feature_size * 8,
|
|
156
|
+
feature_size * 16,
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
self.decoder5 = unetr_block.UnetrUpBlock(
|
|
160
|
+
spatial_dims=spatial_dims,
|
|
161
|
+
in_channels=channel_dims[5],
|
|
162
|
+
out_channels=channel_dims[4],
|
|
163
|
+
kernel_size=3,
|
|
164
|
+
upsample_kernel_size=2,
|
|
165
|
+
norm_name="instance",
|
|
166
|
+
res_block=True,
|
|
167
|
+
)
|
|
168
|
+
self.decoder4 = unetr_block.UnetrUpBlock(
|
|
169
|
+
spatial_dims=spatial_dims,
|
|
170
|
+
in_channels=channel_dims[4],
|
|
171
|
+
out_channels=channel_dims[3],
|
|
172
|
+
kernel_size=3,
|
|
173
|
+
upsample_kernel_size=2,
|
|
174
|
+
norm_name="instance",
|
|
175
|
+
res_block=True,
|
|
176
|
+
)
|
|
177
|
+
self.decoder3 = unetr_block.UnetrUpBlock(
|
|
178
|
+
spatial_dims=spatial_dims,
|
|
179
|
+
in_channels=channel_dims[3],
|
|
180
|
+
out_channels=channel_dims[2],
|
|
181
|
+
kernel_size=3,
|
|
182
|
+
upsample_kernel_size=2,
|
|
183
|
+
norm_name="instance",
|
|
184
|
+
res_block=True,
|
|
185
|
+
)
|
|
186
|
+
self.decoder2 = unetr_block.UnetrUpBlock(
|
|
187
|
+
spatial_dims=spatial_dims,
|
|
188
|
+
in_channels=channel_dims[2],
|
|
189
|
+
out_channels=channel_dims[1],
|
|
190
|
+
kernel_size=3,
|
|
191
|
+
upsample_kernel_size=2,
|
|
192
|
+
norm_name="instance",
|
|
193
|
+
res_block=True,
|
|
194
|
+
)
|
|
195
|
+
self.decoder1 = unetr_block.UnetrUpBlock(
|
|
196
|
+
spatial_dims=spatial_dims,
|
|
197
|
+
in_channels=channel_dims[1],
|
|
198
|
+
out_channels=channel_dims[0],
|
|
199
|
+
kernel_size=3,
|
|
200
|
+
upsample_kernel_size=2,
|
|
201
|
+
norm_name="instance",
|
|
202
|
+
res_block=True,
|
|
203
|
+
)
|
|
204
|
+
self.out = dynunet_block.UnetOutBlock(
|
|
205
|
+
spatial_dims=spatial_dims,
|
|
206
|
+
in_channels=channel_dims[0],
|
|
207
|
+
out_channels=out_channels,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if self._project_dims:
|
|
211
|
+
conv_layer = nn.Conv2d if spatial_dims == 2 else nn.Conv3d
|
|
212
|
+
self.proj_enc0 = conv_layer(feature_size, self._project_dims[0], kernel_size=1)
|
|
213
|
+
self.proj_enc1 = conv_layer(feature_size, self._project_dims[1], kernel_size=1)
|
|
214
|
+
self.proj_enc2 = conv_layer(feature_size * 2, self._project_dims[2], kernel_size=1)
|
|
215
|
+
self.proj_enc3 = conv_layer(feature_size * 4, self._project_dims[3], kernel_size=1)
|
|
216
|
+
self.proj_hid3 = conv_layer(feature_size * 8, self._project_dims[4], kernel_size=1)
|
|
217
|
+
self.proj_dec4 = conv_layer(feature_size * 16, self._project_dims[5], kernel_size=1)
|
|
218
|
+
|
|
219
|
+
def _forward_features(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
220
|
+
"""Forward function for multi-level feature maps to a single one."""
|
|
221
|
+
enc0, enc1, enc2, enc3, hid3, dec4 = self._project_features(features)
|
|
222
|
+
dec3 = self.decoder5(dec4, hid3)
|
|
223
|
+
dec2 = self.decoder4(dec3, enc3)
|
|
224
|
+
dec1 = self.decoder3(dec2, enc2)
|
|
225
|
+
dec0 = self.decoder2(dec1, enc1)
|
|
226
|
+
out = self.decoder1(dec0, enc0)
|
|
227
|
+
return self.out(out)
|
|
228
|
+
|
|
229
|
+
def _project_features(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
230
|
+
"""Projects features using 1x1 to reduce number of channels."""
|
|
231
|
+
return (
|
|
232
|
+
[
|
|
233
|
+
self.proj_enc0(features[0]),
|
|
234
|
+
self.proj_enc1(features[1]),
|
|
235
|
+
self.proj_enc2(features[2]),
|
|
236
|
+
self.proj_enc3(features[3]),
|
|
237
|
+
self.proj_hid3(features[4]),
|
|
238
|
+
self.proj_dec4(features[5]),
|
|
239
|
+
]
|
|
240
|
+
if self._project_dims
|
|
241
|
+
else features
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
|
245
|
+
"""Maps the patch embeddings to a segmentation mask.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
features: List of multi-level intermediate features from
|
|
249
|
+
:class:`SwinUNETREncoder`.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Tensor containing scores for all of the classes with shape
|
|
253
|
+
(batch_size, n_classes, image_height, image_width).
|
|
254
|
+
"""
|
|
255
|
+
return self._forward_features(features)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: kaiko-eva
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: Evaluation Framework for oncology foundation models.
|
|
5
5
|
Keywords: machine-learning,evaluation-framework,oncology,foundation-models
|
|
6
6
|
Author-Email: Ioannis Gatopoulos <ioannis@kaiko.ai>, =?utf-8?q?Nicolas_K=C3=A4nzig?= <nicolas@kaiko.ai>, Roman Moser <roman@kaiko.ai>
|
|
@@ -19,7 +19,7 @@ eva/core/data/__init__.py,sha256=yG3BeOWhp1EjVYMFqx8M_TBWFDyfIwwksQGQmMdSPaI,340
|
|
|
19
19
|
eva/core/data/dataloaders/__init__.py,sha256=0AvpsPOdh4NX5rwkX9Th1M_rzxZVVzTPTdC5oTGFd5w,194
|
|
20
20
|
eva/core/data/dataloaders/collate_fn/__init__.py,sha256=CfSAVrPD36shpyYAkob2ny05VWymb95MutawQcZkbuo,134
|
|
21
21
|
eva/core/data/dataloaders/collate_fn/collate.py,sha256=oBdxaDCIaXBB6H8LB90Qsi2Inw1tyHGF4kAFBINPOeQ,689
|
|
22
|
-
eva/core/data/dataloaders/dataloader.py,sha256=
|
|
22
|
+
eva/core/data/dataloaders/dataloader.py,sha256=v_UL5p78EiyLqXUhWsS0EbfW0iO-DAWOMv7-WTclOvA,2750
|
|
23
23
|
eva/core/data/datamodules/__init__.py,sha256=qZchYbgxo9lxYnGoqdk0C6MfS2IbF0WItO0kCdP9Mqc,229
|
|
24
24
|
eva/core/data/datamodules/call.py,sha256=jjj9w3UXYuQB-qyCcw1EZpRJW10OC1I3dvgvsuQWLck,940
|
|
25
25
|
eva/core/data/datamodules/datamodule.py,sha256=_pK59oXDe53oDkmv6eoJUvfl44WlFkrbC8KXSRMs_20,5533
|
|
@@ -48,13 +48,14 @@ eva/core/data/transforms/padding/pad_2d_tensor.py,sha256=J4maGFmeQf9IHRxt5kU-6eI
|
|
|
48
48
|
eva/core/data/transforms/sampling/__init__.py,sha256=BFKbvRjlZrwS0GcNrM54ZSWt6PrQARfFlXM1jJ-wpvo,149
|
|
49
49
|
eva/core/data/transforms/sampling/sample_from_axis.py,sha256=Zbhp94lVa70WQKmSOKMTsOMe2c7wLqNZto7JqWhSdtI,1229
|
|
50
50
|
eva/core/interface/__init__.py,sha256=chdpKXipxe1NP-Fgr_d9r6X1gMna0XiEa38waJ6FzTM,98
|
|
51
|
-
eva/core/interface/interface.py,sha256=
|
|
51
|
+
eva/core/interface/interface.py,sha256=EeBrIqUlO497BOZGZrtO-fZnc_BhIJrrqyQmRfqWXcw,4591
|
|
52
52
|
eva/core/loggers/__init__.py,sha256=4YMLNlN9LnuKqhBI1R1keh69dmMD-2lcH3HKwwyn380,266
|
|
53
53
|
eva/core/loggers/dummy.py,sha256=Y7ypH0ecSAIkkZ5LzTmNNEzlKkqeaHfUNMCDKVOg6D4,1204
|
|
54
54
|
eva/core/loggers/experimental_loggers.py,sha256=p5uCK_9QCYufRhE-LZQUJWbhGElyobX_zRM78yX4p2o,230
|
|
55
|
-
eva/core/loggers/log/__init__.py,sha256=
|
|
55
|
+
eva/core/loggers/log/__init__.py,sha256=TdsUwcaB2jW0--HVGIr7_polTfmr7iOmXtSmr1wK9_c,251
|
|
56
56
|
eva/core/loggers/log/image.py,sha256=iUwntQCdRNLtkSdqu8CvV34l06zPYVo4NAW2gUeiJIM,1490
|
|
57
57
|
eva/core/loggers/log/parameters.py,sha256=7Xi-I5gQvEVv71d58bwdZ-Hb4287NXxaUyMfriq_KDU,1634
|
|
58
|
+
eva/core/loggers/log/table.py,sha256=HULCo5icDB6UOIXXMci0eo32Pl1-YRWzaOAB-ZBVUak,1726
|
|
58
59
|
eva/core/loggers/log/utils.py,sha256=k4Q7uKpAQctfDv0EEYPnPv6wt9LnckEeqGvbYSLfKO0,415
|
|
59
60
|
eva/core/loggers/loggers.py,sha256=igHxdxJSotWSg6nEOKnfFuBszzblHgi8T7sBrE00FEs,166
|
|
60
61
|
eva/core/loggers/utils/wandb.py,sha256=GdwzEeFTAng5kl_kIVRxKL7rvwqyicQHSaZS8VSMXvU,747
|
|
@@ -99,9 +100,9 @@ eva/core/models/wrappers/huggingface.py,sha256=-_fA81YRnoMc7O7SbrnCEj1dM_xArpQ8W
|
|
|
99
100
|
eva/core/models/wrappers/onnx.py,sha256=34li_xSwPryN8nJDrFyif_Hve1AEH7Ry9E_lZmf7JJM,1834
|
|
100
101
|
eva/core/trainers/__init__.py,sha256=jhsKJF7HAae7EOiG3gKIAHH_h3dZlTE2JRcCHJmOzJc,208
|
|
101
102
|
eva/core/trainers/_logging.py,sha256=gi4FqPy2GuVmh0WZY6mYwF7zMPvnoFA050B0XdCP6PU,2571
|
|
102
|
-
eva/core/trainers/_recorder.py,sha256=
|
|
103
|
+
eva/core/trainers/_recorder.py,sha256=M-BJHLgqGxR_MSV6f_WC7GN2JHYEEinV1-hNLpH667A,8062
|
|
103
104
|
eva/core/trainers/_utils.py,sha256=M3h8lVhUmkeSiEXpX9hRdMvThGFCnTP15gv-hd1CZkc,321
|
|
104
|
-
eva/core/trainers/functional.py,sha256=
|
|
105
|
+
eva/core/trainers/functional.py,sha256=_Mw-NIPU2tPffxpK5t3sHBmVI6u163phCpoJFiauH7E,4583
|
|
105
106
|
eva/core/trainers/trainer.py,sha256=a3OwLWOZKDqxayrd0ugUmxJKyQx6XDb4GHtdL8-AEV0,4826
|
|
106
107
|
eva/core/utils/__init__.py,sha256=cndVBvtYxEW7hykH39GCNVI86zkXNn8Lw2A0sUJHS04,237
|
|
107
108
|
eva/core/utils/clone.py,sha256=qcThZOuAs1cs0uV3BL5eKeM2VIBjuRPBe1t-NiUFM5Y,569
|
|
@@ -146,7 +147,7 @@ eva/vision/data/dataloaders/__init__.py,sha256=9ykBD4vyZ-Yv3IEnqvVcSMURS-gXWjOun
|
|
|
146
147
|
eva/vision/data/dataloaders/collate_fn/__init__.py,sha256=GCvJaeILmAc_-lhGw8yzj2cC2KG4i1PvSWAyVzPKvVo,146
|
|
147
148
|
eva/vision/data/dataloaders/collate_fn/collection.py,sha256=45s9fKjVBnqfnuGWmJZMtt_DDGnfuf7qkWe0QmxXMKo,611
|
|
148
149
|
eva/vision/data/dataloaders/worker_init.py,sha256=lFWywHGCC4QxHeDXrneF8DQ45XG3WmVltEELJrPyLz0,1182
|
|
149
|
-
eva/vision/data/datasets/__init__.py,sha256=
|
|
150
|
+
eva/vision/data/datasets/__init__.py,sha256=_04LqKv46oUXdmQAlSmWkgZYueHFwM-0iiOSMuFnFDw,976
|
|
150
151
|
eva/vision/data/datasets/_utils.py,sha256=epPcaYE4w2_LtUKLLQJh6qQxUNVBe22JA06k4WUerYQ,1430
|
|
151
152
|
eva/vision/data/datasets/_validators.py,sha256=77WZj8ewsuxUjW5WegJ-7zDuR6WdF5JbaOYdywhKIK4,2594
|
|
152
153
|
eva/vision/data/datasets/classification/__init__.py,sha256=5fOGZxKGPeMCf3Jd9qAOYADPrkZnYg97_QE4DC79AMI,1074
|
|
@@ -161,7 +162,7 @@ eva/vision/data/datasets/classification/panda.py,sha256=HVfCvByyajdo5o_waqTpzZWC
|
|
|
161
162
|
eva/vision/data/datasets/classification/patch_camelyon.py,sha256=1yXkfP680qxkQUFAPKRFbZv0cHAFx23s2vvT9th2nKM,7149
|
|
162
163
|
eva/vision/data/datasets/classification/unitopatho.py,sha256=IO3msEsuOnmdcYZxF-eBpo0K97y54rWFmCb_KxuF4bk,5129
|
|
163
164
|
eva/vision/data/datasets/classification/wsi.py,sha256=YMGxU8ECjudizt_uXUevuPS8k66HxtEQ7M2IZJmL6kE,4079
|
|
164
|
-
eva/vision/data/datasets/segmentation/__init__.py,sha256=
|
|
165
|
+
eva/vision/data/datasets/segmentation/__init__.py,sha256=f0q9tzk4ahaZfrw_SgIE_puk_D7qmkSCKX1FP9aJITU,668
|
|
165
166
|
eva/vision/data/datasets/segmentation/_utils.py,sha256=aXUHrnbefP6-OgSvDQHqssFKhUwETul_8aosqYiOfm8,3065
|
|
166
167
|
eva/vision/data/datasets/segmentation/bcss.py,sha256=rqk6VqK0QCHLFnMnDuHd1JPJVK5_C6WnsmnNSKBw6Uo,8230
|
|
167
168
|
eva/vision/data/datasets/segmentation/btcv.py,sha256=9rlEqGyb2SGJBY6Oj42FlHajQF8csf1Jq6jeuPSsfXI,8396
|
|
@@ -170,10 +171,8 @@ eva/vision/data/datasets/segmentation/embeddings.py,sha256=RsTuAwGEJPnWPY7q3pwcj
|
|
|
170
171
|
eva/vision/data/datasets/segmentation/lits17.py,sha256=kcSCKxsgtUuCD1YEYvrb_L_BgOtZC8xDq1lX8ldSZc4,7635
|
|
171
172
|
eva/vision/data/datasets/segmentation/metadata/__init__.py,sha256=o9Od0v6N9dNdf8hfefn2QaNNCD2sZMvc2K58zHA_Nrg,24
|
|
172
173
|
eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py,sha256=O2-ye0A7wIjcI_D857uvpYw-jckTqfhBUrhinqSNWq0,2553
|
|
173
|
-
eva/vision/data/datasets/segmentation/metadata/_total_segmentator.py,sha256=DTaQaAisY7j1h0-zYk1_81Sr4b3D9PTMieYX0PMPtIc,3127
|
|
174
174
|
eva/vision/data/datasets/segmentation/monusac.py,sha256=iv9-MFaTsGfGV1u6_lQNcSEeSpmVBDQC1Oa123iEtu0,8410
|
|
175
175
|
eva/vision/data/datasets/segmentation/msd_task7_pancreas.py,sha256=dTsPD73PAP15VOXdHnX4eQqbpz2jGpCB31YISzinUd4,8964
|
|
176
|
-
eva/vision/data/datasets/segmentation/total_segmentator_2d.py,sha256=TGz67AGuv8_Bm5DM5TyCtzRTuGXOuctZZNxdQtBxF1g,16987
|
|
177
176
|
eva/vision/data/datasets/structs.py,sha256=RaTDW-B36PumcR5gymhCiX-r8GiKqIFcjqoEEjjFyUE,389
|
|
178
177
|
eva/vision/data/datasets/vision.py,sha256=-_WRiyICMgqABR6Ay_RKBMfsPGwgx9MQfCA7WChHo24,3219
|
|
179
178
|
eva/vision/data/datasets/wsi.py,sha256=dEAT_Si_Qb3qdSovUPeoiWeoPb7m-NGYqq44e3UXHk8,8384
|
|
@@ -256,13 +255,13 @@ eva/vision/models/networks/backbones/timm/backbones.py,sha256=ZbF9MMiL4Ylyy79XLe
|
|
|
256
255
|
eva/vision/models/networks/backbones/universal/__init__.py,sha256=xgn3crSqlmUPYz-t2CR1zDKxhlyAEeApA-a6Y_eWQvc,417
|
|
257
256
|
eva/vision/models/networks/backbones/universal/vit.py,sha256=To0OzwpuX5Y5PwjGidwV0Ssq3xa81dve081buwG_Ofg,3658
|
|
258
257
|
eva/vision/models/networks/decoders/__init__.py,sha256=RXFWmoYw2i6E9VOUCJmU8c72icHannVuo-cUKy6fnLM,200
|
|
259
|
-
eva/vision/models/networks/decoders/segmentation/__init__.py,sha256=
|
|
258
|
+
eva/vision/models/networks/decoders/segmentation/__init__.py,sha256=yVrRo2OisNRAlxDjWJGwipKA9HGeqRXd1ZL88eltoy4,726
|
|
260
259
|
eva/vision/models/networks/decoders/segmentation/base.py,sha256=b2TIJKiJR9vejVRpNyedMJLPTrpHhAEXvco8atb9TPU,411
|
|
261
260
|
eva/vision/models/networks/decoders/segmentation/decoder2d.py,sha256=HRonYTSriiq13aZCSNiYUc484qfOhkVT0yFiMW06CDc,4472
|
|
262
261
|
eva/vision/models/networks/decoders/segmentation/linear.py,sha256=ui3-Y0rl4VEF75-sUghaF29P9wpxCVlp5iR_Ym-utUE,4666
|
|
263
|
-
eva/vision/models/networks/decoders/segmentation/semantic/__init__.py,sha256=
|
|
262
|
+
eva/vision/models/networks/decoders/segmentation/semantic/__init__.py,sha256=9QnepLMzQVE-wAZJXx0napVutg1HtkbDERcPsoevWGg,622
|
|
264
263
|
eva/vision/models/networks/decoders/segmentation/semantic/common.py,sha256=FSf_eI-FaBroxPRJd4TiV97RCreauJh1IznIVzBT2eg,2528
|
|
265
|
-
eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py,sha256=
|
|
264
|
+
eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py,sha256=eSFvHng2lrc-Wd4g9CW4z8-yfKndbl0c7-sKhOautBU,9170
|
|
266
265
|
eva/vision/models/networks/decoders/segmentation/semantic/with_image.py,sha256=I5PyGKKo8DcXYcw4xlCFzuavRJNRrzGT-szpDidMPXI,3516
|
|
267
266
|
eva/vision/models/networks/decoders/segmentation/typings.py,sha256=rY4CXp0MNF16SHnx9TgGjXI_r8bVGSqAWdR835hXndg,537
|
|
268
267
|
eva/vision/models/wrappers/__init__.py,sha256=ogmr-eeVuGaOCcsuxSp6PGyauP2QqWTb8dGTtbC7lRU,210
|
|
@@ -277,8 +276,8 @@ eva/vision/utils/io/image.py,sha256=IdOkr5MYqhYHz8U9drZ7wULTM3YHwCWSjZlu_Qdl4GQ,
|
|
|
277
276
|
eva/vision/utils/io/mat.py,sha256=qpGifyjmpE0Xhv567Si7-zxKrgkgE0sywP70cHiLFGU,808
|
|
278
277
|
eva/vision/utils/io/nifti.py,sha256=TFMgNhLqIK3sl3RjIRXEABM7FmSQjqVOwk1vXkuvX2w,4983
|
|
279
278
|
eva/vision/utils/io/text.py,sha256=qYgfo_ZaDZWfG02NkVVYzo5QFySqdCCz5uLA9d-zXtI,701
|
|
280
|
-
kaiko_eva-0.3.
|
|
281
|
-
kaiko_eva-0.3.
|
|
282
|
-
kaiko_eva-0.3.
|
|
283
|
-
kaiko_eva-0.3.
|
|
284
|
-
kaiko_eva-0.3.
|
|
279
|
+
kaiko_eva-0.3.3.dist-info/METADATA,sha256=3-qTJLt0hRJswZbPVixj9e9Bt0gFXDjOp6YbNf6Ohd4,25704
|
|
280
|
+
kaiko_eva-0.3.3.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
|
|
281
|
+
kaiko_eva-0.3.3.dist-info/entry_points.txt,sha256=6CSLu9bmQYJSXEg8gbOzRhxH0AGs75BB-vPm3VvfcNE,88
|
|
282
|
+
kaiko_eva-0.3.3.dist-info/licenses/LICENSE,sha256=e6AEzr7j_R-PYr2qLO-JwLn8y70jbVD3U2mxbRmwcI4,11338
|
|
283
|
+
kaiko_eva-0.3.3.dist-info/RECORD,,
|
|
@@ -1,91 +0,0 @@
|
|
|
1
|
-
"""Utils for TotalSegmentator dataset classes."""
|
|
2
|
-
|
|
3
|
-
from typing import Dict
|
|
4
|
-
|
|
5
|
-
reduced_class_mappings: Dict[str, str] = {
|
|
6
|
-
# Abdominal Organs
|
|
7
|
-
"spleen": "spleen",
|
|
8
|
-
"kidney_right": "kidney",
|
|
9
|
-
"kidney_left": "kidney",
|
|
10
|
-
"gallbladder": "gallbladder",
|
|
11
|
-
"liver": "liver",
|
|
12
|
-
"stomach": "stomach",
|
|
13
|
-
"pancreas": "pancreas",
|
|
14
|
-
"small_bowel": "small_bowel",
|
|
15
|
-
"duodenum": "duodenum",
|
|
16
|
-
"colon": "colon",
|
|
17
|
-
# Endocrine System
|
|
18
|
-
"adrenal_gland_right": "adrenal_gland",
|
|
19
|
-
"adrenal_gland_left": "adrenal_gland",
|
|
20
|
-
"thyroid_gland": "thyroid_gland",
|
|
21
|
-
# Respiratory System
|
|
22
|
-
"lung_upper_lobe_left": "lungs",
|
|
23
|
-
"lung_lower_lobe_left": "lungs",
|
|
24
|
-
"lung_upper_lobe_right": "lungs",
|
|
25
|
-
"lung_middle_lobe_right": "lungs",
|
|
26
|
-
"lung_lower_lobe_right": "lungs",
|
|
27
|
-
"trachea": "trachea",
|
|
28
|
-
"esophagus": "esophagus",
|
|
29
|
-
# Urogenital System
|
|
30
|
-
"urinary_bladder": "urogenital_system",
|
|
31
|
-
"prostate": "urogenital_system",
|
|
32
|
-
"kidney_cyst_left": "kidney_cyst",
|
|
33
|
-
"kidney_cyst_right": "kidney_cyst",
|
|
34
|
-
# Vertebral Column
|
|
35
|
-
**{f"vertebrae_{v}": "vertebrae" for v in ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]},
|
|
36
|
-
**{f"vertebrae_{v}": "vertebrae" for v in [f"T{i}" for i in range(1, 13)]},
|
|
37
|
-
**{f"vertebrae_{v}": "vertebrae" for v in [f"L{i}" for i in range(1, 6)]},
|
|
38
|
-
"vertebrae_S1": "vertebrae",
|
|
39
|
-
"sacrum": "sacral_spine",
|
|
40
|
-
# Cardiovascular System
|
|
41
|
-
"heart": "heart",
|
|
42
|
-
"aorta": "aorta",
|
|
43
|
-
"pulmonary_vein": "veins",
|
|
44
|
-
"brachiocephalic_trunk": "arteries",
|
|
45
|
-
"subclavian_artery_right": "arteries",
|
|
46
|
-
"subclavian_artery_left": "arteries",
|
|
47
|
-
"common_carotid_artery_right": "arteries",
|
|
48
|
-
"common_carotid_artery_left": "arteries",
|
|
49
|
-
"brachiocephalic_vein_left": "veins",
|
|
50
|
-
"brachiocephalic_vein_right": "veins",
|
|
51
|
-
"atrial_appendage_left": "atrial_appendage",
|
|
52
|
-
"superior_vena_cava": "veins",
|
|
53
|
-
"inferior_vena_cava": "veins",
|
|
54
|
-
"portal_vein_and_splenic_vein": "veins",
|
|
55
|
-
"iliac_artery_left": "arteries",
|
|
56
|
-
"iliac_artery_right": "arteries",
|
|
57
|
-
"iliac_vena_left": "veins",
|
|
58
|
-
"iliac_vena_right": "veins",
|
|
59
|
-
# Upper Extremity Bones
|
|
60
|
-
"humerus_left": "humerus",
|
|
61
|
-
"humerus_right": "humerus",
|
|
62
|
-
"scapula_left": "scapula",
|
|
63
|
-
"scapula_right": "scapula",
|
|
64
|
-
"clavicula_left": "clavicula",
|
|
65
|
-
"clavicula_right": "clavicula",
|
|
66
|
-
# Lower Extremity Bones
|
|
67
|
-
"femur_left": "femur",
|
|
68
|
-
"femur_right": "femur",
|
|
69
|
-
"hip_left": "hip",
|
|
70
|
-
"hip_right": "hip",
|
|
71
|
-
# Muscles
|
|
72
|
-
"gluteus_maximus_left": "gluteus",
|
|
73
|
-
"gluteus_maximus_right": "gluteus",
|
|
74
|
-
"gluteus_medius_left": "gluteus",
|
|
75
|
-
"gluteus_medius_right": "gluteus",
|
|
76
|
-
"gluteus_minimus_left": "gluteus",
|
|
77
|
-
"gluteus_minimus_right": "gluteus",
|
|
78
|
-
"autochthon_left": "autochthon",
|
|
79
|
-
"autochthon_right": "autochthon",
|
|
80
|
-
"iliopsoas_left": "iliopsoas",
|
|
81
|
-
"iliopsoas_right": "iliopsoas",
|
|
82
|
-
# Central Nervous System
|
|
83
|
-
"brain": "brain",
|
|
84
|
-
"spinal_cord": "spinal_cord",
|
|
85
|
-
# Skull and Thoracic Cage
|
|
86
|
-
"skull": "skull",
|
|
87
|
-
**{f"rib_left_{i}": "ribs" for i in range(1, 13)},
|
|
88
|
-
**{f"rib_right_{i}": "ribs" for i in range(1, 13)},
|
|
89
|
-
"costal_cartilages": "ribs",
|
|
90
|
-
"sternum": "sternum",
|
|
91
|
-
}
|
|
@@ -1,414 +0,0 @@
|
|
|
1
|
-
"""TotalSegmentator 2D segmentation dataset class."""
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
import hashlib
|
|
5
|
-
import os
|
|
6
|
-
import re
|
|
7
|
-
from glob import glob
|
|
8
|
-
from pathlib import Path
|
|
9
|
-
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
10
|
-
|
|
11
|
-
import numpy as np
|
|
12
|
-
import numpy.typing as npt
|
|
13
|
-
import torch
|
|
14
|
-
from torchvision import tv_tensors
|
|
15
|
-
from torchvision.datasets import utils
|
|
16
|
-
from typing_extensions import override
|
|
17
|
-
|
|
18
|
-
from eva.core.utils import io as core_io
|
|
19
|
-
from eva.core.utils import multiprocessing
|
|
20
|
-
from eva.vision.data.datasets import _validators, structs, vision
|
|
21
|
-
from eva.vision.data.datasets.segmentation.metadata import _total_segmentator
|
|
22
|
-
from eva.vision.utils import io
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class TotalSegmentator2D(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
26
|
-
"""TotalSegmentator 2D segmentation dataset."""
|
|
27
|
-
|
|
28
|
-
_expected_dataset_lengths: Dict[str, int] = {
|
|
29
|
-
"train_small": 35089,
|
|
30
|
-
"val_small": 1283,
|
|
31
|
-
"train_full": 278190,
|
|
32
|
-
"val_full": 14095,
|
|
33
|
-
"test_full": 25578,
|
|
34
|
-
}
|
|
35
|
-
"""Dataset version and split to the expected size."""
|
|
36
|
-
|
|
37
|
-
_sample_every_n_slices: int | None = None
|
|
38
|
-
"""The amount of slices to sub-sample per 3D CT scan image."""
|
|
39
|
-
|
|
40
|
-
_resources_full: List[structs.DownloadResource] = [
|
|
41
|
-
structs.DownloadResource(
|
|
42
|
-
filename="Totalsegmentator_dataset_v201.zip",
|
|
43
|
-
url="https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_v201.zip",
|
|
44
|
-
md5="fe250e5718e0a3b5df4c4ea9d58a62fe",
|
|
45
|
-
),
|
|
46
|
-
]
|
|
47
|
-
"""Resources for the full dataset version."""
|
|
48
|
-
|
|
49
|
-
_resources_small: List[structs.DownloadResource] = [
|
|
50
|
-
structs.DownloadResource(
|
|
51
|
-
filename="Totalsegmentator_dataset_small_v201.zip",
|
|
52
|
-
url="https://zenodo.org/records/10047263/files/Totalsegmentator_dataset_small_v201.zip",
|
|
53
|
-
md5="6b5524af4b15e6ba06ef2d700c0c73e0",
|
|
54
|
-
),
|
|
55
|
-
]
|
|
56
|
-
"""Resources for the small dataset version."""
|
|
57
|
-
|
|
58
|
-
_license: str = (
|
|
59
|
-
"Creative Commons Attribution 4.0 International "
|
|
60
|
-
"(https://creativecommons.org/licenses/by/4.0/deed.en)"
|
|
61
|
-
)
|
|
62
|
-
"""Dataset license."""
|
|
63
|
-
|
|
64
|
-
def __init__(
|
|
65
|
-
self,
|
|
66
|
-
root: str,
|
|
67
|
-
split: Literal["train", "val", "test"] | None,
|
|
68
|
-
version: Literal["small", "full"] | None = "full",
|
|
69
|
-
download: bool = False,
|
|
70
|
-
classes: List[str] | None = None,
|
|
71
|
-
class_mappings: Dict[str, str] | None = _total_segmentator.reduced_class_mappings,
|
|
72
|
-
optimize_mask_loading: bool = True,
|
|
73
|
-
decompress: bool = True,
|
|
74
|
-
num_workers: int = 10,
|
|
75
|
-
transforms: Callable | None = None,
|
|
76
|
-
) -> None:
|
|
77
|
-
"""Initialize dataset.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
root: Path to the root directory of the dataset. The dataset will
|
|
81
|
-
be downloaded and extracted here, if it does not already exist.
|
|
82
|
-
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
83
|
-
version: The version of the dataset to initialize. If `None`, it will
|
|
84
|
-
use the files located at root as is and wont perform any checks.
|
|
85
|
-
download: Whether to download the data for the specified split.
|
|
86
|
-
Note that the download will be executed only by additionally
|
|
87
|
-
calling the :meth:`prepare_data` method and if the data does not
|
|
88
|
-
exist yet on disk.
|
|
89
|
-
classes: Whether to configure the dataset with a subset of classes.
|
|
90
|
-
If `None`, it will use all of them.
|
|
91
|
-
class_mappings: A dictionary that maps the original class names to a
|
|
92
|
-
reduced set of classes. If `None`, it will use the original classes.
|
|
93
|
-
optimize_mask_loading: Whether to pre-process the segmentation masks
|
|
94
|
-
in order to optimize the loading time. In the `setup` method, it
|
|
95
|
-
will reformat the binary one-hot masks to a semantic mask and store
|
|
96
|
-
it on disk.
|
|
97
|
-
decompress: Whether to decompress the ct.nii.gz files when preparing the data.
|
|
98
|
-
The label masks won't be decompressed, but when enabling optimize_mask_loading
|
|
99
|
-
it will export the semantic label masks to a single file in uncompressed .nii
|
|
100
|
-
format.
|
|
101
|
-
num_workers: The number of workers to use for optimizing the masks &
|
|
102
|
-
decompressing the .gz files.
|
|
103
|
-
transforms: A function/transforms that takes in an image and a target
|
|
104
|
-
mask and returns the transformed versions of both.
|
|
105
|
-
|
|
106
|
-
"""
|
|
107
|
-
super().__init__(transforms=transforms)
|
|
108
|
-
|
|
109
|
-
self._root = root
|
|
110
|
-
self._split = split
|
|
111
|
-
self._version = version
|
|
112
|
-
self._download = download
|
|
113
|
-
self._classes = classes
|
|
114
|
-
self._optimize_mask_loading = optimize_mask_loading
|
|
115
|
-
self._decompress = decompress
|
|
116
|
-
self._num_workers = num_workers
|
|
117
|
-
self._class_mappings = class_mappings
|
|
118
|
-
|
|
119
|
-
if self._classes and self._class_mappings:
|
|
120
|
-
raise ValueError("Both 'classes' and 'class_mappings' cannot be set at the same time.")
|
|
121
|
-
|
|
122
|
-
self._samples_dirs: List[str] = []
|
|
123
|
-
self._indices: List[Tuple[int, int]] = []
|
|
124
|
-
|
|
125
|
-
@functools.cached_property
|
|
126
|
-
@override
|
|
127
|
-
def classes(self) -> List[str]:
|
|
128
|
-
def get_filename(path: str) -> str:
|
|
129
|
-
"""Returns the filename from the full path."""
|
|
130
|
-
return os.path.basename(path).split(".")[0]
|
|
131
|
-
|
|
132
|
-
first_sample_labels = os.path.join(self._root, "s0011", "segmentations", "*.nii.gz")
|
|
133
|
-
all_classes = sorted(map(get_filename, glob(first_sample_labels)))
|
|
134
|
-
if self._classes:
|
|
135
|
-
is_subset = all(name in all_classes for name in self._classes)
|
|
136
|
-
if not is_subset:
|
|
137
|
-
raise ValueError("Provided class names are not subset of the original ones.")
|
|
138
|
-
classes = sorted(self._classes)
|
|
139
|
-
elif self._class_mappings:
|
|
140
|
-
is_subset = all(name in all_classes for name in self._class_mappings.keys())
|
|
141
|
-
if not is_subset:
|
|
142
|
-
raise ValueError("Provided class names are not subset of the original ones.")
|
|
143
|
-
classes = sorted(set(self._class_mappings.values()))
|
|
144
|
-
else:
|
|
145
|
-
classes = all_classes
|
|
146
|
-
return ["background"] + classes
|
|
147
|
-
|
|
148
|
-
@property
|
|
149
|
-
@override
|
|
150
|
-
def class_to_idx(self) -> Dict[str, int]:
|
|
151
|
-
return {label: index for index, label in enumerate(self.classes)}
|
|
152
|
-
|
|
153
|
-
@property
|
|
154
|
-
def _file_suffix(self) -> str:
|
|
155
|
-
return "nii" if self._decompress else "nii.gz"
|
|
156
|
-
|
|
157
|
-
@functools.cached_property
|
|
158
|
-
def _classes_hash(self) -> str:
|
|
159
|
-
return hashlib.md5(str(self.classes).encode(), usedforsecurity=False).hexdigest()
|
|
160
|
-
|
|
161
|
-
@override
|
|
162
|
-
def filename(self, index: int) -> str:
|
|
163
|
-
sample_idx, _ = self._indices[index]
|
|
164
|
-
sample_dir = self._samples_dirs[sample_idx]
|
|
165
|
-
return os.path.join(sample_dir, f"ct.{self._file_suffix}")
|
|
166
|
-
|
|
167
|
-
@override
|
|
168
|
-
def prepare_data(self) -> None:
|
|
169
|
-
if self._download:
|
|
170
|
-
self._download_dataset()
|
|
171
|
-
if self._decompress:
|
|
172
|
-
self._decompress_files()
|
|
173
|
-
self._samples_dirs = self._fetch_samples_dirs()
|
|
174
|
-
if self._optimize_mask_loading:
|
|
175
|
-
self._export_semantic_label_masks()
|
|
176
|
-
|
|
177
|
-
@override
|
|
178
|
-
def configure(self) -> None:
|
|
179
|
-
self._indices = self._create_indices()
|
|
180
|
-
|
|
181
|
-
@override
|
|
182
|
-
def validate(self) -> None:
|
|
183
|
-
if self._version is None or self._sample_every_n_slices is not None:
|
|
184
|
-
return
|
|
185
|
-
|
|
186
|
-
if self._classes:
|
|
187
|
-
last_label = self._classes[-1]
|
|
188
|
-
n_classes = len(self._classes)
|
|
189
|
-
elif self._class_mappings:
|
|
190
|
-
classes = sorted(set(self._class_mappings.values()))
|
|
191
|
-
last_label = classes[-1]
|
|
192
|
-
n_classes = len(classes)
|
|
193
|
-
else:
|
|
194
|
-
last_label = "vertebrae_T9"
|
|
195
|
-
n_classes = 117
|
|
196
|
-
|
|
197
|
-
_validators.check_dataset_integrity(
|
|
198
|
-
self,
|
|
199
|
-
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
|
|
200
|
-
n_classes=n_classes + 1,
|
|
201
|
-
first_and_last_labels=("background", last_label),
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
@override
|
|
205
|
-
def __len__(self) -> int:
|
|
206
|
-
return len(self._indices)
|
|
207
|
-
|
|
208
|
-
@override
|
|
209
|
-
def load_data(self, index: int) -> tv_tensors.Image:
|
|
210
|
-
sample_index, slice_index = self._indices[index]
|
|
211
|
-
image_path = self._get_image_path(sample_index)
|
|
212
|
-
image_nii = io.read_nifti(image_path, slice_index)
|
|
213
|
-
image_array = io.nifti_to_array(image_nii)
|
|
214
|
-
image_array = self._fix_orientation(image_array)
|
|
215
|
-
return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
|
|
216
|
-
|
|
217
|
-
@override
|
|
218
|
-
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
219
|
-
if self._optimize_mask_loading:
|
|
220
|
-
mask = self._load_semantic_label_mask(index)
|
|
221
|
-
else:
|
|
222
|
-
mask = self._load_target(index)
|
|
223
|
-
mask = self._fix_orientation(mask)
|
|
224
|
-
return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
|
|
225
|
-
|
|
226
|
-
@override
|
|
227
|
-
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
228
|
-
_, slice_index = self._indices[index]
|
|
229
|
-
return {"slice_index": slice_index}
|
|
230
|
-
|
|
231
|
-
def _load_target(self, index: int) -> npt.NDArray[Any]:
|
|
232
|
-
sample_index, slice_index = self._indices[index]
|
|
233
|
-
return self._load_masks_as_semantic_label(sample_index, slice_index)
|
|
234
|
-
|
|
235
|
-
def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
|
|
236
|
-
"""Loads the segmentation mask from a semantic label NifTi file."""
|
|
237
|
-
sample_index, slice_index = self._indices[index]
|
|
238
|
-
nii = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
|
|
239
|
-
return io.nifti_to_array(nii)
|
|
240
|
-
|
|
241
|
-
def _load_masks_as_semantic_label(
|
|
242
|
-
self, sample_index: int, slice_index: int | None = None
|
|
243
|
-
) -> npt.NDArray[Any]:
|
|
244
|
-
"""Loads binary masks as a semantic label mask.
|
|
245
|
-
|
|
246
|
-
Args:
|
|
247
|
-
sample_index: The data sample index.
|
|
248
|
-
slice_index: Whether to return only a specific slice.
|
|
249
|
-
"""
|
|
250
|
-
masks_dir = self._get_masks_dir(sample_index)
|
|
251
|
-
classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
|
|
252
|
-
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
|
|
253
|
-
binary_masks = [io.nifti_to_array(io.read_nifti(path, slice_index)) for path in mask_paths]
|
|
254
|
-
|
|
255
|
-
if self._class_mappings:
|
|
256
|
-
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
|
|
257
|
-
self.classes[1:]
|
|
258
|
-
)
|
|
259
|
-
for original_class, mapped_class in self._class_mappings.items():
|
|
260
|
-
mapped_index = self.class_to_idx[mapped_class] - 1
|
|
261
|
-
original_index = list(self._class_mappings.keys()).index(original_class)
|
|
262
|
-
mapped_binary_masks[mapped_index] = np.logical_or(
|
|
263
|
-
mapped_binary_masks[mapped_index], binary_masks[original_index]
|
|
264
|
-
)
|
|
265
|
-
binary_masks = mapped_binary_masks
|
|
266
|
-
|
|
267
|
-
background_mask = np.zeros_like(binary_masks[0])
|
|
268
|
-
return np.argmax([background_mask] + binary_masks, axis=0)
|
|
269
|
-
|
|
270
|
-
def _export_semantic_label_masks(self) -> None:
|
|
271
|
-
"""Exports the segmentation binary masks (one-hot) to semantic labels."""
|
|
272
|
-
mask_classes_file = os.path.join(f"{self._get_optimized_masks_root()}/classes.txt")
|
|
273
|
-
if os.path.isfile(mask_classes_file):
|
|
274
|
-
with open(mask_classes_file, "r") as file:
|
|
275
|
-
if file.read() != str(self.classes):
|
|
276
|
-
raise ValueError(
|
|
277
|
-
"Optimized masks hash doesn't match the current classes or mappings."
|
|
278
|
-
)
|
|
279
|
-
return
|
|
280
|
-
|
|
281
|
-
total_samples = len(self._samples_dirs)
|
|
282
|
-
semantic_labels = [
|
|
283
|
-
(index, self._get_optimized_masks_file(index)) for index in range(total_samples)
|
|
284
|
-
]
|
|
285
|
-
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
|
|
286
|
-
|
|
287
|
-
def _process_mask(sample_index: Any, filename: str) -> None:
|
|
288
|
-
semantic_labels = self._load_masks_as_semantic_label(sample_index)
|
|
289
|
-
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
290
|
-
io.save_array_as_nifti(semantic_labels, filename)
|
|
291
|
-
|
|
292
|
-
multiprocessing.run_with_threads(
|
|
293
|
-
_process_mask,
|
|
294
|
-
list(to_export),
|
|
295
|
-
num_workers=self._num_workers,
|
|
296
|
-
progress_desc=">> Exporting optimized semantic mask",
|
|
297
|
-
return_results=False,
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
os.makedirs(os.path.dirname(mask_classes_file), exist_ok=True)
|
|
301
|
-
with open(mask_classes_file, "w") as file:
|
|
302
|
-
file.write(str(self.classes))
|
|
303
|
-
|
|
304
|
-
def _fix_orientation(self, array: npt.NDArray):
|
|
305
|
-
"""Fixes orientation such that table is at the bottom & liver on the left."""
|
|
306
|
-
array = np.rot90(array)
|
|
307
|
-
array = np.flip(array, axis=1)
|
|
308
|
-
return array
|
|
309
|
-
|
|
310
|
-
def _get_image_path(self, sample_index: int) -> str:
|
|
311
|
-
"""Returns the corresponding image path."""
|
|
312
|
-
sample_dir = self._samples_dirs[sample_index]
|
|
313
|
-
return os.path.join(self._root, sample_dir, f"ct.{self._file_suffix}")
|
|
314
|
-
|
|
315
|
-
def _get_masks_dir(self, sample_index: int) -> str:
|
|
316
|
-
"""Returns the directory of the corresponding masks."""
|
|
317
|
-
sample_dir = self._samples_dirs[sample_index]
|
|
318
|
-
return os.path.join(self._root, sample_dir, "segmentations")
|
|
319
|
-
|
|
320
|
-
def _get_optimized_masks_root(self) -> str:
|
|
321
|
-
"""Returns the directory of the optimized masks."""
|
|
322
|
-
return os.path.join(self._root, f"processed/masks/{self._classes_hash}")
|
|
323
|
-
|
|
324
|
-
def _get_optimized_masks_file(self, sample_index: int) -> str:
|
|
325
|
-
"""Returns the semantic label filename."""
|
|
326
|
-
return os.path.join(
|
|
327
|
-
f"{self._get_optimized_masks_root()}/{self._samples_dirs[sample_index]}/masks.nii"
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
|
|
331
|
-
"""Returns the total amount of slices of a sample."""
|
|
332
|
-
image_path = self._get_image_path(sample_index)
|
|
333
|
-
image_shape = io.fetch_nifti_shape(image_path)
|
|
334
|
-
return image_shape[-1]
|
|
335
|
-
|
|
336
|
-
def _fetch_samples_dirs(self) -> List[str]:
|
|
337
|
-
"""Returns the name of all the samples of all the splits of the dataset."""
|
|
338
|
-
sample_filenames = [
|
|
339
|
-
filename
|
|
340
|
-
for filename in os.listdir(self._root)
|
|
341
|
-
if os.path.isdir(os.path.join(self._root, filename)) and re.match(r"^s\d{4}$", filename)
|
|
342
|
-
]
|
|
343
|
-
return sorted(sample_filenames)
|
|
344
|
-
|
|
345
|
-
def _get_split_indices(self) -> List[int]:
|
|
346
|
-
"""Returns the samples indices that corresponding the dataset split and version."""
|
|
347
|
-
metadata_file = os.path.join(self._root, "meta.csv")
|
|
348
|
-
metadata = io.read_csv(metadata_file, delimiter=";", encoding="utf-8-sig")
|
|
349
|
-
|
|
350
|
-
match self._split:
|
|
351
|
-
case "train":
|
|
352
|
-
image_ids = [item["image_id"] for item in metadata if item["split"] == "train"]
|
|
353
|
-
case "val":
|
|
354
|
-
image_ids = [item["image_id"] for item in metadata if item["split"] == "val"]
|
|
355
|
-
case "test":
|
|
356
|
-
image_ids = [item["image_id"] for item in metadata if item["split"] == "test"]
|
|
357
|
-
case _:
|
|
358
|
-
image_ids = self._samples_dirs
|
|
359
|
-
|
|
360
|
-
return sorted(map(self._samples_dirs.index, image_ids))
|
|
361
|
-
|
|
362
|
-
def _create_indices(self) -> List[Tuple[int, int]]:
|
|
363
|
-
"""Builds the dataset indices for the specified split.
|
|
364
|
-
|
|
365
|
-
Returns:
|
|
366
|
-
A list of tuples, where the first value indicates the
|
|
367
|
-
sample index which the second its corresponding slice
|
|
368
|
-
index.
|
|
369
|
-
"""
|
|
370
|
-
indices = [
|
|
371
|
-
(sample_idx, slide_idx)
|
|
372
|
-
for sample_idx in self._get_split_indices()
|
|
373
|
-
for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx))
|
|
374
|
-
if slide_idx % (self._sample_every_n_slices or 1) == 0
|
|
375
|
-
]
|
|
376
|
-
return indices
|
|
377
|
-
|
|
378
|
-
def _download_dataset(self) -> None:
|
|
379
|
-
"""Downloads the dataset."""
|
|
380
|
-
dataset_resources = {
|
|
381
|
-
"small": self._resources_small,
|
|
382
|
-
"full": self._resources_full,
|
|
383
|
-
}
|
|
384
|
-
resources = dataset_resources.get(self._version or "")
|
|
385
|
-
if resources is None:
|
|
386
|
-
raise ValueError(
|
|
387
|
-
f"Can't download data version '{self._version}'. Use 'small' or 'full'."
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
self._print_license()
|
|
391
|
-
for resource in resources:
|
|
392
|
-
if os.path.isdir(self._root):
|
|
393
|
-
continue
|
|
394
|
-
|
|
395
|
-
utils.download_and_extract_archive(
|
|
396
|
-
resource.url,
|
|
397
|
-
download_root=self._root,
|
|
398
|
-
filename=resource.filename,
|
|
399
|
-
remove_finished=True,
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
def _decompress_files(self) -> None:
|
|
403
|
-
compressed_paths = Path(self._root).rglob("*/ct.nii.gz")
|
|
404
|
-
multiprocessing.run_with_threads(
|
|
405
|
-
core_io.gunzip_file,
|
|
406
|
-
[(str(path),) for path in compressed_paths],
|
|
407
|
-
num_workers=self._num_workers,
|
|
408
|
-
progress_desc=">> Decompressing .gz files",
|
|
409
|
-
return_results=False,
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
def _print_license(self) -> None:
|
|
413
|
-
"""Prints the dataset license."""
|
|
414
|
-
print(f"Dataset license: {self._license}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|