kaiko-eva 0.1.1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. eva/core/callbacks/writers/embeddings/base.py +3 -4
  2. eva/core/data/dataloaders/dataloader.py +2 -2
  3. eva/core/data/splitting/random.py +6 -5
  4. eva/core/data/splitting/stratified.py +12 -6
  5. eva/core/losses/__init__.py +5 -0
  6. eva/core/losses/cross_entropy.py +27 -0
  7. eva/core/metrics/__init__.py +0 -4
  8. eva/core/metrics/defaults/__init__.py +0 -2
  9. eva/core/models/modules/module.py +9 -9
  10. eva/core/models/transforms/extract_cls_features.py +17 -9
  11. eva/core/models/transforms/extract_patch_features.py +23 -11
  12. eva/core/utils/io/__init__.py +2 -1
  13. eva/core/utils/io/gz.py +28 -0
  14. eva/core/utils/multiprocessing.py +46 -1
  15. eva/core/utils/progress_bar.py +15 -0
  16. eva/vision/callbacks/loggers/batch/segmentation.py +7 -4
  17. eva/vision/data/datasets/__init__.py +4 -0
  18. eva/vision/data/datasets/classification/__init__.py +2 -1
  19. eva/vision/data/datasets/classification/camelyon16.py +4 -1
  20. eva/vision/data/datasets/classification/panda.py +17 -1
  21. eva/vision/data/datasets/classification/wsi.py +4 -1
  22. eva/vision/data/datasets/segmentation/__init__.py +2 -0
  23. eva/vision/data/datasets/segmentation/consep.py +2 -2
  24. eva/vision/data/datasets/segmentation/lits.py +49 -29
  25. eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
  26. eva/vision/data/datasets/segmentation/monusac.py +7 -7
  27. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +50 -18
  28. eva/vision/data/datasets/wsi.py +37 -1
  29. eva/vision/data/wsi/patching/coordinates.py +9 -1
  30. eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
  31. eva/vision/data/wsi/patching/samplers/random.py +4 -2
  32. eva/vision/losses/__init__.py +2 -2
  33. eva/vision/losses/dice.py +75 -8
  34. eva/vision/metrics/__init__.py +11 -0
  35. eva/vision/metrics/defaults/__init__.py +7 -0
  36. eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
  37. eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
  38. eva/vision/metrics/segmentation/BUILD +1 -0
  39. eva/vision/metrics/segmentation/__init__.py +9 -0
  40. eva/vision/metrics/segmentation/_utils.py +69 -0
  41. eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
  42. eva/vision/metrics/segmentation/mean_iou.py +57 -0
  43. eva/vision/models/modules/semantic_segmentation.py +4 -3
  44. eva/vision/models/networks/backbones/_utils.py +12 -0
  45. eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
  46. eva/vision/models/networks/backbones/pathology/histai.py +8 -2
  47. eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
  48. eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
  49. eva/vision/models/networks/backbones/pathology/paige.py +51 -0
  50. eva/vision/models/networks/decoders/__init__.py +1 -1
  51. eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
  52. eva/vision/models/networks/decoders/segmentation/base.py +16 -0
  53. eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
  54. eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
  55. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
  56. eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
  57. eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
  58. eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
  59. eva/vision/utils/colormap.py +20 -0
  60. eva/vision/utils/io/__init__.py +7 -1
  61. eva/vision/utils/io/nifti.py +19 -4
  62. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/METADATA +8 -39
  63. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/RECORD +66 -52
  64. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/WHEEL +1 -1
  65. eva/core/metrics/mean_iou.py +0 -120
  66. eva/vision/models/networks/decoders/decoder.py +0 -7
  67. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/entry_points.txt +0 -0
  68. {kaiko_eva-0.1.1.dist-info → kaiko_eva-0.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -5,12 +5,14 @@ import glob
5
5
  import os
6
6
  from typing import Any, Callable, Dict, List, Literal, Tuple
7
7
 
8
+ import numpy as np
9
+ import numpy.typing as npt
8
10
  import torch
9
11
  from torchvision import tv_tensors
10
12
  from typing_extensions import override
11
13
 
12
14
  from eva.core import utils
13
- from eva.vision.data.datasets import _utils as data_utils
15
+ from eva.core.data import splitting
14
16
  from eva.vision.data.datasets import _validators
15
17
  from eva.vision.data.datasets.segmentation import base
16
18
  from eva.vision.utils import io
@@ -20,22 +22,23 @@ class LiTS(base.ImageSegmentation):
20
22
  """LiTS - Liver Tumor Segmentation Challenge.
21
23
 
22
24
  Webpage: https://competitions.codalab.org/competitions/17094
23
-
24
- For the splits we follow: https://arxiv.org/pdf/2010.01663v2
25
25
  """
26
26
 
27
- _train_index_ranges: List[Tuple[int, int]] = [(0, 102)]
28
- _val_index_ranges: List[Tuple[int, int]] = [(102, 117)]
29
- _test_index_ranges: List[Tuple[int, int]] = [(117, 131)]
27
+ _train_ratio: float = 0.7
28
+ _val_ratio: float = 0.15
29
+ _test_ratio: float = 0.15
30
30
  """Index ranges per split."""
31
31
 
32
+ _fix_orientation: bool = True
33
+ """Whether to fix the orientation of the images to match the default for radiologists."""
34
+
32
35
  _sample_every_n_slices: int | None = None
33
36
  """The amount of slices to sub-sample per 3D CT scan image."""
34
37
 
35
38
  _expected_dataset_lengths: Dict[str | None, int] = {
36
- "train": 39307,
37
- "val": 12045,
38
- "test": 7286,
39
+ "train": 38686,
40
+ "val": 11192,
41
+ "test": 8760,
39
42
  None: 58638,
40
43
  }
41
44
  """Dataset version and split to the expected size."""
@@ -51,6 +54,7 @@ class LiTS(base.ImageSegmentation):
51
54
  root: str,
52
55
  split: Literal["train", "val", "test"] | None = None,
53
56
  transforms: Callable | None = None,
57
+ seed: int = 8,
54
58
  ) -> None:
