kaiko-eva 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (47) hide show
  1. eva/core/data/dataloaders/dataloader.py +5 -2
  2. eva/core/data/datamodules/datamodule.py +42 -5
  3. eva/core/data/datamodules/schemas.py +18 -1
  4. eva/core/data/datasets/__init__.py +4 -1
  5. eva/core/data/datasets/base.py +23 -0
  6. eva/core/data/datasets/typings.py +18 -0
  7. eva/core/data/samplers/__init__.py +4 -2
  8. eva/core/data/samplers/classification/__init__.py +5 -0
  9. eva/core/data/samplers/classification/balanced.py +96 -0
  10. eva/core/data/samplers/random.py +39 -0
  11. eva/core/data/samplers/sampler.py +27 -0
  12. eva/core/metrics/structs/module.py +30 -9
  13. eva/core/models/__init__.py +8 -1
  14. eva/core/models/modules/head.py +19 -1
  15. eva/core/models/modules/utils/__init__.py +2 -1
  16. eva/core/models/modules/utils/checkpoint.py +21 -0
  17. eva/core/models/wrappers/__init__.py +3 -1
  18. eva/core/models/wrappers/from_torchhub.py +93 -0
  19. eva/core/trainers/functional.py +4 -2
  20. eva/core/trainers/trainer.py +8 -4
  21. eva/vision/data/datasets/segmentation/_total_segmentator.py +91 -0
  22. eva/vision/data/datasets/segmentation/consep.py +4 -1
  23. eva/vision/data/datasets/segmentation/lits.py +3 -3
  24. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +92 -37
  25. eva/vision/data/datasets/vision.py +1 -18
  26. eva/vision/losses/dice.py +0 -3
  27. eva/vision/metrics/__init__.py +5 -1
  28. eva/vision/metrics/defaults/segmentation/multiclass.py +30 -6
  29. eva/vision/metrics/segmentation/__init__.py +4 -0
  30. eva/vision/metrics/segmentation/_utils.py +1 -2
  31. eva/vision/metrics/segmentation/dice.py +69 -0
  32. eva/vision/metrics/segmentation/generalized_dice.py +2 -4
  33. eva/vision/metrics/segmentation/mean_iou.py +4 -8
  34. eva/vision/metrics/segmentation/monai_dice.py +57 -0
  35. eva/vision/metrics/wrappers/__init__.py +5 -0
  36. eva/vision/metrics/wrappers/monai.py +32 -0
  37. eva/vision/models/modules/semantic_segmentation.py +19 -1
  38. eva/vision/models/networks/backbones/__init__.py +2 -2
  39. eva/vision/models/networks/backbones/torchhub/__init__.py +5 -0
  40. eva/vision/models/networks/backbones/torchhub/backbones.py +61 -0
  41. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  42. eva/vision/models/wrappers/__init__.py +1 -1
  43. {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/METADATA +3 -2
  44. {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/RECORD +47 -34
  45. {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/WHEEL +0 -0
  46. {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/entry_points.txt +0 -0
  47. {kaiko_eva-0.1.6.dist-info → kaiko_eva-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,93 @@
1
+ """Model wrapper for torch.hub models."""
2
+
3
+ from typing import Any, Callable, Dict, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from typing_extensions import override
8
+
9
+ from eva.core.models import wrappers
10
+ from eva.core.models.wrappers import _utils
11
+
12
+
13
+ class TorchHubModel(wrappers.BaseModel):
14
+ """Model wrapper for `torch.hub` models."""
15
+
16
+ def __init__(
17
+ self,
18
+ model_name: str,
19
+ repo_or_dir: str,
20
+ pretrained: bool = True,
21
+ checkpoint_path: str = "",
22
+ out_indices: int | Tuple[int, ...] | None = None,
23
+ norm: bool = False,
24
+ trust_repo: bool = True,
25
+ model_kwargs: Dict[str, Any] | None = None,
26
+ tensor_transforms: Callable | None = None,
27
+ ) -> None:
28
+ """Initializes the encoder.
29
+
30
+ Args:
31
+ model_name: Name of model to instantiate.
32
+ repo_or_dir: The torch.hub repository or local directory to load the model from.
33
+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
34
+ checkpoint_path: Path of checkpoint to load.
35
+ out_indices: Returns last n blocks if `int`, all if `None`, select
36
+ matching indices if sequence.
37
+ norm: Wether to apply norm layer to all intermediate features. Only
38
+ used when `out_indices` is not `None`.
39
+ trust_repo: If set to `False`, a prompt will ask the user whether the
40
+ repo should be trusted.
41
+ model_kwargs: Extra model arguments.
42
+ tensor_transforms: The transforms to apply to the output tensor
43
+ produced by the model.
44
+ """
45
+ super().__init__(tensor_transforms=tensor_transforms)
46
+
47
+ self._model_name = model_name
48
+ self._repo_or_dir = repo_or_dir
49
+ self._pretrained = pretrained
50
+ self._checkpoint_path = checkpoint_path
51
+ self._out_indices = out_indices
52
+ self._norm = norm
53
+ self._trust_repo = trust_repo
54
+ self._model_kwargs = model_kwargs or {}
55
+
56
+ self.load_model()
57
+
58
+ @override
59
+ def load_model(self) -> None:
60
+ """Builds and loads the torch.hub model."""
61
+ self._model: nn.Module = torch.hub.load(
62
+ repo_or_dir=self._repo_or_dir,
63
+ model=self._model_name,
64
+ trust_repo=self._trust_repo,
65
+ pretrained=self._pretrained,
66
+ **self._model_kwargs,
67
+ ) # type: ignore
68
+
69
+ if self._checkpoint_path:
70
+ _utils.load_model_weights(self._model, self._checkpoint_path)
71
+
72
+ TorchHubModel.__name__ = self._model_name
73
+
74
+ @override
75
+ def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
76
+ if self._out_indices is not None:
77
+ if not hasattr(self._model, "get_intermediate_layers"):
78
+ raise ValueError(
79
+ "Only models with `get_intermediate_layers` are supported "
80
+ "when using `out_indices`."
81
+ )
82
+
83
+ return list(
84
+ self._model.get_intermediate_layers(
85
+ tensor,
86
+ self._out_indices,
87
+ reshape=True,
88
+ return_class_token=False,
89
+ norm=self._norm,
90
+ )
91
+ )
92
+
93
+ return self._model(tensor)
@@ -96,11 +96,13 @@ def fit_and_validate(
96
96
  A tuple of with the validation and the test metrics (if exists).
97
97
  """
98
98
  trainer.fit(model, datamodule=datamodule)
99
- validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose)
99
+ validation_scores = trainer.validate(
100
+ datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type
101
+ )
100
102
  test_scores = (
101
103
  None
102
104
  if datamodule.datasets.test is None
103
- else trainer.test(datamodule=datamodule, verbose=verbose)
105
+ else trainer.test(datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type)
104
106
  )
105
107
  return validation_scores, test_scores
106
108
 
@@ -1,7 +1,7 @@
1
1
  """Core trainer module."""
2
2
 
3
3
  import os
4
- from typing import Any
4
+ from typing import Any, Literal
5
5
 
6
6
  import loguru
7
7
  from lightning.pytorch import loggers as pl_loggers
@@ -28,6 +28,7 @@ class Trainer(pl_trainer.Trainer):
28
28
  *args: Any,
29
29
  default_root_dir: str = "logs",
30
30
  n_runs: int = 1,
31
+ checkpoint_type: Literal["best", "last"] = "best",
31
32
  **kwargs: Any,
32
33
  ) -> None:
33
34
  """Initializes the trainer.
@@ -40,11 +41,14 @@ class Trainer(pl_trainer.Trainer):
40
41
  Unlike in ::class::`lightning.pytorch.Trainer`, this path would be the
41
42
  prioritized destination point.
42
43
  n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session.
44
+ checkpoint_type: Wether to load the "best" or "last" checkpoint saved by the checkpoint
45
+ callback for evaluations on validation & test sets.
43
46
  kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`.
44
47
  """
45
48
  super().__init__(*args, default_root_dir=default_root_dir, **kwargs)
46
49
 
47
- self._n_runs = n_runs
50
+ self.checkpoint_type = checkpoint_type
51
+ self.n_runs = n_runs
48
52
 
49
53
  self._session_id: str = _logging.generate_session_id()
50
54
  self._log_dir: str = self.default_log_dir
@@ -106,6 +110,6 @@ class Trainer(pl_trainer.Trainer):
106
110
  base_trainer=self,
107
111
  base_model=model,
108
112
  datamodule=datamodule,
109
- n_runs=self._n_runs,
110
- verbose=self._n_runs > 1,
113
+ n_runs=self.n_runs,
114
+ verbose=self.n_runs > 1,
111
115
  )
@@ -0,0 +1,91 @@
1
+ """Utils for TotalSegmentator dataset classes."""
2
+
3
+ from typing import Dict
4
+
5
+ reduced_class_mappings: Dict[str, str] = {
6
+ # Abdominal Organs
7
+ "spleen": "spleen",
8
+ "kidney_right": "kidney",
9
+ "kidney_left": "kidney",
10
+ "gallbladder": "gallbladder",
11
+ "liver": "liver",
12
+ "stomach": "stomach",
13
+ "pancreas": "pancreas",
14
+ "small_bowel": "small_bowel",
15
+ "duodenum": "duodenum",
16
+ "colon": "colon",
17
+ # Endocrine System
18
+ "adrenal_gland_right": "adrenal_gland",
19
+ "adrenal_gland_left": "adrenal_gland",
20
+ "thyroid_gland": "thyroid_gland",
21
+ # Respiratory System
22
+ "lung_upper_lobe_left": "lungs",
23
+ "lung_lower_lobe_left": "lungs",
24
+ "lung_upper_lobe_right": "lungs",
25
+ "lung_middle_lobe_right": "lungs",
26
+ "lung_lower_lobe_right": "lungs",
27
+ "trachea": "trachea",
28
+ "esophagus": "esophagus",
29
+ # Urogenital System
30
+ "urinary_bladder": "urogenital_system",
31
+ "prostate": "urogenital_system",
32
+ "kidney_cyst_left": "kidney_cyst",
33
+ "kidney_cyst_right": "kidney_cyst",
34
+ # Vertebral Column
35
+ **{f"vertebrae_{v}": "vertebrae" for v in ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]},
36
+ **{f"vertebrae_{v}": "vertebrae" for v in [f"T{i}" for i in range(1, 13)]},
37
+ **{f"vertebrae_{v}": "vertebrae" for v in [f"L{i}" for i in range(1, 6)]},
38
+ "vertebrae_S1": "vertebrae",
39
+ "sacrum": "sacral_spine",
40
+ # Cardiovascular System
41
+ "heart": "heart",
42
+ "aorta": "aorta",
43
+ "pulmonary_vein": "veins",
44
+ "brachiocephalic_trunk": "arteries",
45
+ "subclavian_artery_right": "arteries",
46
+ "subclavian_artery_left": "arteries",
47
+ "common_carotid_artery_right": "arteries",
48
+ "common_carotid_artery_left": "arteries",
49
+ "brachiocephalic_vein_left": "veins",
50
+ "brachiocephalic_vein_right": "veins",
51
+ "atrial_appendage_left": "atrial_appendage",
52
+ "superior_vena_cava": "veins",
53
+ "inferior_vena_cava": "veins",
54
+ "portal_vein_and_splenic_vein": "veins",
55
+ "iliac_artery_left": "arteries",
56
+ "iliac_artery_right": "arteries",
57
+ "iliac_vena_left": "veins",
58
+ "iliac_vena_right": "veins",
59
+ # Upper Extremity Bones
60
+ "humerus_left": "humerus",
61
+ "humerus_right": "humerus",
62
+ "scapula_left": "scapula",
63
+ "scapula_right": "scapula",
64
+ "clavicula_left": "clavicula",
65
+ "clavicula_right": "clavicula",
66
+ # Lower Extremity Bones
67
+ "femur_left": "femur",
68
+ "femur_right": "femur",
69
+ "hip_left": "hip",
70
+ "hip_right": "hip",
71
+ # Muscles
72
+ "gluteus_maximus_left": "gluteus",
73
+ "gluteus_maximus_right": "gluteus",
74
+ "gluteus_medius_left": "gluteus",
75
+ "gluteus_medius_right": "gluteus",
76
+ "gluteus_minimus_left": "gluteus",
77
+ "gluteus_minimus_right": "gluteus",
78
+ "autochthon_left": "autochthon",
79
+ "autochthon_right": "autochthon",
80
+ "iliopsoas_left": "iliopsoas",
81
+ "iliopsoas_right": "iliopsoas",
82
+ # Central Nervous System
83
+ "brain": "brain",
84
+ "spinal_cord": "spinal_cord",
85
+ # Skull and Thoracic Cage
86
+ "skull": "skull",
87
+ **{f"rib_left_{i}": "ribs" for i in range(1, 13)},
88
+ **{f"rib_right_{i}": "ribs" for i in range(1, 13)},
89
+ "costal_cartilages": "ribs",
90
+ "sternum": "sternum",
91
+ }
@@ -20,9 +20,12 @@ from eva.vision.utils import io
20
20
  class CoNSeP(wsi.MultiWsiDataset, base.ImageSegmentation):
21
21
  """Dataset class for CoNSeP semantic segmentation task.
22
22
 
23
- We combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
23
+ As in [1], we combine classes 3 (healthy epithelial) & 4 (dysplastic/malignant epithelial)
24
24
  into the epithelial class and 5 (fibroblast), 6 (muscle) & 7 (endothelial) into
25
25
  the spindle-shaped class.
26
+
27
+ [1] Graham, Simon, et al. "Hover-net: Simultaneous segmentation and classification of
28
+ nuclei in multi-tissue histology images." https://arxiv.org/abs/1802.04712
26
29
  """
27
30
 
28
31
  _expected_dataset_lengths: Dict[str | None, int] = {
@@ -76,7 +76,7 @@ class LiTS(base.ImageSegmentation):
76
76
  @property
77
77
  @override
78
78
  def classes(self) -> List[str]:
79
- return ["liver", "tumor"]
79
+ return ["background", "liver", "tumor"]
80
80
 
81
81
  @functools.cached_property
82
82
  @override
@@ -105,8 +105,8 @@ class LiTS(base.ImageSegmentation):
105
105
  _validators.check_dataset_integrity(
106
106
  self,
107
107
  length=self._expected_dataset_lengths.get(self._split, 0),
108
- n_classes=2,
109
- first_and_last_labels=("liver", "tumor"),
108
+ n_classes=3,
109
+ first_and_last_labels=("background", "tumor"),
110
110
  )
111
111
 
112
112
  @override
@@ -1,7 +1,9 @@
1
1
  """TotalSegmentator 2D segmentation dataset class."""
2
2
 
3
3
  import functools
4
+ import hashlib
4
5
  import os
6
+ import re
5
7
  from glob import glob
6
8
  from pathlib import Path
7
9
  from typing import Any, Callable, Dict, List, Literal, Tuple
@@ -16,7 +18,7 @@ from typing_extensions import override
16
18
  from eva.core.utils import io as core_io
17
19
  from eva.core.utils import multiprocessing
18
20
  from eva.vision.data.datasets import _validators, structs
19
- from eva.vision.data.datasets.segmentation import base
21
+ from eva.vision.data.datasets.segmentation import _total_segmentator, base
20
22
  from eva.vision.utils import io
21
23
 
22
24
 
@@ -66,6 +68,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
66
68
  version: Literal["small", "full"] | None = "full",
67
69
  download: bool = False,
68
70
  classes: List[str] | None = None,
71
+ class_mappings: Dict[str, str] | None = _total_segmentator.reduced_class_mappings,
69
72
  optimize_mask_loading: bool = True,
70
73
  decompress: bool = True,
71
74
  num_workers: int = 10,
@@ -85,6 +88,8 @@ class TotalSegmentator2D(base.ImageSegmentation):
85
88
  exist yet on disk.
86
89
  classes: Whether to configure the dataset with a subset of classes.
87
90
  If `None`, it will use all of them.
91
+ class_mappings: A dictionary that maps the original class names to a
92
+ reduced set of classes. If `None`, it will use the original classes.
88
93
  optimize_mask_loading: Whether to pre-process the segmentation masks
89
94
  in order to optimize the loading time. In the `setup` method, it
90
95
  will reformat the binary one-hot masks to a semantic mask and store
@@ -109,11 +114,10 @@ class TotalSegmentator2D(base.ImageSegmentation):
109
114
  self._optimize_mask_loading = optimize_mask_loading
110
115
  self._decompress = decompress
111
116
  self._num_workers = num_workers
117
+ self._class_mappings = class_mappings
112
118
 
113
- if self._optimize_mask_loading and self._classes is not None:
114
- raise ValueError(
115
- "To use customize classes please set the optimize_mask_loading to `False`."
116
- )
119
+ if self._classes and self._class_mappings:
120
+ raise ValueError("Both 'classes' and 'class_mappings' cannot be set at the same time.")
117
121
 
118
122
  self._samples_dirs: List[str] = []
119
123
  self._indices: List[Tuple[int, int]] = []
@@ -125,16 +129,21 @@ class TotalSegmentator2D(base.ImageSegmentation):
125
129
  """Returns the filename from the full path."""
126
130
  return os.path.basename(path).split(".")[0]
127
131
 
128
- first_sample_labels = os.path.join(
129
- self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
130
- )
132
+ first_sample_labels = os.path.join(self._root, "s0011", "segmentations", "*.nii.gz")
131
133
  all_classes = sorted(map(get_filename, glob(first_sample_labels)))
132
134
  if self._classes:
133
135
  is_subset = all(name in all_classes for name in self._classes)
134
136
  if not is_subset:
135
- raise ValueError("Provided class names are not subset of the dataset onces.")
136
-
137
- return all_classes if self._classes is None else self._classes
137
+ raise ValueError("Provided class names are not subset of the original ones.")
138
+ classes = sorted(self._classes)
139
+ elif self._class_mappings:
140
+ is_subset = all(name in all_classes for name in self._class_mappings.keys())
141
+ if not is_subset:
142
+ raise ValueError("Provided class names are not subset of the original ones.")
143
+ classes = sorted(set(self._class_mappings.values()))
144
+ else:
145
+ classes = all_classes
146
+ return ["background"] + classes
138
147
 
139
148
  @property
140
149
  @override
@@ -145,6 +154,10 @@ class TotalSegmentator2D(base.ImageSegmentation):
145
154
  def _file_suffix(self) -> str:
146
155
  return "nii" if self._decompress else "nii.gz"
147
156
 
157
+ @functools.cached_property
158
+ def _classes_hash(self) -> str:
159
+ return hashlib.md5(str(self.classes).encode(), usedforsecurity=False).hexdigest()
160
+
148
161
  @override
149
162
  def filename(self, index: int) -> str:
150
163
  sample_idx, _ = self._indices[index]
@@ -170,15 +183,22 @@ class TotalSegmentator2D(base.ImageSegmentation):
170
183
  if self._version is None or self._sample_every_n_slices is not None:
171
184
  return
172
185
 
186
+ if self._classes:
187
+ last_label = self._classes[-1]
188
+ n_classes = len(self._classes)
189
+ elif self._class_mappings:
190
+ classes = sorted(set(self._class_mappings.values()))
191
+ last_label = classes[-1]
192
+ n_classes = len(classes)
193
+ else:
194
+ last_label = "vertebrae_T9"
195
+ n_classes = 117
196
+
173
197
  _validators.check_dataset_integrity(
174
198
  self,
175
199
  length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
176
- n_classes=len(self._classes) if self._classes else 117,
177
- first_and_last_labels=(
178
- (self._classes[0], self._classes[-1])
179
- if self._classes
180
- else ("adrenal_gland_left", "vertebrae_T9")
181
- ),
200
+ n_classes=n_classes + 1,
201
+ first_and_last_labels=("background", last_label),
182
202
  )
183
203
 
184
204
  @override
@@ -190,32 +210,31 @@ class TotalSegmentator2D(base.ImageSegmentation):
190
210
  sample_index, slice_index = self._indices[index]
191
211
  image_path = self._get_image_path(sample_index)
192
212
  image_array = io.read_nifti(image_path, slice_index)
193
- image_rgb_array = image_array.repeat(3, axis=2)
194
- return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
213
+ image_array = self._fix_orientation(image_array)
214
+ return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
195
215
 
196
216
  @override
197
217
  def load_mask(self, index: int) -> tv_tensors.Mask:
198
218
  if self._optimize_mask_loading:
199
- return self._load_semantic_label_mask(index)
200
- return self._load_mask(index)
219
+ mask = self._load_semantic_label_mask(index)
220
+ else:
221
+ mask = self._load_mask(index)
222
+ mask = self._fix_orientation(mask)
223
+ return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
201
224
 
202
225
  @override
203
226
  def load_metadata(self, index: int) -> Dict[str, Any]:
204
227
  _, slice_index = self._indices[index]
205
228
  return {"slice_index": slice_index}
206
229
 
207
- def _load_mask(self, index: int) -> tv_tensors.Mask:
230
+ def _load_mask(self, index: int) -> npt.NDArray[Any]:
208
231
  sample_index, slice_index = self._indices[index]
209
- semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
210
- return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
232
+ return self._load_masks_as_semantic_label(sample_index, slice_index)
211
233
 
212
- def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
234
+ def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
213
235
  """Loads the segmentation mask from a semantic label NifTi file."""
214
236
  sample_index, slice_index = self._indices[index]
215
- masks_dir = self._get_masks_dir(sample_index)
216
- filename = os.path.join(masks_dir, "semantic_labels", "masks.nii")
217
- semantic_labels = io.read_nifti(filename, slice_index)
218
- return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
237
+ return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
219
238
 
220
239
  def _load_masks_as_semantic_label(
221
240
  self, sample_index: int, slice_index: int | None = None
@@ -227,18 +246,39 @@ class TotalSegmentator2D(base.ImageSegmentation):
227
246
  slice_index: Whether to return only a specific slice.
228
247
  """
229
248
  masks_dir = self._get_masks_dir(sample_index)
230
- mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in self.classes]
249
+ classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
250
+ mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
231
251
  binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
252
+
253
+ if self._class_mappings:
254
+ mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
255
+ self.classes[1:]
256
+ )
257
+ for original_class, mapped_class in self._class_mappings.items():
258
+ mapped_index = self.class_to_idx[mapped_class] - 1
259
+ original_index = list(self._class_mappings.keys()).index(original_class)
260
+ mapped_binary_masks[mapped_index] = np.logical_or(
261
+ mapped_binary_masks[mapped_index], binary_masks[original_index]
262
+ )
263
+ binary_masks = mapped_binary_masks
264
+
232
265
  background_mask = np.zeros_like(binary_masks[0])
233
266
  return np.argmax([background_mask] + binary_masks, axis=0)
234
267
 
235
268
  def _export_semantic_label_masks(self) -> None:
236
269
  """Exports the segmentation binary masks (one-hot) to semantic labels."""
270
+ mask_classes_file = os.path.join(f"{self._get_optimized_masks_root()}/classes.txt")
271
+ if os.path.isfile(mask_classes_file):
272
+ with open(mask_classes_file, "r") as file:
273
+ if file.read() != str(self.classes):
274
+ raise ValueError(
275
+ "Optimized masks hash doesn't match the current classes or mappings."
276
+ )
277
+ return
278
+
237
279
  total_samples = len(self._samples_dirs)
238
- masks_dirs = map(self._get_masks_dir, range(total_samples))
239
280
  semantic_labels = [
240
- (index, os.path.join(directory, "semantic_labels", "masks.nii"))
241
- for index, directory in enumerate(masks_dirs)
281
+ (index, self._get_optimized_masks_file(index)) for index in range(total_samples)
242
282
  ]
243
283
  to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
244
284
 
@@ -255,6 +295,16 @@ class TotalSegmentator2D(base.ImageSegmentation):
255
295
  return_results=False,
256
296
  )
257
297
 
298
+ os.makedirs(os.path.dirname(mask_classes_file), exist_ok=True)
299
+ with open(mask_classes_file, "w") as file:
300
+ file.write(str(self.classes))
301
+
302
+ def _fix_orientation(self, array: npt.NDArray):
303
+ """Fixes orientation such that table is at the bottom & liver on the left."""
304
+ array = np.rot90(array)
305
+ array = np.flip(array, axis=1)
306
+ return array
307
+
258
308
  def _get_image_path(self, sample_index: int) -> str:
259
309
  """Returns the corresponding image path."""
260
310
  sample_dir = self._samples_dirs[sample_index]
@@ -265,10 +315,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
265
315
  sample_dir = self._samples_dirs[sample_index]
266
316
  return os.path.join(self._root, sample_dir, "segmentations")
267
317
 
268
- def _get_semantic_labels_filename(self, sample_index: int) -> str:
318
+ def _get_optimized_masks_root(self) -> str:
319
+ """Returns the directory of the optimized masks."""
320
+ return os.path.join(self._root, f"processed/masks/{self._classes_hash}")
321
+
322
+ def _get_optimized_masks_file(self, sample_index: int) -> str:
269
323
  """Returns the semantic label filename."""
270
- masks_dir = self._get_masks_dir(sample_index)
271
- return os.path.join(masks_dir, "semantic_labels", "masks.nii")
324
+ return os.path.join(
325
+ f"{self._get_optimized_masks_root()}/{self._samples_dirs[sample_index]}/masks.nii"
326
+ )
272
327
 
273
328
  def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
274
329
  """Returns the total amount of slices of a sample."""
@@ -281,7 +336,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
281
336
  sample_filenames = [
282
337
  filename
283
338
  for filename in os.listdir(self._root)
284
- if os.path.isdir(os.path.join(self._root, filename))
339
+ if os.path.isdir(os.path.join(self._root, filename)) and re.match(r"^s\d{4}$", filename)
285
340
  ]
286
341
  return sorted(sample_filenames)
287
342
 
@@ -9,7 +9,7 @@ DataSample = TypeVar("DataSample")
9
9
  """The data sample type."""
10
10
 
11
11
 
12
- class VisionDataset(base.Dataset, abc.ABC, Generic[DataSample]):
12
+ class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
13
13
  """Base dataset class for vision tasks."""
14
14
 
15
15
  @abc.abstractmethod
@@ -24,20 +24,3 @@ class VisionDataset(base.Dataset, abc.ABC, Generic[DataSample]):
24
24
  Returns:
25
25
  The filename of the `index`'th data sample.
26
26
  """
27
-
28
- @abc.abstractmethod
29
- def __getitem__(self, index: int) -> DataSample:
30
- """Returns the `index`'th data sample.
31
-
32
- Args:
33
- index: The index of the data-sample to select.
34
-
35
- Returns:
36
- A data sample and its target.
37
- """
38
- raise NotImplementedError
39
-
40
- @abc.abstractmethod
41
- def __len__(self) -> int:
42
- """Returns the total length of the data."""
43
- raise NotImplementedError
eva/vision/losses/dice.py CHANGED
@@ -45,9 +45,6 @@ class DiceLoss(losses.DiceLoss): # type: ignore
45
45
  inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
46
46
  targets = _to_one_hot(targets, num_classes=inputs.shape[1])
47
47
 
48
- if targets.ndim == 3:
49
- targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
50
-
51
48
  return super().forward(inputs, targets)
52
49
 
53
50
 
@@ -1,11 +1,15 @@
1
1
  """Default metric collections API."""
2
2
 
3
3
  from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
4
+ from eva.vision.metrics.segmentation.dice import DiceScore
4
5
  from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
5
6
  from eva.vision.metrics.segmentation.mean_iou import MeanIoU
7
+ from eva.vision.metrics.segmentation.monai_dice import MonaiDiceScore
6
8
 
7
9
  __all__ = [
8
- "MulticlassSegmentationMetrics",
10
+ "DiceScore",
9
11
  "GeneralizedDiceScore",
10
12
  "MeanIoU",
13
+ "MonaiDiceScore",
14
+ "MulticlassSegmentationMetrics",
11
15
  ]
@@ -1,7 +1,7 @@
1
1
  """Default metric collection for multiclass semantic segmentation tasks."""
2
2
 
3
3
  from eva.core.metrics import structs
4
- from eva.vision.metrics.segmentation import generalized_dice, mean_iou
4
+ from eva.vision.metrics import segmentation
5
5
 
6
6
 
7
7
  class MulticlassSegmentationMetrics(structs.MetricCollection):
@@ -26,19 +26,43 @@ class MulticlassSegmentationMetrics(structs.MetricCollection):
26
26
  postfix: A string to add after the keys in the output dictionary.
27
27
  """
28
28
  super().__init__(
29
- metrics=[
30
- generalized_dice.GeneralizedDiceScore(
29
+ metrics={
30
+ "MonaiDiceScore": segmentation.MonaiDiceScore(
31
31
  num_classes=num_classes,
32
32
  include_background=include_background,
33
- weight_type="linear",
34
33
  ignore_index=ignore_index,
34
+ ignore_empty=True,
35
35
  ),
36
- mean_iou.MeanIoU(
36
+ "MonaiDiceScore (ignore_empty=False)": segmentation.MonaiDiceScore(
37
37
  num_classes=num_classes,
38
38
  include_background=include_background,
39
39
  ignore_index=ignore_index,
40
+ ignore_empty=False,
40
41
  ),
41
- ],
42
+ "DiceScore (micro)": segmentation.DiceScore(
43
+ num_classes=num_classes,
44
+ include_background=include_background,
45
+ average="micro",
46
+ ignore_index=ignore_index,
47
+ ),
48
+ "DiceScore (macro)": segmentation.DiceScore(
49
+ num_classes=num_classes,
50
+ include_background=include_background,
51
+ average="macro",
52
+ ignore_index=ignore_index,
53
+ ),
54
+ "DiceScore (weighted)": segmentation.DiceScore(
55
+ num_classes=num_classes,
56
+ include_background=include_background,
57
+ average="weighted",
58
+ ignore_index=ignore_index,
59
+ ),
60
+ "MeanIoU": segmentation.MeanIoU(
61
+ num_classes=num_classes,
62
+ include_background=include_background,
63
+ ignore_index=ignore_index,
64
+ ),
65
+ },
42
66
  prefix=prefix,
43
67
  postfix=postfix,
44
68
  )
@@ -1,9 +1,13 @@
1
1
  """Segmentation metrics API."""
2
2
 
3
+ from eva.vision.metrics.segmentation.dice import DiceScore
3
4
  from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
4
5
  from eva.vision.metrics.segmentation.mean_iou import MeanIoU
6
+ from eva.vision.metrics.segmentation.monai_dice import MonaiDiceScore
5
7
 
6
8
  __all__ = [
9
+ "DiceScore",
10
+ "MonaiDiceScore",
7
11
  "GeneralizedDiceScore",
8
12
  "MeanIoU",
9
13
  ]