kaiko-eva 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/dataloaders/dataloader.py +5 -2
- eva/core/data/datamodules/datamodule.py +42 -5
- eva/core/data/datamodules/schemas.py +18 -1
- eva/core/data/datasets/__init__.py +4 -1
- eva/core/data/datasets/base.py +23 -0
- eva/core/data/datasets/typings.py +18 -0
- eva/core/data/samplers/__init__.py +4 -2
- eva/core/data/samplers/classification/__init__.py +5 -0
- eva/core/data/samplers/classification/balanced.py +96 -0
- eva/core/data/samplers/random.py +39 -0
- eva/core/data/samplers/sampler.py +27 -0
- eva/core/metrics/structs/module.py +30 -9
- eva/core/models/__init__.py +8 -1
- eva/core/models/modules/head.py +19 -1
- eva/core/models/modules/utils/__init__.py +2 -1
- eva/core/models/modules/utils/checkpoint.py +21 -0
- eva/core/models/wrappers/__init__.py +3 -1
- eva/core/models/wrappers/from_torchhub.py +93 -0
- eva/core/trainers/functional.py +4 -2
- eva/core/trainers/trainer.py +8 -4
- eva/vision/data/datasets/segmentation/_total_segmentator.py +91 -0
- eva/vision/data/datasets/segmentation/consep.py +4 -1
- eva/vision/data/datasets/segmentation/lits.py +3 -3
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +92 -37
- eva/vision/data/datasets/vision.py +1 -18
- eva/vision/losses/dice.py +0 -3
- eva/vision/metrics/__init__.py +5 -1
- eva/vision/metrics/defaults/segmentation/multiclass.py +30 -6
- eva/vision/metrics/segmentation/__init__.py +4 -0
- eva/vision/metrics/segmentation/_utils.py +1 -2
- eva/vision/metrics/segmentation/dice.py +69 -0
- eva/vision/metrics/segmentation/generalized_dice.py +2 -4
- eva/vision/metrics/segmentation/mean_iou.py +4 -8
- eva/vision/metrics/segmentation/monai_dice.py +57 -0
- eva/vision/metrics/wrappers/__init__.py +5 -0
- eva/vision/metrics/wrappers/monai.py +32 -0
- eva/vision/models/modules/semantic_segmentation.py +19 -1
- eva/vision/models/networks/backbones/__init__.py +2 -2
- eva/vision/models/networks/backbones/torchhub/__init__.py +5 -0
- eva/vision/models/networks/backbones/torchhub/backbones.py +61 -0
- eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
- eva/vision/models/wrappers/__init__.py +1 -1
- {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/METADATA +3 -2
- {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/RECORD +47 -34
- {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/WHEEL +0 -0
- {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Model wrapper for torch.hub models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.core.models import wrappers
|
|
10
|
+
from eva.core.models.wrappers import _utils
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TorchHubModel(wrappers.BaseModel):
|
|
14
|
+
"""Model wrapper for `torch.hub` models."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_name: str,
|
|
19
|
+
repo_or_dir: str,
|
|
20
|
+
pretrained: bool = True,
|
|
21
|
+
checkpoint_path: str = "",
|
|
22
|
+
out_indices: int | Tuple[int, ...] | None = None,
|
|
23
|
+
norm: bool = False,
|
|
24
|
+
trust_repo: bool = True,
|
|
25
|
+
model_kwargs: Dict[str, Any] | None = None,
|
|
26
|
+
tensor_transforms: Callable | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Initializes the encoder.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model_name: Name of model to instantiate.
|
|
32
|
+
repo_or_dir: The torch.hub repository or local directory to load the model from.
|
|
33
|
+
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
|
|
34
|
+
checkpoint_path: Path of checkpoint to load.
|
|
35
|
+
out_indices: Returns last n blocks if `int`, all if `None`, select
|
|
36
|
+
matching indices if sequence.
|
|
37
|
+
norm: Wether to apply norm layer to all intermediate features. Only
|
|
38
|
+
used when `out_indices` is not `None`.
|
|
39
|
+
trust_repo: If set to `False`, a prompt will ask the user whether the
|
|
40
|
+
repo should be trusted.
|
|
41
|
+
model_kwargs: Extra model arguments.
|
|
42
|
+
tensor_transforms: The transforms to apply to the output tensor
|
|
43
|
+
produced by the model.
|
|
44
|
+
"""
|
|
45
|
+
super().__init__(tensor_transforms=tensor_transforms)
|
|
46
|
+
|
|
47
|
+
self._model_name = model_name
|
|
48
|
+
self._repo_or_dir = repo_or_dir
|
|
49
|
+
self._pretrained = pretrained
|
|
50
|
+
self._checkpoint_path = checkpoint_path
|
|
51
|
+
self._out_indices = out_indices
|
|
52
|
+
self._norm = norm
|
|
53
|
+
self._trust_repo = trust_repo
|
|
54
|
+
self._model_kwargs = model_kwargs or {}
|
|
55
|
+
|
|
56
|
+
self.load_model()
|
|
57
|
+
|
|
58
|
+
@override
|
|
59
|
+
def load_model(self) -> None:
|
|
60
|
+
"""Builds and loads the torch.hub model."""
|
|
61
|
+
self._model: nn.Module = torch.hub.load(
|
|
62
|
+
repo_or_dir=self._repo_or_dir,
|
|
63
|
+
model=self._model_name,
|
|
64
|
+
trust_repo=self._trust_repo,
|
|
65
|
+
pretrained=self._pretrained,
|
|
66
|
+
**self._model_kwargs,
|
|
67
|
+
) # type: ignore
|
|
68
|
+
|
|
69
|
+
if self._checkpoint_path:
|
|
70
|
+
_utils.load_model_weights(self._model, self._checkpoint_path)
|
|
71
|
+
|
|
72
|
+
TorchHubModel.__name__ = self._model_name
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
|
|
76
|
+
if self._out_indices is not None:
|
|
77
|
+
if not hasattr(self._model, "get_intermediate_layers"):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Only models with `get_intermediate_layers` are supported "
|
|
80
|
+
"when using `out_indices`."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return list(
|
|
84
|
+
self._model.get_intermediate_layers(
|
|
85
|
+
tensor,
|
|
86
|
+
self._out_indices,
|
|
87
|
+
reshape=True,
|
|
88
|
+
return_class_token=False,
|
|
89
|
+
norm=self._norm,
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return self._model(tensor)
|
eva/core/trainers/functional.py
CHANGED
|
@@ -96,11 +96,13 @@ def fit_and_validate(
|
|
|
96
96
|
A tuple of with the validation and the test metrics (if exists).
|
|
97
97
|
"""
|
|
98
98
|
trainer.fit(model, datamodule=datamodule)
|
|
99
|
-
validation_scores = trainer.validate(
|
|
99
|
+
validation_scores = trainer.validate(
|
|
100
|
+
datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type
|
|
101
|
+
)
|
|
100
102
|
test_scores = (
|
|
101
103
|
None
|
|
102
104
|
if datamodule.datasets.test is None
|
|
103
|
-
else trainer.test(datamodule=datamodule, verbose=verbose)
|
|
105
|
+
else trainer.test(datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type)
|
|
104
106
|
)
|
|
105
107
|
return validation_scores, test_scores
|
|
106
108
|
|
eva/core/trainers/trainer.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Core trainer module."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any, Literal
|
|
5
5
|
|
|
6
6
|
import loguru
|
|
7
7
|
from lightning.pytorch import loggers as pl_loggers
|
|
@@ -28,6 +28,7 @@ class Trainer(pl_trainer.Trainer):
|
|
|
28
28
|
*args: Any,
|
|
29
29
|
default_root_dir: str = "logs",
|
|
30
30
|
n_runs: int = 1,
|
|
31
|
+
checkpoint_type: Literal["best", "last"] = "best",
|
|
31
32
|
**kwargs: Any,
|
|
32
33
|
) -> None:
|
|
33
34
|
"""Initializes the trainer.
|
|
@@ -40,11 +41,14 @@ class Trainer(pl_trainer.Trainer):
|
|
|
40
41
|
Unlike in ::class::`lightning.pytorch.Trainer`, this path would be the
|
|
41
42
|
prioritized destination point.
|
|
42
43
|
n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session.
|
|
44
|
+
checkpoint_type: Wether to load the "best" or "last" checkpoint saved by the checkpoint
|
|
45
|
+
callback for evaluations on validation & test sets.
|
|
43
46
|
kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`.
|
|
44
47
|
"""
|
|
45
48
|
super().__init__(*args, default_root_dir=default_root_dir, **kwargs)
|
|
46
49
|
|
|
47
|
-
self.
|
|
50
|
+
self.checkpoint_type = checkpoint_type
|
|
51
|
+
self.n_runs = n_runs
|
|
48
52
|
|
|
49
53
|
self._session_id: str = _logging.generate_session_id()
|
|
50
54
|
self._log_dir: str = self.default_log_dir
|
|
@@ -106,6 +110,6 @@ class Trainer(pl_trainer.Trainer):
|
|
|
106
110
|
base_trainer=self,
|
|
107
111
|
base_model=model,
|
|
108
112
|
datamodule=datamodule,
|
|
109
|
-
n_runs=self.
|
|
110
|
-
verbose=self.
|
|
113
|
+
n_runs=self.n_runs,
|
|
114
|
+
verbose=self.n_runs > 1,
|
|
111
115
|
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Utils for TotalSegmentator dataset classes."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
reduced_class_mappings: Dict[str, str] = {
|
|
6
|
+
# Abdominal Organs
|
|
7
|
+
"spleen": "spleen",
|
|
8
|
+
"kidney_right": "kidney",
|
|
9
|
+
"kidney_left": "kidney",
|
|
10
|
+
"gallbladder": "gallbladder",
|
|
11
|
+
"liver": "liver",
|
|
12
|
+
"stomach": "stomach",
|
|
13
|
+
"pancreas": "pancreas",
|
|
14
|
+
"small_bowel": "small_bowel",
|
|
15
|
+
"duodenum": "duodenum",
|
|
16
|
+
"colon": "colon",
|
|
17
|
+
# Endocrine System
|
|
18
|
+
"adrenal_gland_right": "adrenal_gland",
|
|
19
|
+
"adrenal_gland_left": "adrenal_gland",
|
|
20
|
+
"thyroid_gland": "thyroid_gland",
|
|
21
|
+
# Respiratory System
|
|
22
|
+
"lung_upper_lobe_left": "lungs",
|
|
23
|
+
"lung_lower_lobe_left": "lungs",
|
|
24
|
+
"lung_upper_lobe_right": "lungs",
|
|
25
|
+
"lung_middle_lobe_right": "lungs",
|
|
26
|
+
"lung_lower_lobe_right": "lungs",
|
|
27
|
+
"trachea": "trachea",
|
|
28
|
+
"esophagus": "esophagus",
|
|
29
|
+
# Urogenital System
|
|
30
|
+
"urinary_bladder": "urogenital_system",
|
|
31
|
+
"prostate": "urogenital_system",
|
|
32
|
+
"kidney_cyst_left": "kidney_cyst",
|
|
33
|
+
"kidney_cyst_right": "kidney_cyst",
|
|
34
|
+
# Vertebral Column
|
|
35
|
+
**{f"vertebrae_{v}": "vertebrae" for v in ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]},
|
|
36
|
+
**{f"vertebrae_{v}": "vertebrae" for v in [f"T{i}" for i in range(1, 13)]},
|
|
37
|
+
**{f"vertebrae_{v}": "vertebrae" for v in [f"L{i}" for i in range(1, 6)]},
|
|
38
|
+
"vertebrae_S1": "vertebrae",
|
|
39
|
+
"sacrum": "sacral_spine",
|
|
40
|
+
# Cardiovascular System
|
|
41
|
+
"heart": "heart",
|
|
42
|
+
"aorta": "aorta",
|
|
43
|
+
"pulmonary_vein": "veins",
|
|
44
|
+
"brachiocephalic_trunk": "arteries",
|
|
45
|
+
"subclavian_artery_right": "arteries",
|
|
46
|
+
"subclavian_artery_left": "arteries",
|
|
47
|
+
"common_carotid_artery_right": "arteries",
|
|
48
|
+
"common_carotid_artery_left": "arteries",
|
|
49
|
+
"brachiocephalic_vein_left": "veins",
|
|
50
|
+
"brachiocephalic_vein_right": "veins",
|
|
51
|
+
"atrial_appendage_left": "atrial_appendage",
|
|
52
|
+
"superior_vena_cava": "veins",
|
|
53
|
+
"inferior_vena_cava": "veins",
|
|
54
|
+
"portal_vein_and_splenic_vein": "veins",
|
|
55
|
+
"iliac_artery_left": "arteries",
|
|
56
|
+
"iliac_artery_right": "arteries",
|
|
57
|
+
"iliac_vena_left": "veins",
|
|
58
|
+
"iliac_vena_right": "veins",
|
|
59
|
+
# Upper Extremity Bones
|
|
60
|
+
"humerus_left": "humerus",
|
|
61
|
+
"humerus_right": "humerus",
|
|
62
|
+
"scapula_left": "scapula",
|
|
63
|
+
"scapula_right": "scapula",
|
|
64
|
+
"clavicula_left": "clavicula",
|
|
65
|
+
"clavicula_right": "clavicula",
|
|
66
|
+
# Lower Extremity Bones
|
|
67
|
+
"femur_left": "femur",
|
|
68
|
+
"femur_right": "femur",
|
|
69
|
+
"hip_left": "hip",
|
|
70
|
+
"hip_right": "hip",
|
|
71
|
+
# Muscles
|
|
72
|
+
"gluteus_maximus_left": "gluteus",
|
|
73
|
+
"gluteus_maximus_right": "gluteus",
|
|
74
|
+
"gluteus_medius_left": "gluteus",
|
|
75
|
+
"gluteus_medius_right": "gluteus",
|
|
76
|
+
"gluteus_minimus_left": "gluteus",
|
|
77
|
+
"gluteus_minimus_right": "gluteus",
|
|
78
|
+
"autochthon_left": "autochthon",
|
|
79
|
+
"autochthon_right": "autochthon",
|
|
80
|
+
"iliopsoas_left": "iliopsoas",
|
|
81
|
+
"iliopsoas_right": "iliopsoas",
|
|
82
|
+
# Central Nervous System
|
|
83
|
+
"brain": "brain",
|
|
84
|
+
"spinal_cord": "spinal_cord",
|
|
85
|
+
# Skull and Thoracic Cage
|
|
86
|
+
"skull": "skull",
|
|
87
|
+
**{f"rib_left_{i}": "ribs" for i in range(1, 13)},
|
|
88
|
+
**{f"rib_right_{i}": "ribs" for i in range(1, 13)},
|
|
89
|
+
"costal_cartilages": "ribs",
|
|
90
|
+
"sternum": "sternum",
|
|
91
|
+
}
|
|
@@ -20,9 +20,12 @@ from eva.vision.utils import io
|
|
|
20
20
|
class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
|
|
21
21
|
"""Dataset class for CoNSeP semantic segmentation task.
|
|
22
22
|
|
|
23
|
-
|
|
23
|
+
As in [1], we combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
|
|
24
24
|
into the epithelial class and 5 (fibroblast), 6 (muscle) & 7 (endothelial) into
|
|
25
25
|
the spindle-shaped class.
|
|
26
|
+
|
|
27
|
+
[1] Graham, Simon, et al. "Hover-net: Simultaneous segmentation and classification of
|
|
28
|
+
nuclei in multi-tissue histology images." https://arxiv.org/abs/1802.04712
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
31
|
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
@@ -76,7 +76,7 @@ class LiTS(base.ImageSegmentation):
|
|
|
76
76
|
@property
|
|
77
77
|
@override
|
|
78
78
|
def classes(self) -> List[str]:
|
|
79
|
-
return ["liver", "tumor"]
|
|
79
|
+
return ["background", "liver", "tumor"]
|
|
80
80
|
|
|
81
81
|
@functools.cached_property
|
|
82
82
|
@override
|
|
@@ -105,8 +105,8 @@ class LiTS(base.ImageSegmentation):
|
|
|
105
105
|
_validators.check_dataset_integrity(
|
|
106
106
|
self,
|
|
107
107
|
length=self._expected_dataset_lengths.get(self._split, 0),
|
|
108
|
-
n_classes=
|
|
109
|
-
first_and_last_labels=("
|
|
108
|
+
n_classes=3,
|
|
109
|
+
first_and_last_labels=("background", "tumor"),
|
|
110
110
|
)
|
|
111
111
|
|
|
112
112
|
@override
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""TotalSegmentator 2D segmentation dataset class."""
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
|
+
import hashlib
|
|
4
5
|
import os
|
|
6
|
+
import re
|
|
5
7
|
from glob import glob
|
|
6
8
|
from pathlib import Path
|
|
7
9
|
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
@@ -16,7 +18,7 @@ from typing_extensions import override
|
|
|
16
18
|
from eva.core.utils import io as core_io
|
|
17
19
|
from eva.core.utils import multiprocessing
|
|
18
20
|
from eva.vision.data.datasets import _validators, structs
|
|
19
|
-
from eva.vision.data.datasets.segmentation import base
|
|
21
|
+
from eva.vision.data.datasets.segmentation import _total_segmentator, base
|
|
20
22
|
from eva.vision.utils import io
|
|
21
23
|
|
|
22
24
|
|
|
@@ -66,6 +68,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
66
68
|
version: Literal["small", "full"] | None = "full",
|
|
67
69
|
download: bool = False,
|
|
68
70
|
classes: List[str] | None = None,
|
|
71
|
+
class_mappings: Dict[str, str] | None = _total_segmentator.reduced_class_mappings,
|
|
69
72
|
optimize_mask_loading: bool = True,
|
|
70
73
|
decompress: bool = True,
|
|
71
74
|
num_workers: int = 10,
|
|
@@ -85,6 +88,8 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
85
88
|
exist yet on disk.
|
|
86
89
|
classes: Whether to configure the dataset with a subset of classes.
|
|
87
90
|
If `None`, it will use all of them.
|
|
91
|
+
class_mappings: A dictionary that maps the original class names to a
|
|
92
|
+
reduced set of classes. If `None`, it will use the original classes.
|
|
88
93
|
optimize_mask_loading: Whether to pre-process the segmentation masks
|
|
89
94
|
in order to optimize the loading time. In the `setup` method, it
|
|
90
95
|
will reformat the binary one-hot masks to a semantic mask and store
|
|
@@ -109,11 +114,10 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
109
114
|
self._optimize_mask_loading = optimize_mask_loading
|
|
110
115
|
self._decompress = decompress
|
|
111
116
|
self._num_workers = num_workers
|
|
117
|
+
self._class_mappings = class_mappings
|
|
112
118
|
|
|
113
|
-
if self.
|
|
114
|
-
raise ValueError(
|
|
115
|
-
"To use customize classes please set the optimize_mask_loading to `False`."
|
|
116
|
-
)
|
|
119
|
+
if self._classes and self._class_mappings:
|
|
120
|
+
raise ValueError("Both 'classes' and 'class_mappings' cannot be set at the same time.")
|
|
117
121
|
|
|
118
122
|
self._samples_dirs: List[str] = []
|
|
119
123
|
self._indices: List[Tuple[int, int]] = []
|
|
@@ -125,16 +129,21 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
125
129
|
"""Returns the filename from the full path."""
|
|
126
130
|
return os.path.basename(path).split(".")[0]
|
|
127
131
|
|
|
128
|
-
first_sample_labels = os.path.join(
|
|
129
|
-
self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
|
|
130
|
-
)
|
|
132
|
+
first_sample_labels = os.path.join(self._root, "s0011", "segmentations", "*.nii.gz")
|
|
131
133
|
all_classes = sorted(map(get_filename, glob(first_sample_labels)))
|
|
132
134
|
if self._classes:
|
|
133
135
|
is_subset = all(name in all_classes for name in self._classes)
|
|
134
136
|
if not is_subset:
|
|
135
|
-
raise ValueError("Provided class names are not subset of the
|
|
136
|
-
|
|
137
|
-
|
|
137
|
+
raise ValueError("Provided class names are not subset of the original ones.")
|
|
138
|
+
classes = sorted(self._classes)
|
|
139
|
+
elif self._class_mappings:
|
|
140
|
+
is_subset = all(name in all_classes for name in self._class_mappings.keys())
|
|
141
|
+
if not is_subset:
|
|
142
|
+
raise ValueError("Provided class names are not subset of the original ones.")
|
|
143
|
+
classes = sorted(set(self._class_mappings.values()))
|
|
144
|
+
else:
|
|
145
|
+
classes = all_classes
|
|
146
|
+
return ["background"] + classes
|
|
138
147
|
|
|
139
148
|
@property
|
|
140
149
|
@override
|
|
@@ -145,6 +154,10 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
145
154
|
def _file_suffix(self) -> str:
|
|
146
155
|
return "nii" if self._decompress else "nii.gz"
|
|
147
156
|
|
|
157
|
+
@functools.cached_property
|
|
158
|
+
def _classes_hash(self) -> str:
|
|
159
|
+
return hashlib.md5(str(self.classes).encode(), usedforsecurity=False).hexdigest()
|
|
160
|
+
|
|
148
161
|
@override
|
|
149
162
|
def filename(self, index: int) -> str:
|
|
150
163
|
sample_idx, _ = self._indices[index]
|
|
@@ -170,15 +183,22 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
170
183
|
if self._version is None or self._sample_every_n_slices is not None:
|
|
171
184
|
return
|
|
172
185
|
|
|
186
|
+
if self._classes:
|
|
187
|
+
last_label = self._classes[-1]
|
|
188
|
+
n_classes = len(self._classes)
|
|
189
|
+
elif self._class_mappings:
|
|
190
|
+
classes = sorted(set(self._class_mappings.values()))
|
|
191
|
+
last_label = classes[-1]
|
|
192
|
+
n_classes = len(classes)
|
|
193
|
+
else:
|
|
194
|
+
last_label = "vertebrae_T9"
|
|
195
|
+
n_classes = 117
|
|
196
|
+
|
|
173
197
|
_validators.check_dataset_integrity(
|
|
174
198
|
self,
|
|
175
199
|
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
|
|
176
|
-
n_classes=
|
|
177
|
-
first_and_last_labels=(
|
|
178
|
-
(self._classes[0], self._classes[-1])
|
|
179
|
-
if self._classes
|
|
180
|
-
else ("adrenal_gland_left", "vertebrae_T9")
|
|
181
|
-
),
|
|
200
|
+
n_classes=n_classes + 1,
|
|
201
|
+
first_and_last_labels=("background", last_label),
|
|
182
202
|
)
|
|
183
203
|
|
|
184
204
|
@override
|
|
@@ -190,32 +210,31 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
190
210
|
sample_index, slice_index = self._indices[index]
|
|
191
211
|
image_path = self._get_image_path(sample_index)
|
|
192
212
|
image_array = io.read_nifti(image_path, slice_index)
|
|
193
|
-
|
|
194
|
-
return tv_tensors.Image(
|
|
213
|
+
image_array = self._fix_orientation(image_array)
|
|
214
|
+
return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
|
|
195
215
|
|
|
196
216
|
@override
|
|
197
217
|
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
198
218
|
if self._optimize_mask_loading:
|
|
199
|
-
|
|
200
|
-
|
|
219
|
+
mask = self._load_semantic_label_mask(index)
|
|
220
|
+
else:
|
|
221
|
+
mask = self._load_mask(index)
|
|
222
|
+
mask = self._fix_orientation(mask)
|
|
223
|
+
return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
|
|
201
224
|
|
|
202
225
|
@override
|
|
203
226
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
204
227
|
_, slice_index = self._indices[index]
|
|
205
228
|
return {"slice_index": slice_index}
|
|
206
229
|
|
|
207
|
-
def _load_mask(self, index: int) ->
|
|
230
|
+
def _load_mask(self, index: int) -> npt.NDArray[Any]:
|
|
208
231
|
sample_index, slice_index = self._indices[index]
|
|
209
|
-
|
|
210
|
-
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
232
|
+
return self._load_masks_as_semantic_label(sample_index, slice_index)
|
|
211
233
|
|
|
212
|
-
def _load_semantic_label_mask(self, index: int) ->
|
|
234
|
+
def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
|
|
213
235
|
"""Loads the segmentation mask from a semantic label NifTi file."""
|
|
214
236
|
sample_index, slice_index = self._indices[index]
|
|
215
|
-
|
|
216
|
-
filename = os.path.join(masks_dir, "semantic_labels", "masks.nii")
|
|
217
|
-
semantic_labels = io.read_nifti(filename, slice_index)
|
|
218
|
-
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
|
|
237
|
+
return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
|
|
219
238
|
|
|
220
239
|
def _load_masks_as_semantic_label(
|
|
221
240
|
self, sample_index: int, slice_index: int | None = None
|
|
@@ -227,18 +246,39 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
227
246
|
slice_index: Whether to return only a specific slice.
|
|
228
247
|
"""
|
|
229
248
|
masks_dir = self._get_masks_dir(sample_index)
|
|
230
|
-
|
|
249
|
+
classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
|
|
250
|
+
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
|
|
231
251
|
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
|
|
252
|
+
|
|
253
|
+
if self._class_mappings:
|
|
254
|
+
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
|
|
255
|
+
self.classes[1:]
|
|
256
|
+
)
|
|
257
|
+
for original_class, mapped_class in self._class_mappings.items():
|
|
258
|
+
mapped_index = self.class_to_idx[mapped_class] - 1
|
|
259
|
+
original_index = list(self._class_mappings.keys()).index(original_class)
|
|
260
|
+
mapped_binary_masks[mapped_index] = np.logical_or(
|
|
261
|
+
mapped_binary_masks[mapped_index], binary_masks[original_index]
|
|
262
|
+
)
|
|
263
|
+
binary_masks = mapped_binary_masks
|
|
264
|
+
|
|
232
265
|
background_mask = np.zeros_like(binary_masks[0])
|
|
233
266
|
return np.argmax([background_mask] + binary_masks, axis=0)
|
|
234
267
|
|
|
235
268
|
def _export_semantic_label_masks(self) -> None:
|
|
236
269
|
"""Exports the segmentation binary masks (one-hot) to semantic labels."""
|
|
270
|
+
mask_classes_file = os.path.join(f"{self._get_optimized_masks_root()}/classes.txt")
|
|
271
|
+
if os.path.isfile(mask_classes_file):
|
|
272
|
+
with open(mask_classes_file, "r") as file:
|
|
273
|
+
if file.read() != str(self.classes):
|
|
274
|
+
raise ValueError(
|
|
275
|
+
"Optimized masks hash doesn't match the current classes or mappings."
|
|
276
|
+
)
|
|
277
|
+
return
|
|
278
|
+
|
|
237
279
|
total_samples = len(self._samples_dirs)
|
|
238
|
-
masks_dirs = map(self._get_masks_dir, range(total_samples))
|
|
239
280
|
semantic_labels = [
|
|
240
|
-
(index,
|
|
241
|
-
for index, directory in enumerate(masks_dirs)
|
|
281
|
+
(index, self._get_optimized_masks_file(index)) for index in range(total_samples)
|
|
242
282
|
]
|
|
243
283
|
to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
|
|
244
284
|
|
|
@@ -255,6 +295,16 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
255
295
|
return_results=False,
|
|
256
296
|
)
|
|
257
297
|
|
|
298
|
+
os.makedirs(os.path.dirname(mask_classes_file), exist_ok=True)
|
|
299
|
+
with open(mask_classes_file, "w") as file:
|
|
300
|
+
file.write(str(self.classes))
|
|
301
|
+
|
|
302
|
+
def _fix_orientation(self, array: npt.NDArray):
|
|
303
|
+
"""Fixes orientation such that table is at the bottom & liver on the left."""
|
|
304
|
+
array = np.rot90(array)
|
|
305
|
+
array = np.flip(array, axis=1)
|
|
306
|
+
return array
|
|
307
|
+
|
|
258
308
|
def _get_image_path(self, sample_index: int) -> str:
|
|
259
309
|
"""Returns the corresponding image path."""
|
|
260
310
|
sample_dir = self._samples_dirs[sample_index]
|
|
@@ -265,10 +315,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
265
315
|
sample_dir = self._samples_dirs[sample_index]
|
|
266
316
|
return os.path.join(self._root, sample_dir, "segmentations")
|
|
267
317
|
|
|
268
|
-
def
|
|
318
|
+
def _get_optimized_masks_root(self) -> str:
|
|
319
|
+
"""Returns the directory of the optimized masks."""
|
|
320
|
+
return os.path.join(self._root, f"processed/masks/{self._classes_hash}")
|
|
321
|
+
|
|
322
|
+
def _get_optimized_masks_file(self, sample_index: int) -> str:
|
|
269
323
|
"""Returns the semantic label filename."""
|
|
270
|
-
|
|
271
|
-
|
|
324
|
+
return os.path.join(
|
|
325
|
+
f"{self._get_optimized_masks_root()}/{self._samples_dirs[sample_index]}/masks.nii"
|
|
326
|
+
)
|
|
272
327
|
|
|
273
328
|
def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
|
|
274
329
|
"""Returns the total amount of slices of a sample."""
|
|
@@ -281,7 +336,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
281
336
|
sample_filenames = [
|
|
282
337
|
filename
|
|
283
338
|
for filename in os.listdir(self._root)
|
|
284
|
-
if os.path.isdir(os.path.join(self._root, filename))
|
|
339
|
+
if os.path.isdir(os.path.join(self._root, filename)) and re.match(r"^s\d{4}$", filename)
|
|
285
340
|
]
|
|
286
341
|
return sorted(sample_filenames)
|
|
287
342
|
|
|
@@ -9,7 +9,7 @@ DataSample = TypeVar("DataSample")
|
|
|
9
9
|
"""The data sample type."""
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class VisionDataset(base.
|
|
12
|
+
class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
|
|
13
13
|
"""Base dataset class for vision tasks."""
|
|
14
14
|
|
|
15
15
|
@abc.abstractmethod
|
|
@@ -24,20 +24,3 @@ class VisionDataset(base.Dataset, abc.ABC, Generic[DataSample]):
|
|
|
24
24
|
Returns:
|
|
25
25
|
The filename of the `index`'th data sample.
|
|
26
26
|
"""
|
|
27
|
-
|
|
28
|
-
@abc.abstractmethod
|
|
29
|
-
def __getitem__(self, index: int) -> DataSample:
|
|
30
|
-
"""Returns the `index`'th data sample.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
index: The index of the data-sample to select.
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
A data sample and its target.
|
|
37
|
-
"""
|
|
38
|
-
raise NotImplementedError
|
|
39
|
-
|
|
40
|
-
@abc.abstractmethod
|
|
41
|
-
def __len__(self) -> int:
|
|
42
|
-
"""Returns the total length of the data."""
|
|
43
|
-
raise NotImplementedError
|
eva/vision/losses/dice.py
CHANGED
|
@@ -45,9 +45,6 @@ class DiceLoss(losses.DiceLoss): # type: ignore
|
|
|
45
45
|
inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
|
|
46
46
|
targets = _to_one_hot(targets, num_classes=inputs.shape[1])
|
|
47
47
|
|
|
48
|
-
if targets.ndim == 3:
|
|
49
|
-
targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
|
|
50
|
-
|
|
51
48
|
return super().forward(inputs, targets)
|
|
52
49
|
|
|
53
50
|
|
eva/vision/metrics/__init__.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""Default metric collections API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
|
|
4
|
+
from eva.vision.metrics.segmentation.dice import DiceScore
|
|
4
5
|
from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
|
|
5
6
|
from eva.vision.metrics.segmentation.mean_iou import MeanIoU
|
|
7
|
+
from eva.vision.metrics.segmentation.monai_dice import MonaiDiceScore
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
|
-
"
|
|
10
|
+
"DiceScore",
|
|
9
11
|
"GeneralizedDiceScore",
|
|
10
12
|
"MeanIoU",
|
|
13
|
+
"MonaiDiceScore",
|
|
14
|
+
"MulticlassSegmentationMetrics",
|
|
11
15
|
]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Default metric collection for multiclass semantic segmentation tasks."""
|
|
2
2
|
|
|
3
3
|
from eva.core.metrics import structs
|
|
4
|
-
from eva.vision.metrics
|
|
4
|
+
from eva.vision.metrics import segmentation
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class MulticlassSegmentationMetrics(structs.MetricCollection):
|
|
@@ -26,19 +26,43 @@ class MulticlassSegmentationMetrics(structs.MetricCollection):
|
|
|
26
26
|
postfix: A string to add after the keys in the output dictionary.
|
|
27
27
|
"""
|
|
28
28
|
super().__init__(
|
|
29
|
-
metrics=
|
|
30
|
-
|
|
29
|
+
metrics={
|
|
30
|
+
"MonaiDiceScore": segmentation.MonaiDiceScore(
|
|
31
31
|
num_classes=num_classes,
|
|
32
32
|
include_background=include_background,
|
|
33
|
-
weight_type="linear",
|
|
34
33
|
ignore_index=ignore_index,
|
|
34
|
+
ignore_empty=True,
|
|
35
35
|
),
|
|
36
|
-
|
|
36
|
+
"MonaiDiceScore (ignore_empty=False)": segmentation.MonaiDiceScore(
|
|
37
37
|
num_classes=num_classes,
|
|
38
38
|
include_background=include_background,
|
|
39
39
|
ignore_index=ignore_index,
|
|
40
|
+
ignore_empty=False,
|
|
40
41
|
),
|
|
41
|
-
|
|
42
|
+
"DiceScore (micro)": segmentation.DiceScore(
|
|
43
|
+
num_classes=num_classes,
|
|
44
|
+
include_background=include_background,
|
|
45
|
+
average="micro",
|
|
46
|
+
ignore_index=ignore_index,
|
|
47
|
+
),
|
|
48
|
+
"DiceScore (macro)": segmentation.DiceScore(
|
|
49
|
+
num_classes=num_classes,
|
|
50
|
+
include_background=include_background,
|
|
51
|
+
average="macro",
|
|
52
|
+
ignore_index=ignore_index,
|
|
53
|
+
),
|
|
54
|
+
"DiceScore (weighted)": segmentation.DiceScore(
|
|
55
|
+
num_classes=num_classes,
|
|
56
|
+
include_background=include_background,
|
|
57
|
+
average="weighted",
|
|
58
|
+
ignore_index=ignore_index,
|
|
59
|
+
),
|
|
60
|
+
"MeanIoU": segmentation.MeanIoU(
|
|
61
|
+
num_classes=num_classes,
|
|
62
|
+
include_background=include_background,
|
|
63
|
+
ignore_index=ignore_index,
|
|
64
|
+
),
|
|
65
|
+
},
|
|
42
66
|
prefix=prefix,
|
|
43
67
|
postfix=postfix,
|
|
44
68
|
)
|
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
"""Segmentation metrics API."""
|
|
2
2
|
|
|
3
|
+
from eva.vision.metrics.segmentation.dice import DiceScore
|
|
3
4
|
from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
|
|
4
5
|
from eva.vision.metrics.segmentation.mean_iou import MeanIoU
|
|
6
|
+
from eva.vision.metrics.segmentation.monai_dice import MonaiDiceScore
|
|
5
7
|
|
|
6
8
|
__all__ = [
|
|
9
|
+
"DiceScore",
|
|
10
|
+
"MonaiDiceScore",
|
|
7
11
|
"GeneralizedDiceScore",
|
|
8
12
|
"MeanIoU",
|
|
9
13
|
]
|