55
59
  """Initialize dataset.
56
60
 
@@ -60,12 +64,13 @@ class LiTS(base.ImageSegmentation):
60
64
  split: Dataset split to use.
61
65
  transforms: A function/transforms that takes in an image and a target
62
66
  mask and returns the transformed versions of both.
67
+ seed: Seed used for generating the dataset splits.
63
68
  """
64
69
  super().__init__(transforms=transforms)
65
70
 
66
71
  self._root = root
67
72
  self._split = split
68
-
73
+ self._seed = seed
69
74
  self._indices: List[Tuple[int, int]] = []
70
75
 
71
76
  @property
@@ -90,10 +95,12 @@ class LiTS(base.ImageSegmentation):
90
95
 
91
96
  @override
92
97
  def validate(self) -> None:
93
- if len(self._volume_files) != len(self._segmentation_files):
94
- raise ValueError(
95
- "The number of volume files does not match the number of the segmentation ones."
96
- )
98
+ for i in range(len(self._volume_files)):
99
+ seg_path = self._segmentation_file(i)
100
+ if not os.path.exists(seg_path):
101
+ raise FileNotFoundError(
102
+ f"Segmentation file {seg_path} not found for volume {self._volume_files[i]}."
103
+ )
97
104
 
98
105
  _validators.check_dataset_integrity(
99
106
  self,
@@ -107,15 +114,27 @@ class LiTS(base.ImageSegmentation):
107
114
  sample_index, slice_index = self._indices[index]
108
115
  volume_path = self._volume_files[sample_index]
109
116
  image_array = io.read_nifti(volume_path, slice_index)
117
+ if self._fix_orientation:
118
+ image_array = self._orientation(image_array, sample_index)
110
119
  return tv_tensors.Image(image_array.transpose(2, 0, 1))
111
120
 
112
121
  @override
113
122
  def load_mask(self, index: int) -> tv_tensors.Mask:
114
123
  sample_index, slice_index = self._indices[index]
115
- segmentation_path = self._segmentation_files[sample_index]
124
+ segmentation_path = self._segmentation_file(sample_index)
116
125
  semantic_labels = io.read_nifti(segmentation_path, slice_index)
126
+ if self._fix_orientation:
127
+ semantic_labels = self._orientation(semantic_labels, sample_index)
117
128
  return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
118
129
 
130
+ def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray:
131
+ volume_path = self._volume_files[sample_index]
132
+ orientation = io.fetch_nifti_axis_direction_code(volume_path)
133
+ array = np.rot90(array, axes=(0, 1))
134
+ if orientation == "LPS":
135
+ array = np.flip(array, axis=0)
136
+ return array.copy()
137
+
119
138
  @override
120
139
  def load_metadata(self, index: int) -> Dict[str, Any]:
121
140
  _, slice_index = self._indices[index]
@@ -137,11 +156,10 @@ class LiTS(base.ImageSegmentation):
137
156
  files = glob.glob(files_pattern, recursive=True)
138
157
  return utils.numeric_sort(files)
139
158
 
140
- @functools.cached_property
141
- def _segmentation_files(self) -> List[str]:
142
- files_pattern = os.path.join(self._root, "**", "segmentation-*.nii")
143
- files = glob.glob(files_pattern, recursive=True)
144
- return utils.numeric_sort(files)
159
+ def _segmentation_file(self, index: int) -> str:
160
+ volume_file_path = self._volume_files[index]
161
+ segmentation_file = os.path.basename(volume_file_path).replace("volume", "segmentation")
162
+ return os.path.join(os.path.dirname(volume_file_path), segmentation_file)
145
163
 
146
164
  def _create_indices(self) -> List[Tuple[int, int]]:
147
165
  """Builds the dataset indices for the specified split.
@@ -161,17 +179,19 @@ class LiTS(base.ImageSegmentation):
161
179
 
162
180
  def _get_split_indices(self) -> List[int]:
163
181
  """Returns the sample indices for the specified dataset split."""
164
- split_index_ranges = {
165
- "train": self._train_index_ranges,
166
- "val": self._val_index_ranges,
167
- "test": self._test_index_ranges,
168
- None: [(0, len(self._volume_files))],
182
+ indices = list(range(len(self._volume_files)))
183
+ train_indices, val_indices, test_indices = splitting.random_split(
184
+ indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed
185
+ )
186
+ split_indices_dict = {
187
+ "train": train_indices,
188
+ "val": val_indices,
189
+ "test": test_indices,
190
+ None: indices,
169
191
  }
170
- index_ranges = split_index_ranges.get(self._split)
171
- if index_ranges is None:
192
+ if self._split not in split_indices_dict:
172
193
  raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.")
173
-
174
- return data_utils.ranges_to_indices(index_ranges)
194
+ return list(split_indices_dict[self._split])
175
195
 
176
196
  def _print_license(self) -> None:
177
197
  """Prints the dataset license."""
@@ -0,0 +1,93 @@
1
+ """Balanced LiTS dataset."""
2
+
3
+ from typing import Callable, Dict, List, Literal, Tuple
4
+
5
+ import numpy as np
6
+ from typing_extensions import override
7
+
8
+ from eva.vision.data.datasets.segmentation import lits
9
+ from eva.vision.utils import io
10
+
11
+
12
+ class LiTSBalanced(lits.LiTS):
13
+ """Balanced version of the LiTS - Liver Tumor Segmentation Challenge dataset.
14
+
15
+ For each volume in the dataset, we sample the same number of slices where
16
+ only the liver and where both liver and tumor are present.
17
+
18
+ Webpage: https://competitions.codalab.org/competitions/17094
19
+
20
+ For the splits we follow: https://arxiv.org/pdf/2010.01663v2
21
+ """
22
+
23
+ _expected_dataset_lengths: Dict[str | None, int] = {
24
+ "train": 5514,
25
+ "val": 1332,
26
+ "test": 1530,
27
+ None: 8376,
28
+ }
29
+ """Dataset version and split to the expected size."""
30
+
31
+ def __init__(
32
+ self,
33
+ root: str,
34
+ split: Literal["train", "val", "test"] | None = None,
35
+ transforms: Callable | None = None,
36
+ seed: int = 8,
37
+ ) -> None:
38
+ """Initialize dataset.
39
+
40
+ Args:
41
+ root: Path to the root directory of the dataset. The dataset will
42
+ be downloaded and extracted here, if it does not already exist.
43
+ split: Dataset split to use.
44
+ transforms: A function/transforms that takes in an image and a target
45
+ mask and returns the transformed versions of both.
46
+ seed: Seed used for generating the dataset splits and sampling of the slices.
47
+ """
48
+ super().__init__(root=root, split=split, transforms=transforms, seed=seed)
49
+
50
+ @override
51
+ def _create_indices(self) -> List[Tuple[int, int]]:
52
+ """Builds the dataset indices for the specified split.
53
+
54
+ Returns:
55
+ A list of tuples, where the first value indicates the
56
+ sample index which the second its corresponding slice
57
+ index.
58
+ """
59
+ split_indices = set(self._get_split_indices())
60
+ indices: List[Tuple[int, int]] = []
61
+ random_generator = np.random.default_rng(seed=self._seed)
62
+
63
+ for sample_idx in range(len(self._volume_files)):
64
+ if sample_idx not in split_indices:
65
+ continue
66
+
67
+ segmentation = io.read_nifti(self._segmentation_file(sample_idx))
68
+ tumor_filter = segmentation == 2
69
+ tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0
70
+
71
+ if tumor_filter.sum() == 0:
72
+ continue
73
+
74
+ liver_filter = segmentation == 1
75
+ liver_slice_filter = liver_filter.sum(axis=(0, 1)) > 0
76
+
77
+ liver_and_tumor_filter = liver_slice_filter & tumor_slice_filter
78
+ liver_only_filter = liver_slice_filter & ~tumor_slice_filter
79
+
80
+ n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum())
81
+ tumor_indices = list(np.where(liver_and_tumor_filter)[0])
82
+ tumor_indices = list(
83
+ random_generator.choice(tumor_indices, size=n_slice_samples, replace=False)
84
+ )
85
+
86
+ liver_indices = list(np.where(liver_only_filter)[0])
87
+ liver_indices = list(
88
+ random_generator.choice(liver_indices, size=n_slice_samples, replace=False)
89
+ )
90
+
91
+ indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices])
92
+
93
+ return list(indices)
@@ -10,12 +10,12 @@ import imagesize
10
10
  import numpy as np
11
11
  import numpy.typing as npt
12
12
  import torch
13
- import tqdm
14
13
  from skimage import draw
15
14
  from torchvision import tv_tensors
16
15
  from torchvision.datasets import utils
17
16
  from typing_extensions import override
18
17
 
18
+ from eva.core.utils.progress_bar import tqdm
19
19
  from eva.vision.data.datasets import _validators, structs
20
20
  from eva.vision.data.datasets.segmentation import base
21
21
  from eva.vision.utils import io
@@ -84,7 +84,7 @@ class MoNuSAC(base.ImageSegmentation):
84
84
  @property
85
85
  @override
86
86
  def classes(self) -> List[str]:
87
- return ["Epithelial", "Lymphocyte", "Neutrophil", "Macrophage"]
87
+ return ["Background", "Epithelial", "Lymphocyte", "Neutrophil", "Macrophage", "Ambiguous"]
88
88
 
89
89
  @functools.cached_property
90
90
  @override
@@ -107,8 +107,8 @@ class MoNuSAC(base.ImageSegmentation):
107
107
  _validators.check_dataset_integrity(
108
108
  self,
109
109
  length=self._expected_dataset_lengths.get(self._split, 0),
110
- n_classes=4,
111
- first_and_last_labels=("Epithelial", "Macrophage"),
110
+ n_classes=6,
111
+ first_and_last_labels=("Background", "Ambiguous"),
112
112
  )
113
113
 
114
114
  @override
@@ -161,7 +161,7 @@ class MoNuSAC(base.ImageSegmentation):
161
161
  for index, filename in enumerate(self._image_files)
162
162
  ]
163
163
  to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
164
- for sample_index, filename in tqdm.tqdm(
164
+ for sample_index, filename in tqdm(
165
165
  list(to_export),
166
166
  desc=">> Exporting semantic masks",
167
167
  leave=False,
@@ -199,9 +199,9 @@ class MoNuSAC(base.ImageSegmentation):
199
199
  semantic_labels = np.zeros((height, width), "uint8") # type: ignore[reportCallIssue]
200
200
  for level in range(len(root)):
201
201
  label = [item.attrib["Name"] for item in root[level][0]][0]
202
- class_id = self.class_to_idx.get(label, 254) + 1
202
+ class_id = self.class_to_idx.get(label, self.class_to_idx["Ambiguous"])
203
203
  # for the test dataset an additional class 'Ambiguous' was added for
204
- # difficult regions with fuzzy boundaries - we return it as 255
204
+ # difficult regions with fuzzy boundaries
205
205
  regions = [item for child in root[level] for item in child if item.tag == "Region"]
206
206
  for region in regions:
207
207
  vertices = np.array(
@@ -3,16 +3,18 @@
3
3
  import functools
4
4
  import os
5
5
  from glob import glob
6
+ from pathlib import Path
6
7
  from typing import Any, Callable, Dict, List, Literal, Tuple
7
8
 
8
9
  import numpy as np
9
10
  import numpy.typing as npt
10
11
  import torch
11
- import tqdm
12
12
  from torchvision import tv_tensors
13
13
  from torchvision.datasets import utils
14
14
  from typing_extensions import override
15
15
 
16
+ from eva.core.utils import io as core_io
17
+ from eva.core.utils import multiprocessing
16
18
  from eva.vision.data.datasets import _validators, structs
17
19
  from eva.vision.data.datasets.segmentation import base
18
20
  from eva.vision.utils import io
@@ -65,6 +67,8 @@ class TotalSegmentator2D(base.ImageSegmentation):
65
67
  download: bool = False,
66
68
  classes: List[str] | None = None,
67
69
  optimize_mask_loading: bool = True,
70
+ decompress: bool = True,
71
+ num_workers: int = 10,
68
72
  transforms: Callable | None = None,
69
73
  ) -> None:
70
74
  """Initialize dataset.
@@ -85,8 +89,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
85
89
  in order to optimize the loading time. In the `setup` method, it
86
90
  will reformat the binary one-hot masks to a semantic mask and store
87
91
  it on disk.
92
+ decompress: Whether to decompress the ct.nii.gz files when preparing the data.
93
+ The label masks won't be decompressed, but when enabling optimize_mask_loading
94
+ it will export the semantic label masks to a single file in uncompressed .nii
95
+ format.
96
+ num_workers: The number of workers to use for optimizing the masks &
97
+ decompressing the .gz files.
88
98
  transforms: A function/transforms that takes in an image and a target
89
99
  mask and returns the transformed versions of both.
100
+
90
101
  """
91
102
  super().__init__(transforms=transforms)
92
103
 
@@ -96,6 +107,8 @@ class TotalSegmentator2D(base.ImageSegmentation):
96
107
  self._download = download
97
108
  self._classes = classes
98
109
  self._optimize_mask_loading = optimize_mask_loading
110
+ self._decompress = decompress
111
+ self._num_workers = num_workers
99
112
 
100
113
  if self._optimize_mask_loading and self._classes is not None:
101
114
  raise ValueError(
@@ -128,23 +141,29 @@ class TotalSegmentator2D(base.ImageSegmentation):
128
141
  def class_to_idx(self) -> Dict[str, int]:
129
142
  return {label: index for index, label in enumerate(self.classes)}
130
143
 
144
+ @property
145
+ def _file_suffix(self) -> str:
146
+ return "nii" if self._decompress else "nii.gz"
147
+
131
148
  @override
132
- def filename(self, index: int, segmented: bool = True) -> str:
149
+ def filename(self, index: int) -> str:
133
150
  sample_idx, _ = self._indices[index]
134
151
  sample_dir = self._samples_dirs[sample_idx]
135
- return os.path.join(sample_dir, "ct.nii.gz")
152
+ return os.path.join(sample_dir, f"ct.{self._file_suffix}")
136
153
 
137
154
  @override
138
155
  def prepare_data(self) -> None:
139
156
  if self._download:
140
157
  self._download_dataset()
158
+ if self._decompress:
159
+ self._decompress_files()
160
+ self._samples_dirs = self._fetch_samples_dirs()
161
+ if self._optimize_mask_loading:
162
+ self._export_semantic_label_masks()
141
163
 
142
164
  @override
143
165
  def configure(self) -> None:
144
- self._samples_dirs = self._fetch_samples_dirs()
145
166
  self._indices = self._create_indices()
146
- if self._optimize_mask_loading:
147
- self._export_semantic_label_masks()
148
167
 
149
168
  @override
150
169
  def validate(self) -> None:
@@ -186,16 +205,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
186
205
  return {"slice_index": slice_index}
187
206
 
188
207
  def _load_mask(self, index: int) -> tv_tensors.Mask:
189
- """Loads and builds the segmentation mask from NifTi files."""
190
208
  sample_index, slice_index = self._indices[index]
191
209
  semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
192
- return tv_tensors.Mask(semantic_labels, dtype=torch.int64) # type: ignore[reportCallIssue]
210
+ return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
193
211
 
194
212
  def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
195
213
  """Loads the segmentation mask from a semantic label NifTi file."""
196
214
  sample_index, slice_index = self._indices[index]
197
215
  masks_dir = self._get_masks_dir(sample_index)
198
- filename = os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
216
+ filename = os.path.join(masks_dir, "semantic_labels", "masks.nii")
199
217
  semantic_labels = io.read_nifti(filename, slice_index)
200
218
  return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
201
219
 
@@ -209,7 +227,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
209
227
  slice_index: Whether to return only a specific slice.
210
228
  """
211
229
  masks_dir = self._get_masks_dir(sample_index)
212
- mask_paths = [os.path.join(masks_dir, label + ".nii.gz") for label in self.classes]
230
+ mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in self.classes]
213
231
  binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
