kaiko-eva 0.1.0__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (63) hide show
  1. eva/core/callbacks/writers/embeddings/base.py +3 -4
  2. eva/core/data/dataloaders/dataloader.py +2 -2
  3. eva/core/data/splitting/random.py +6 -5
  4. eva/core/data/splitting/stratified.py +12 -6
  5. eva/core/losses/__init__.py +5 -0
  6. eva/core/losses/cross_entropy.py +27 -0
  7. eva/core/metrics/__init__.py +0 -4
  8. eva/core/metrics/defaults/__init__.py +0 -2
  9. eva/core/models/modules/module.py +9 -9
  10. eva/core/models/transforms/extract_cls_features.py +17 -9
  11. eva/core/models/transforms/extract_patch_features.py +23 -11
  12. eva/core/utils/progress_bar.py +15 -0
  13. eva/vision/data/datasets/__init__.py +4 -0
  14. eva/vision/data/datasets/classification/__init__.py +2 -1
  15. eva/vision/data/datasets/classification/camelyon16.py +4 -1
  16. eva/vision/data/datasets/classification/panda.py +17 -1
  17. eva/vision/data/datasets/classification/wsi.py +4 -1
  18. eva/vision/data/datasets/segmentation/__init__.py +2 -0
  19. eva/vision/data/datasets/segmentation/consep.py +2 -2
  20. eva/vision/data/datasets/segmentation/lits.py +49 -29
  21. eva/vision/data/datasets/segmentation/lits_balanced.py +93 -0
  22. eva/vision/data/datasets/segmentation/monusac.py +7 -7
  23. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +2 -2
  24. eva/vision/data/datasets/wsi.py +37 -1
  25. eva/vision/data/wsi/patching/coordinates.py +9 -1
  26. eva/vision/data/wsi/patching/samplers/_utils.py +2 -8
  27. eva/vision/data/wsi/patching/samplers/random.py +4 -2
  28. eva/vision/losses/__init__.py +2 -2
  29. eva/vision/losses/dice.py +75 -8
  30. eva/vision/metrics/__init__.py +11 -0
  31. eva/vision/metrics/defaults/__init__.py +7 -0
  32. eva/{core → vision}/metrics/defaults/segmentation/__init__.py +1 -1
  33. eva/{core → vision}/metrics/defaults/segmentation/multiclass.py +2 -1
  34. eva/vision/metrics/segmentation/BUILD +1 -0
  35. eva/vision/metrics/segmentation/__init__.py +9 -0
  36. eva/vision/metrics/segmentation/_utils.py +69 -0
  37. eva/{core/metrics → vision/metrics/segmentation}/generalized_dice.py +12 -10
  38. eva/vision/metrics/segmentation/mean_iou.py +57 -0
  39. eva/vision/models/modules/semantic_segmentation.py +4 -3
  40. eva/vision/models/networks/backbones/_utils.py +12 -0
  41. eva/vision/models/networks/backbones/pathology/__init__.py +4 -1
  42. eva/vision/models/networks/backbones/pathology/histai.py +8 -2
  43. eva/vision/models/networks/backbones/pathology/mahmood.py +2 -9
  44. eva/vision/models/networks/backbones/pathology/owkin.py +14 -0
  45. eva/vision/models/networks/backbones/pathology/paige.py +51 -0
  46. eva/vision/models/networks/decoders/__init__.py +1 -1
  47. eva/vision/models/networks/decoders/segmentation/__init__.py +12 -4
  48. eva/vision/models/networks/decoders/segmentation/base.py +16 -0
  49. eva/vision/models/networks/decoders/segmentation/{conv2d.py → decoder2d.py} +26 -22
  50. eva/vision/models/networks/decoders/segmentation/linear.py +2 -2
  51. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +12 -0
  52. eva/vision/models/networks/decoders/segmentation/{common.py → semantic/common.py} +3 -3
  53. eva/vision/models/networks/decoders/segmentation/semantic/with_image.py +94 -0
  54. eva/vision/models/networks/decoders/segmentation/typings.py +18 -0
  55. eva/vision/utils/io/__init__.py +7 -1
  56. eva/vision/utils/io/nifti.py +19 -4
  57. {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/METADATA +3 -34
  58. {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/RECORD +61 -48
  59. {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/WHEEL +1 -1
  60. eva/core/metrics/mean_iou.py +0 -120
  61. eva/vision/models/networks/decoders/decoder.py +0 -7
  62. {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/entry_points.txt +0 -0
  63. {kaiko_eva-0.1.0.dist-info → kaiko_eva-0.1.3.dist-info}/licenses/LICENSE +0 -0
@@ -5,12 +5,14 @@ import glob
5
5
  import os
6
6
  from typing import Any, Callable, Dict, List, Literal, Tuple
7
7
 
8
+ import numpy as np
9
+ import numpy.typing as npt
8
10
  import torch
9
11
  from torchvision import tv_tensors
10
12
  from typing_extensions import override
11
13
 
12
14
  from eva.core import utils
13
- from eva.vision.data.datasets import _utils as data_utils
15
+ from eva.core.data import splitting
14
16
  from eva.vision.data.datasets import _validators
15
17
  from eva.vision.data.datasets.segmentation import base
16
18
  from eva.vision.utils import io
@@ -20,22 +22,23 @@ class LiTS(base.ImageSegmentation):
20
22
  """LiTS - Liver Tumor Segmentation Challenge.
21
23
 
22
24
  Webpage: https://competitions.codalab.org/competitions/17094
23
-
24
- For the splits we follow: https://arxiv.org/pdf/2010.01663v2
25
25
  """
26
26
 
27
- _train_index_ranges: List[Tuple[int, int]] = [(0, 102)]
28
- _val_index_ranges: List[Tuple[int, int]] = [(102, 117)]
29
- _test_index_ranges: List[Tuple[int, int]] = [(117, 131)]
27
+ _train_ratio: float = 0.7
28
+ _val_ratio: float = 0.15
29
+ _test_ratio: float = 0.15
30
30
  """Index ranges per split."""
31
31
 
32
+ _fix_orientation: bool = True
33
+ """Whether to fix the orientation of the images to match the default for radiologists."""
34
+
32
35
  _sample_every_n_slices: int | None = None
33
36
  """The amount of slices to sub-sample per 3D CT scan image."""
34
37
 
35
38
  _expected_dataset_lengths: Dict[str | None, int] = {
36
- "train": 39307,
37
- "val": 12045,
38
- "test": 7286,
39
+ "train": 38686,
40
+ "val": 11192,
41
+ "test": 8760,
39
42
  None: 58638,
40
43
  }
41
44
  """Dataset version and split to the expected size."""
@@ -51,6 +54,7 @@ class LiTS(base.ImageSegmentation):
51
54
  root: str,
52
55
  split: Literal["train", "val", "test"] | None = None,
53
56
  transforms: Callable | None = None,
57
+ seed: int = 8,
54
58
  ) -> None:
55
59
  """Initialize dataset.
56
60
 
@@ -60,12 +64,13 @@ class LiTS(base.ImageSegmentation):
60
64
  split: Dataset split to use.
61
65
  transforms: A function/transforms that takes in an image and a target
62
66
  mask and returns the transformed versions of both.
67
+ seed: Seed used for generating the dataset splits.
63
68
  """
64
69
  super().__init__(transforms=transforms)
65
70
 
66
71
  self._root = root
67
72
  self._split = split
68
-
73
+ self._seed = seed
69
74
  self._indices: List[Tuple[int, int]] = []
70
75
 
71
76
  @property
@@ -90,10 +95,12 @@ class LiTS(base.ImageSegmentation):
90
95
 
91
96
  @override
92
97
  def validate(self) -> None:
93
- if len(self._volume_files) != len(self._segmentation_files):
94
- raise ValueError(
95
- "The number of volume files does not match the number of the segmentation ones."
96
- )
98
+ for i in range(len(self._volume_files)):
99
+ seg_path = self._segmentation_file(i)
100
+ if not os.path.exists(seg_path):
101
+ raise FileNotFoundError(
102
+ f"Segmentation file {seg_path} not found for volume {self._volume_files[i]}."
103
+ )
97
104
 
98
105
  _validators.check_dataset_integrity(
99
106
  self,
@@ -107,15 +114,27 @@ class LiTS(base.ImageSegmentation):
107
114
  sample_index, slice_index = self._indices[index]
108
115
  volume_path = self._volume_files[sample_index]
109
116
  image_array = io.read_nifti(volume_path, slice_index)
117
+ if self._fix_orientation:
118
+ image_array = self._orientation(image_array, sample_index)
110
119
  return tv_tensors.Image(image_array.transpose(2, 0, 1))
111
120
 
112
121
  @override
113
122
  def load_mask(self, index: int) -> tv_tensors.Mask:
114
123
  sample_index, slice_index = self._indices[index]
115
- segmentation_path = self._segmentation_files[sample_index]
124
+ segmentation_path = self._segmentation_file(sample_index)
116
125
  semantic_labels = io.read_nifti(segmentation_path, slice_index)
126
+ if self._fix_orientation:
127
+ semantic_labels = self._orientation(semantic_labels, sample_index)
117
128
  return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
118
129
 
130
+ def _orientation(self, array: npt.NDArray, sample_index: int) -> npt.NDArray:
131
+ volume_path = self._volume_files[sample_index]
132
+ orientation = io.fetch_nifti_axis_direction_code(volume_path)
133
+ array = np.rot90(array, axes=(0, 1))
134
+ if orientation == "LPS":
135
+ array = np.flip(array, axis=0)
136
+ return array.copy()
137
+
119
138
  @override
120
139
  def load_metadata(self, index: int) -> Dict[str, Any]:
121
140
  _, slice_index = self._indices[index]
@@ -137,11 +156,10 @@ class LiTS(base.ImageSegmentation):
137
156
  files = glob.glob(files_pattern, recursive=True)
138
157
  return utils.numeric_sort(files)
139
158
 
140
- @functools.cached_property
141
- def _segmentation_files(self) -> List[str]:
142
- files_pattern = os.path.join(self._root, "**", "segmentation-*.nii")
143
- files = glob.glob(files_pattern, recursive=True)
144
- return utils.numeric_sort(files)
159
+ def _segmentation_file(self, index: int) -> str:
160
+ volume_file_path = self._volume_files[index]
161
+ segmentation_file = os.path.basename(volume_file_path).replace("volume", "segmentation")
162
+ return os.path.join(os.path.dirname(volume_file_path), segmentation_file)
145
163
 
146
164
  def _create_indices(self) -> List[Tuple[int, int]]:
147
165
  """Builds the dataset indices for the specified split.
@@ -161,17 +179,19 @@ class LiTS(base.ImageSegmentation):
161
179
 
162
180
  def _get_split_indices(self) -> List[int]:
163
181
  """Returns the sample indices for the specified dataset split."""
164
- split_index_ranges = {
165
- "train": self._train_index_ranges,
166
- "val": self._val_index_ranges,
167
- "test": self._test_index_ranges,
168
- None: [(0, len(self._volume_files))],
182
+ indices = list(range(len(self._volume_files)))
183
+ train_indices, val_indices, test_indices = splitting.random_split(
184
+ indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed
185
+ )
186
+ split_indices_dict = {
187
+ "train": train_indices,
188
+ "val": val_indices,
189
+ "test": test_indices,
190
+ None: indices,
169
191
  }
170
- index_ranges = split_index_ranges.get(self._split)
171
- if index_ranges is None:
192
+ if self._split not in split_indices_dict:
172
193
  raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.")
173
-
174
- return data_utils.ranges_to_indices(index_ranges)
194
+ return list(split_indices_dict[self._split])
175
195
 
176
196
  def _print_license(self) -> None:
177
197
  """Prints the dataset license."""
@@ -0,0 +1,93 @@
1
+ """Balanced LiTS dataset."""
2
+
3
+ from typing import Callable, Dict, List, Literal, Tuple
4
+
5
+ import numpy as np
6
+ from typing_extensions import override
7
+
8
+ from eva.vision.data.datasets.segmentation import lits
9
+ from eva.vision.utils import io
10
+
11
+
12
+ class LiTSBalanced(lits.LiTS):
13
+ """Balanced version of the LiTS - Liver Tumor Segmentation Challenge dataset.
14
+
15
+ For each volume in the dataset, we sample the same number of slices where
16
+ only the liver and where both liver and tumor are present.
17
+
18
+ Webpage: https://competitions.codalab.org/competitions/17094
19
+
20
+ For the splits we follow: https://arxiv.org/pdf/2010.01663v2
21
+ """
22
+
23
+ _expected_dataset_lengths: Dict[str | None, int] = {
24
+ "train": 5514,
25
+ "val": 1332,
26
+ "test": 1530,
27
+ None: 8376,
28
+ }
29
+ """Dataset version and split to the expected size."""
30
+
31
+ def __init__(
32
+ self,
33
+ root: str,
34
+ split: Literal["train", "val", "test"] | None = None,
35
+ transforms: Callable | None = None,
36
+ seed: int = 8,
37
+ ) -> None:
38
+ """Initialize dataset.
39
+
40
+ Args:
41
+ root: Path to the root directory of the dataset. The dataset will
42
+ be downloaded and extracted here, if it does not already exist.
43
+ split: Dataset split to use.
44
+ transforms: A function/transforms that takes in an image and a target
45
+ mask and returns the transformed versions of both.
46
+ seed: Seed used for generating the dataset splits and sampling of the slices.
47
+ """
48
+ super().__init__(root=root, split=split, transforms=transforms, seed=seed)
49
+
50
+ @override
51
+ def _create_indices(self) -> List[Tuple[int, int]]:
52
+ """Builds the dataset indices for the specified split.
53
+
54
+ Returns:
55
+ A list of tuples, where the first value indicates the
56
+ sample index which the second its corresponding slice
57
+ index.
58
+ """
59
+ split_indices = set(self._get_split_indices())
60
+ indices: List[Tuple[int, int]] = []
61
+ random_generator = np.random.default_rng(seed=self._seed)
62
+
63
+ for sample_idx in range(len(self._volume_files)):
64
+ if sample_idx not in split_indices:
65
+ continue
66
+
67
+ segmentation = io.read_nifti(self._segmentation_file(sample_idx))
68
+ tumor_filter = segmentation == 2
69
+ tumor_slice_filter = tumor_filter.sum(axis=(0, 1)) > 0
70
+
71
+ if tumor_filter.sum() == 0:
72
+ continue
73
+
74
+ liver_filter = segmentation == 1
75
+ liver_slice_filter = liver_filter.sum(axis=(0, 1)) > 0
76
+
77
+ liver_and_tumor_filter = liver_slice_filter & tumor_slice_filter
78
+ liver_only_filter = liver_slice_filter & ~tumor_slice_filter
79
+
80
+ n_slice_samples = min(liver_and_tumor_filter.sum(), liver_only_filter.sum())
81
+ tumor_indices = list(np.where(liver_and_tumor_filter)[0])
82
+ tumor_indices = list(
83
+ random_generator.choice(tumor_indices, size=n_slice_samples, replace=False)
84
+ )
85
+
86
+ liver_indices = list(np.where(liver_only_filter)[0])
87
+ liver_indices = list(
88
+ random_generator.choice(liver_indices, size=n_slice_samples, replace=False)
89
+ )
90
+
91
+ indices.extend([(sample_idx, slice_idx) for slice_idx in tumor_indices + liver_indices])
92
+
93
+ return list(indices)
@@ -10,12 +10,12 @@ import imagesize
10
10
  import numpy as np
11
11
  import numpy.typing as npt
12
12
  import torch
13
- import tqdm
14
13
  from skimage import draw
15
14
  from torchvision import tv_tensors
16
15
  from torchvision.datasets import utils
17
16
  from typing_extensions import override
18
17
 
18
+ from eva.core.utils.progress_bar import tqdm
19
19
  from eva.vision.data.datasets import _validators, structs
20
20
  from eva.vision.data.datasets.segmentation import base
21
21
  from eva.vision.utils import io
@@ -84,7 +84,7 @@ class MoNuSAC(base.ImageSegmentation):
84
84
  @property
85
85
  @override
86
86
  def classes(self) -> List[str]:
87
- return ["Epithelial", "Lymphocyte", "Neutrophil", "Macrophage"]
87
+ return ["Background", "Epithelial", "Lymphocyte", "Neutrophil", "Macrophage", "Ambiguous"]
88
88
 
89
89
  @functools.cached_property
90
90
  @override
@@ -107,8 +107,8 @@ class MoNuSAC(base.ImageSegmentation):
107
107
  _validators.check_dataset_integrity(
108
108
  self,
109
109
  length=self._expected_dataset_lengths.get(self._split, 0),
110
- n_classes=4,
111
- first_and_last_labels=("Epithelial", "Macrophage"),
110
+ n_classes=6,
111
+ first_and_last_labels=("Background", "Ambiguous"),
112
112
  )
113
113
 
114
114
  @override
@@ -161,7 +161,7 @@ class MoNuSAC(base.ImageSegmentation):
161
161
  for index, filename in enumerate(self._image_files)
162
162
  ]
163
163
  to_export = filter(lambda x: not os.path.isfile(x[1]), mask_files)
164
- for sample_index, filename in tqdm.tqdm(
164
+ for sample_index, filename in tqdm(
165
165
  list(to_export),
166
166
  desc=">> Exporting semantic masks",
167
167
  leave=False,
@@ -199,9 +199,9 @@ class MoNuSAC(base.ImageSegmentation):
199
199
  semantic_labels = np.zeros((height, width), "uint8") # type: ignore[reportCallIssue]
200
200
  for level in range(len(root)):
201
201
  label = [item.attrib["Name"] for item in root[level][0]][0]
202
- class_id = self.class_to_idx.get(label, 254) + 1
202
+ class_id = self.class_to_idx.get(label, self.class_to_idx["Ambiguous"])
203
203
  # for the test dataset an additional class 'Ambiguous' was added for
204
- # difficult regions with fuzzy boundaries - we return it as 255
204
+ # difficult regions with fuzzy boundaries
205
205
  regions = [item for child in root[level] for item in child if item.tag == "Region"]
206
206
  for region in regions:
207
207
  vertices = np.array(
@@ -8,11 +8,11 @@ from typing import Any, Callable, Dict, List, Literal, Tuple
8
8
  import numpy as np
9
9
  import numpy.typing as npt
10
10
  import torch
11
- import tqdm
12
11
  from torchvision import tv_tensors
13
12
  from torchvision.datasets import utils
14
13
  from typing_extensions import override
15
14
 
15
+ from eva.core.utils.progress_bar import tqdm
16
16
  from eva.vision.data.datasets import _validators, structs
17
17
  from eva.vision.data.datasets.segmentation import base
18
18
  from eva.vision.utils import io
@@ -224,7 +224,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
224
224
  ]
225
225
  to_export = filter(lambda x: not os.path.isfile(x[1]), semantic_labels)
226
226
 
227
- for sample_index, filename in tqdm.tqdm(
227
+ for sample_index, filename in tqdm(
228
228
  list(to_export),
229
229
  desc=">> Exporting optimized semantic masks",
230
230
  leave=False,
@@ -2,8 +2,9 @@
2
2
 
3
3
  import bisect
4
4
  import os
5
- from typing import Callable, List
5
+ from typing import Any, Callable, Dict, List
6
6
 
7
+ import pandas as pd
7
8
  from loguru import logger
8
9
  from torch.utils.data import dataset as torch_datasets
9
10
  from torchvision import tv_tensors
@@ -85,6 +86,17 @@ class WsiDataset(vision.VisionDataset):
85
86
  patch = self._apply_transforms(patch)
86
87
  return patch
87
88
 
89
+ def load_metadata(self, index: int) -> Dict[str, Any]:
90
+ """Loads the metadata for the patch at the specified index."""
91
+ x, y = self._coords.x_y[index]
92
+ return {
93
+ "x": x,
94
+ "y": y,
95
+ "width": self._coords.width,
96
+ "height": self._coords.height,
97
+ "level_idx": self._coords.level_idx,
98
+ }
99
+
88
100
  def _apply_transforms(self, image: tv_tensors.Image) -> tv_tensors.Image:
89
101
  if self._image_transforms is not None:
90
102
  image = self._image_transforms(image)
@@ -105,6 +117,7 @@ class MultiWsiDataset(vision.VisionDataset):
105
117
  overwrite_mpp: float | None = None,
106
118
  backend: str = "openslide",
107
119
  image_transforms: Callable | None = None,
120
+ coords_path: str | None = None,
108
121
  ):
109
122
  """Initializes a new dataset instance.
110
123
 
@@ -118,6 +131,7 @@ class MultiWsiDataset(vision.VisionDataset):
118
131
  sampler: The sampler to use for sampling patch coordinates.
119
132
  backend: The backend to use for reading the whole-slide images.
120
133
  image_transforms: Transforms to apply to the extracted image patches.
134
+ coords_path: File path to save the patch coordinates as .csv.
121
135
  """
122
136
  super().__init__()
123
137
 
@@ -130,6 +144,7 @@ class MultiWsiDataset(vision.VisionDataset):
130
144
  self._sampler = sampler
131
145
  self._backend = backend
132
146
  self._image_transforms = image_transforms
147
+ self._coords_path = coords_path
133
148
 
134
149
  self._concat_dataset: torch_datasets.ConcatDataset
135
150
 
@@ -146,6 +161,7 @@ class MultiWsiDataset(vision.VisionDataset):
146
161
  @override
147
162
  def configure(self) -> None:
148
163
  self._concat_dataset = torch_datasets.ConcatDataset(datasets=self._load_datasets())
164
+ self._save_coords_to_file()
149
165
 
150
166
  @override
151
167
  def __len__(self) -> int:
@@ -159,6 +175,12 @@ class MultiWsiDataset(vision.VisionDataset):
159
175
  def filename(self, index: int) -> str:
160
176
  return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
161
177
 
178
+ def load_metadata(self, index: int) -> Dict[str, Any]:
179
+ """Loads the metadata for the patch at the specified index."""
180
+ dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index)
181
+ patch_metadata = self.datasets[dataset_index].load_metadata(sample_index)
182
+ return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata
183
+
162
184
  def _load_datasets(self) -> list[WsiDataset]:
163
185
  logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...")
164
186
  wsi_datasets = []
@@ -185,3 +207,17 @@ class MultiWsiDataset(vision.VisionDataset):
185
207
 
186
208
  def _get_dataset_idx(self, index: int) -> int:
187
209
  return bisect.bisect_right(self.cumulative_sizes, index)
210
+
211
+ def _get_sample_idx(self, index: int) -> int:
212
+ dataset_idx = self._get_dataset_idx(index)
213
+ return index if dataset_idx == 0 else index - self.cumulative_sizes[dataset_idx - 1]
214
+
215
+ def _save_coords_to_file(self):
216
+ if self._coords_path is not None:
217
+ coords = [
218
+ {"file": self._file_paths[i]} | dataset._coords.to_dict()
219
+ for i, dataset in enumerate(self.datasets)
220
+ ]
221
+ os.makedirs(os.path.abspath(os.path.join(self._coords_path, os.pardir)), exist_ok=True)
222
+ pd.DataFrame(coords).to_csv(self._coords_path, index=False)
223
+ logger.info(f"Saved patch coordinates to: {self._coords_path}")
@@ -2,7 +2,7 @@
2
2
 
3
3
  import dataclasses
4
4
  import functools
5
- from typing import List, Tuple
5
+ from typing import Any, Dict, List, Tuple
6
6
 
7
7
  from eva.vision.data.wsi import backends
8
8
  from eva.vision.data.wsi.patching import samplers
@@ -75,6 +75,14 @@ class PatchCoordinates:
75
75
 
76
76
  return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))
77
77
 
78
+ def to_dict(self, include_keys: List[str] | None = None) -> Dict[str, Any]:
79
+ """Convert the coordinates to a dictionary."""
80
+ include_keys = include_keys or ["x_y", "width", "height", "level_idx"]
81
+ coord_dict = dataclasses.asdict(self)
82
+ if include_keys:
83
+ coord_dict = {key: coord_dict[key] for key in include_keys}
84
+ return coord_dict
85
+
78
86
 
79
87
  @functools.lru_cache(LRU_CACHE_SIZE)
80
88
  def get_cached_coords(
@@ -1,14 +1,8 @@
1
- import random
2
1
  from typing import Tuple
3
2
 
4
3
  import numpy as np
5
4
 
6
5
 
7
- def set_seed(seed: int) -> None:
8
- random.seed(seed)
9
- np.random.seed(seed)
10
-
11
-
12
6
  def get_grid_coords_and_indices(
13
7
  layer_shape: Tuple[int, int],
14
8
  width: int,
@@ -33,8 +27,8 @@ def get_grid_coords_and_indices(
33
27
 
34
28
  indices = list(range(len(x_y)))
35
29
  if shuffle:
36
- set_seed(seed)
37
- np.random.shuffle(indices)
30
+ random_generator = np.random.default_rng(seed)
31
+ random_generator.shuffle(indices)
38
32
  return x_y, indices
39
33
 
40
34
 
@@ -18,6 +18,7 @@ class RandomSampler(base.Sampler):
18
18
  """Initializes the sampler."""
19
19
  self.seed = seed
20
20
  self.n_samples = n_samples
21
+ self.random_generator = random.Random(seed) # nosec
21
22
 
22
23
  def sample(
23
24
  self,
@@ -33,9 +34,10 @@ class RandomSampler(base.Sampler):
33
34
  layer_shape: The shape of the layer.
34
35
  """
35
36
  _utils.validate_dimensions(width, height, layer_shape)
36
- _utils.set_seed(self.seed)
37
37
 
38
38
  x_max, y_max = layer_shape[0], layer_shape[1]
39
39
  for _ in range(self.n_samples):
40
- x, y = random.randint(0, x_max - width), random.randint(0, y_max - height) # nosec
40
+ x, y = self.random_generator.randint(0, x_max - width), self.random_generator.randint(
41
+ 0, y_max - height
42
+ )
41
43
  yield x, y
@@ -1,5 +1,5 @@
1
1
  """Loss functions API."""
2
2
 
3
- from eva.vision.losses.dice import DiceLoss
3
+ from eva.vision.losses.dice import DiceCELoss, DiceLoss
4
4
 
5
- __all__ = ["DiceLoss"]
5
+ __all__ = ["DiceLoss", "DiceCELoss"]
eva/vision/losses/dice.py CHANGED
@@ -1,4 +1,6 @@
1
- """Dice loss."""
1
+ """Dice based loss functions."""
2
+
3
+ from typing import Sequence, Tuple
2
4
 
3
5
  import torch
4
6
  from monai import losses
@@ -12,29 +14,94 @@ class DiceLoss(losses.DiceLoss): # type: ignore
12
14
  Extends the implementation from MONAI
13
15
  - to support semantic target labels (meaning targets of shape BHW)
14
16
  - to support `ignore_index` functionality
17
+ - accept weight argument in list format
15
18
  """
16
19
 
17
- def __init__(self, *args, ignore_index: int | None = None, **kwargs) -> None:
18
- """Initialize the DiceLoss with support for ignore_index.
20
+ def __init__(
21
+ self,
22
+ *args,
23
+ ignore_index: int | None = None,
24
+ weight: Sequence[float] | torch.Tensor | None = None,
25
+ **kwargs,
26
+ ) -> None:
27
+ """Initialize the DiceLoss.
19
28
 
20
29
  Args:
21
30
  args: Positional arguments from the base class.
22
31
  ignore_index: Specifies a target value that is ignored and
23
32
  does not contribute to the input gradient.
33
+ weight: A list of weights to assign to each class.
24
34
  kwargs: Key-word arguments from the base class.
25
35
  """
26
- super().__init__(*args, **kwargs)
36
+ if weight is not None and not isinstance(weight, torch.Tensor):
37
+ weight = torch.tensor(weight)
38
+
39
+ super().__init__(*args, **kwargs, weight=weight)
27
40
 
28
41
  self.ignore_index = ignore_index
29
42
 
30
43
  @override
31
44
  def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
32
- if self.ignore_index is not None:
33
- mask = targets != self.ignore_index
34
- targets = targets * mask
35
- inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
45
+ inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
46
+ targets = _to_one_hot(targets, num_classes=inputs.shape[1])
36
47
 
37
48
  if targets.ndim == 3:
38
49
  targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1])
39
50
 
40
51
  return super().forward(inputs, targets)
52
+
53
+
54
+ class DiceCELoss(losses.dice.DiceCELoss):
55
+ """Combination of Dice and Cross Entropy Loss.
56
+
57
+ Extends the implementation from MONAI
58
+ - to support semantic target labels (meaning targets of shape BHW)
59
+ - to support `ignore_index` functionality
60
+ - accept weight argument in list format
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ *args,
66
+ ignore_index: int | None = None,
67
+ weight: Sequence[float] | torch.Tensor | None = None,
68
+ **kwargs,
69
+ ) -> None:
70
+ """Initialize the DiceCELoss.
71
+
72
+ Args:
73
+ args: Positional arguments from the base class.
74
+ ignore_index: Specifies a target value that is ignored and
75
+ does not contribute to the input gradient.
76
+ weight: A list of weights to assign to each class.
77
+ kwargs: Key-word arguments from the base class.
78
+ """
79
+ if weight is not None and not isinstance(weight, torch.Tensor):
80
+ weight = torch.tensor(weight)
81
+
82
+ super().__init__(*args, **kwargs, weight=weight)
83
+
84
+ self.ignore_index = ignore_index
85
+
86
+ @override
87
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # noqa
88
+ inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index)
89
+ targets = _to_one_hot(targets, num_classes=inputs.shape[1])
90
+
91
+ return super().forward(inputs, targets)
92
+
93
+
94
+ def _apply_ignore_index(
95
+ inputs: torch.Tensor, targets: torch.Tensor, ignore_index: int | None
96
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ if ignore_index is not None:
98
+ mask = targets != ignore_index
99
+ targets = targets * mask
100
+ inputs = torch.mul(inputs, mask.unsqueeze(1) if mask.ndim == 3 else mask)
101
+ return inputs, targets
102
+
103
+
104
+ def _to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor:
105
+ if tensor.ndim == 3:
106
+ return one_hot(tensor[:, None, ...], num_classes=num_classes)
107
+ return tensor
@@ -0,0 +1,11 @@
1
+ """Default metric collections API."""
2
+
3
+ from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
4
+ from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
5
+ from eva.vision.metrics.segmentation.mean_iou import MeanIoU
6
+
7
+ __all__ = [
8
+ "MulticlassSegmentationMetrics",
9
+ "GeneralizedDiceScore",
10
+ "MeanIoU",
11
+ ]
@@ -0,0 +1,7 @@
1
+ """Default metric collections API."""
2
+
3
+ from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics
4
+
5
+ __all__ = [
6
+ "MulticlassSegmentationMetrics",
7
+ ]
@@ -1,5 +1,5 @@
1
1
  """Default segmentation metric collections API."""
2
2
 
3
- from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
3
+ from eva.vision.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics
4
4
 
5
5
  __all__ = ["MulticlassSegmentationMetrics"]
@@ -1,6 +1,7 @@
1
1
  """Default metric collection for multiclass semantic segmentation tasks."""
2
2
 
3
- from eva.core.metrics import generalized_dice, mean_iou, structs
3
+ from eva.core.metrics import structs
4
+ from eva.vision.metrics.segmentation import generalized_dice, mean_iou
4
5
 
5
6
 
6
7
  class MulticlassSegmentationMetrics(structs.MetricCollection):
@@ -0,0 +1 @@
1
+ python_sources()
@@ -0,0 +1,9 @@
1
+ """Segmentation metrics API."""
2
+
3
+ from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
4
+ from eva.vision.metrics.segmentation.mean_iou import MeanIoU
5
+
6
+ __all__ = [
7
+ "GeneralizedDiceScore",
8
+ "MeanIoU",
9
+ ]