214
232
  background_mask = np.zeros_like(binary_masks[0])
215
233
  return np.argmax([background_mask] + binary_masks, axis=0)
@@ -219,24 +237,28 @@ class TotalSegmentator2D(base.ImageSegmentation):
219
237
  total_samples = len(self._samples_dirs)
220
238
  masks_dirs = map(self._get_masks_dir, range(total_samples))
221
239
  semantic_labels = [
222
- (index, os.path.join(directory, "semantic_labels", "masks.nii.gz"))
240
+ (index, os.path.join(directory, "semantic_labels", "masks.nii"))
223
241
  for index, directory in enumerate(masks_dirs)
224
242
  ]
225
243
  to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
226
244
 
227
- for sample_index, filename in tqdm.tqdm(
228
- list(to_export),
229
- desc=">> Exporting optimized semantic masks",
230
- leave=False,
231
- ):
245
+ def _process_mask(sample_index: Any, filename: str) -> None:
232
246
  semantic_labels = self._load_masks_as_semantic_label(sample_index)
233
247
  os.makedirs(os.path.dirname(filename), exist_ok=True)
234
248
  io.save_array_as_nifti(semantic_labels, filename)
235
249
 
250
+ multiprocessing.run_with_threads(
251
+ _process_mask,
252
+ list(to_export),
253
+ num_workers=self._num_workers,
254
+ progress_desc=">> Exporting optimized semantic mask",
255
+ return_results=False,
256
+ )
257
+
236
258
  def _get_image_path(self, sample_index: int) -> str:
237
259
  """Returns the corresponding image path."""
238
260
  sample_dir = self._samples_dirs[sample_index]
239
- return os.path.join(self._root, sample_dir, "ct.nii.gz")
261
+ return os.path.join(self._root, sample_dir, f"ct.{self._file_suffix}")
240
262
 
241
263
  def _get_masks_dir(self, sample_index: int) -> str:
242
264
  """Returns the directory of the corresponding masks."""
@@ -246,7 +268,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
246
268
  def _get_semantic_labels_filename(self, sample_index: int) -> str:
247
269
  """Returns the semantic label filename."""
248
270
  masks_dir = self._get_masks_dir(sample_index)
249
- return os.path.join(masks_dir, "semantic_labels", "masks.nii.gz")
271
+ return os.path.join(masks_dir, "semantic_labels", "masks.nii")
250
272
 
251
273
  def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
252
274
  """Returns the total amount of slices of a sample."""
@@ -320,6 +342,16 @@ class TotalSegmentator2D(base.ImageSegmentation):
320
342
  remove_finished=True,
321
343
  )
322
344
 
345
+ def _decompress_files(self) -> None:
346
+ compressed_paths = Path(self._root).rglob("*/ct.nii.gz")
347
+ multiprocessing.run_with_threads(
348
+ core_io.gunzip_file,
349
+ [(str(path),) for path in compressed_paths],
350
+ num_workers=self._num_workers,
351
+ progress_desc=">> Decompressing .gz files",
352
+ return_results=False,
353
+ )
354
+
323
355
  def _print_license(self) -> None:
324
356
  """Prints the dataset license."""
325
357
  print(f"Dataset license: {self._license}")
@@ -2,8 +2,9 @@
2
2
 
3
3
  import bisect
4
4
  import os
5
- from typing import Callable, List
5
+ from typing import Any, Callable, Dict, List
6
6
 
7
+ import pandas as pd
7
8
  from loguru import logger
8
9
  from torch.utils.data import dataset as torch_datasets
9
10
  from torchvision import tv_tensors
@@ -85,6 +86,17 @@ class WsiDataset(vision.VisionDataset):
85
86
  patch = self._apply_transforms(patch)
86
87
  return patch
87
88
 
89
+ def load_metadata(self, index: int) -> Dict[str, Any]:
90
+ """Loads the metadata for the patch at the specified index."""
91
+ x, y = self._coords.x_y[index]
92
+ return {
93
+ "x": x,
94
+ "y": y,
95
+ "width": self._coords.width,
96
+ "height": self._coords.height,
97
+ "level_idx": self._coords.level_idx,
98
+ }
99
+
88
100
  def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
89
101
  if self._image_transforms is not None:
90
102
  image = self._image_transforms(image)
@@ -105,6 +117,7 @@ class MultiWsiDataset(vision.VisionDataset):
105
117
  overwrite_mpp: float | None = None,
106
118
  backend: str = "openslide",
107
119
  image_transforms: Callable | None = None,
120
+ coords_path: str | None = None,
108
121
  ):
109
122
  """Initializes a new dataset instance.
110
123
 
@@ -118,6 +131,7 @@ class MultiWsiDataset(vision.VisionDataset):
118
131
  sampler: The sampler to use for sampling patch coordinates.
119
132
  backend: The backend to use for reading the whole-slide images.
120
133
  image_transforms: Transforms to apply to the extracted image patches.
134
+ coords_path: File path to save the patch coordinates as .csv.
121
135
  """
122
136
  super().__init__()
123
137
 
@@ -130,6 +144,7 @@ class MultiWsiDataset(vision.VisionDataset):
130
144
  self._sampler = sampler
131
145
  self._backend = backend
132
146
  self._image_transforms = image_transforms
147
+ self._coords_path = coords_path
133
148
 
134
149
  self._concat_dataset: torch_datasets.ConcatDataset
135
150
 
@@ -146,6 +161,7 @@ class MultiWsiDataset(vision.VisionDataset):
146
161
  @override
147
162
  def configure(self) -> None:
148
163
  self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
164
+ self._save_coords_to_file()
149
165
 
150
166
  @override
151
167
  def __len__(self) -> int:
@@ -159,6 +175,12 @@ class MultiWsiDataset(vision.VisionDataset):
159
175
  def filename(self, index: int) -> str:
160
176
  return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
161
177
 
178
+ def load_metadata(self, index: int) -> Dict[str, Any]:
179
+ """Loads the metadata for the patch at the specified index."""
180
+ dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
181
+ patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
182
+ return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata
183
+
162
184
  def _load_datasets(self) -> list[WsiDataset]:
163
185
  logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
164
186
  wsi_datasets = []
@@ -185,3 +207,17 @@ class MultiWsiDataset(vision.VisionDataset):
185
207
 
186
208
  def _get_dataset_idx(self, index: int) -> int:
187
209
  return bisect.bisect_right(self.cumulative_sizes, index)
210
+
211
+ def _get_sample_idx(self, index: int) -> int:
212
+ dataset_idx = self._get_dataset_idx(index)
213
+ return index if dataset_idx == 0 else index - self.cumulative_sizes[dataset_idx - 1]
214
+
215
+ def _save_coords_to_file(self):
216
+ if self._coords_path is not None:
217
+ coords = [
218
+ {"file": self._file_paths[i]} | dataset._coords.to_dict()
219
+ for i, dataset in enumerate(self.datasets)
220
+ ]
221
+ os.makedirs(os.path.abspath(os.path.join(self._coords_path, os.pardir)), exist_ok=True)
222
+ pd.DataFrame(coords).to_csv(self._coords_path, index=False)
223
+ logger.info(f"Saved patch coordinates to: {self._coords_path}")
@@ -2,7 +2,7 @@
2
2
 
3
3
  import dataclasses
4
4
  import functools
5
- from typing import List, Tuple
5
+ from typing import Any, Dict, List, Tuple
6
6
 
7
7
  from eva.vision.data.wsi import backends
8
8
  from eva.vision.data.wsi.patching import samplers
@@ -75,6 +75,14 @@ class PatchCoordinates:
75
75
 
76
76
  return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
77
77
 
78
+ def to_dict(self, include_keys: List[str] | None = None) -> Dict[str, Any]:
79
+ """Convert the coordinates to a dictionary."""
80
+ include_keys = include_keys or ["x_y", "width", "height", "level_idx"]
81
+ coord_dict = dataclasses.asdict(self)
82
+ if include_keys:
83
+ coord_dict = {key: coord_dict[key] for key in include_keys}
84
+ return coord_dict
85
+
78
86
 
79
87
  @functools.lru_cache(LRU_CACHE_SIZE)
80
88
  def get_cached_coords(
@@ -1,14 +1,8 @@
1
- import random
2
1
  from typing import Tuple
3
2
 
4
3
  import numpy as np
5
4
 
6
5
 
7
- def set_seed(seed: int) -> None:
8
- random.seed(seed)
9
- np.random.seed(seed)
10
-
11
-
12
6
  def get_grid_coords_and_indices(
13
7
  layer_shape: Tuple[int, int],
14
8
  width: int,
@@ -33,8 +27,8 @@ def get_grid_coords_and_indices(
33
27
 
34
28
  indices = list(range(len(x_y)))
35
29
  if shuffle:
36
- set_seed(seed)
37
- np.random.shuffle(indices)
30
+ random_generator = np.random.default_rng(seed)
31
+ random_generator.shuffle(indices)
38
32
  return x_y, indices
39
33
 
40
34
 
@@ -18,6 +18,7 @@ class RandomSampler(base.Sampler):
18
18
  """Initializes the sampler."""
19
19
  self.seed = seed
20
20
  self.n_samples = n_samples
21
+ self.random_generator = random.Random(seed) # nosec
21
22
 
22
23
  def sample(
23
24
  self,
@@ -33,9 +34,10 @@ class RandomSampler(base.Sampler):
33
34
  layer_shape: The shape of the layer.
34
35
  """
35
36
  _utils.validate_dimensions(width, height, layer_shape)
36
- _utils.set_seed(self.seed)
37
37
 
38
38
  x_max, y_max = layer_shape[0], layer_shape[1]
39
39
  for _ in range(self.n_samples):
40
- x, y = random.randint(0, x_max - width), random.randint(0, y_max - height) # nosec
40
+ x, y = self.random_generator.randint(0, x_max - width), self.random_generator.randint(
41
+ 0, y_max - height
42
+ )
41
43
  yield x, y
@@ -1,5 +1,5 @@
1
1
  """Loss functions API."""
2
2
 
3
- from eva.vision.losses.dice import DiceLoss
3
+ from eva.vision.losses.dice import DiceCELoss, DiceLoss
4
4
 
5
- __all__ = ["DiceLoss"]
5
+ __all__ = ["DiceLoss", "DiceCELoss"]