kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/data/datasets/classification/embeddings.py +2 -2
  3. eva/core/data/datasets/classification/multi_embeddings.py +2 -2
  4. eva/core/data/datasets/embeddings.py +4 -4
  5. eva/core/data/samplers/classification/balanced.py +19 -18
  6. eva/core/loggers/utils/wandb.py +33 -0
  7. eva/core/models/modules/head.py +5 -3
  8. eva/core/models/modules/typings.py +2 -2
  9. eva/core/models/transforms/__init__.py +2 -1
  10. eva/core/models/transforms/as_discrete.py +57 -0
  11. eva/core/models/wrappers/_utils.py +121 -1
  12. eva/core/trainers/functional.py +8 -5
  13. eva/core/trainers/trainer.py +32 -17
  14. eva/core/utils/suppress_logs.py +28 -0
  15. eva/vision/data/__init__.py +2 -2
  16. eva/vision/data/dataloaders/__init__.py +5 -0
  17. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  18. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  19. eva/vision/data/datasets/__init__.py +10 -2
  20. eva/vision/data/datasets/classification/__init__.py +9 -0
  21. eva/vision/data/datasets/classification/bach.py +3 -4
  22. eva/vision/data/datasets/classification/bracs.py +111 -0
  23. eva/vision/data/datasets/classification/breakhis.py +209 -0
  24. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  25. eva/vision/data/datasets/classification/crc.py +3 -4
  26. eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
  27. eva/vision/data/datasets/classification/mhist.py +3 -4
  28. eva/vision/data/datasets/classification/panda.py +4 -5
  29. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  30. eva/vision/data/datasets/classification/unitopatho.py +158 -0
  31. eva/vision/data/datasets/classification/wsi.py +6 -5
  32. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  33. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  34. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  35. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  36. eva/vision/data/datasets/segmentation/consep.py +6 -7
  37. eva/vision/data/datasets/segmentation/embeddings.py +2 -2
  38. eva/vision/data/datasets/segmentation/lits.py +9 -8
  39. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  40. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  41. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  42. eva/vision/data/datasets/vision.py +95 -4
  43. eva/vision/data/datasets/wsi.py +5 -5
  44. eva/vision/data/transforms/__init__.py +22 -3
  45. eva/vision/data/transforms/common/__init__.py +1 -2
  46. eva/vision/data/transforms/croppad/__init__.py +11 -0
  47. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  48. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  49. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  50. eva/vision/data/transforms/intensity/__init__.py +11 -0
  51. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  52. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  53. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  54. eva/vision/data/transforms/spatial/__init__.py +7 -0
  55. eva/vision/data/transforms/spatial/flip.py +72 -0
  56. eva/vision/data/transforms/spatial/rotate.py +53 -0
  57. eva/vision/data/transforms/spatial/spacing.py +69 -0
  58. eva/vision/data/transforms/utility/__init__.py +5 -0
  59. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  60. eva/vision/data/tv_tensors/__init__.py +5 -0
  61. eva/vision/data/tv_tensors/volume.py +61 -0
  62. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  63. eva/vision/models/modules/semantic_segmentation.py +28 -20
  64. eva/vision/models/networks/backbones/__init__.py +9 -2
  65. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  66. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  67. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  68. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  69. eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
  70. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  72. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  73. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  74. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  75. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  76. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  77. eva/vision/utils/io/__init__.py +2 -0
  78. eva/vision/utils/io/nifti.py +91 -11
  79. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
  80. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
  81. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
  82. eva/vision/data/datasets/classification/base.py +0 -96
  83. eva/vision/data/datasets/segmentation/base.py +0 -96
  84. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  85. eva/vision/data/transforms/normalization/__init__.py +0 -6
  86. eva/vision/data/transforms/normalization/clamp.py +0 -43
  87. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  88. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  89. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  90. eva/vision/metrics/segmentation/BUILD +0 -1
  91. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  92. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  93. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
  94. {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,12 +16,11 @@ from torchvision.datasets import utils
16
16
  from typing_extensions import override
17
17
 
18
18
  from eva.core.utils.progress_bar import tqdm
19
- from eva.vision.data.datasets import _validators, structs
20
- from eva.vision.data.datasets.segmentation import base
19
+ from eva.vision.data.datasets import _validators, structs, vision
21
20
  from eva.vision.utils import io
22
21
 
23
22
 
24
- class MoNuSAC(base.ImageSegmentation):
23
+ class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
25
24
  """MoNuSAC2020: A Multi-organ Nuclei Segmentation and Classification Challenge.
26
25
 
27
26
  Webpage: https://monusac-2020.grand-challenge.org/
@@ -112,13 +111,13 @@ class MoNuSAC(base.ImageSegmentation):
112
111
  )
113
112
 
114
113
  @override
115
- def load_image(self, index: int) -> tv_tensors.Image:
114
+ def load_data(self, index: int) -> tv_tensors.Image:
116
115
  image_path = self._image_files[index]
117
116
  image_rgb_array = io.read_image(image_path)
118
117
  return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
119
118
 
120
119
  @override
121
- def load_mask(self, index: int) -> tv_tensors.Mask:
120
+ def load_target(self, index: int) -> tv_tensors.Mask:
122
121
  semantic_labels = (
123
122
  self._load_semantic_mask_file(index)
124
123
  if self._export_masks
@@ -17,12 +17,12 @@ from typing_extensions import override
17
17
 
18
18
  from eva.core.utils import io as core_io
19
19
  from eva.core.utils import multiprocessing
20
- from eva.vision.data.datasets import _validators, structs
21
- from eva.vision.data.datasets.segmentation import _total_segmentator, base
20
+ from eva.vision.data.datasets import _validators, structs, vision
21
+ from eva.vision.data.datasets.segmentation import _total_segmentator
22
22
  from eva.vision.utils import io
23
23
 
24
24
 
25
- class TotalSegmentator2D(base.ImageSegmentation):
25
+ class TotalSegmentator2D(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
26
26
  """TotalSegmentator 2D segmentation dataset."""
27
27
 
28
28
  _expected_dataset_lengths: Dict[str, int] = {
@@ -206,19 +206,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
206
206
  return len(self._indices)
207
207
 
208
208
  @override
209
- def load_image(self, index: int) -> tv_tensors.Image:
209
+ def load_data(self, index: int) -> tv_tensors.Image:
210
210
  sample_index, slice_index = self._indices[index]
211
211
  image_path = self._get_image_path(sample_index)
212
- image_array = io.read_nifti(image_path, slice_index)
212
+ image_nii = io.read_nifti(image_path, slice_index)
213
+ image_array = io.nifti_to_array(image_nii)
213
214
  image_array = self._fix_orientation(image_array)
214
215
  return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
215
216
 
216
217
  @override
217
- def load_mask(self, index: int) -> tv_tensors.Mask:
218
+ def load_target(self, index: int) -> tv_tensors.Mask:
218
219
  if self._optimize_mask_loading:
219
220
  mask = self._load_semantic_label_mask(index)
220
221
  else:
221
- mask = self._load_mask(index)
222
+ mask = self._load_target(index)
222
223
  mask = self._fix_orientation(mask)
223
224
  return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
224
225
 
@@ -227,14 +228,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
227
228
  _, slice_index = self._indices[index]
228
229
  return {"slice_index": slice_index}
229
230
 
230
- def _load_mask(self, index: int) -> npt.NDArray[Any]:
231
+ def _load_target(self, index: int) -> npt.NDArray[Any]:
231
232
  sample_index, slice_index = self._indices[index]
232
233
  return self._load_masks_as_semantic_label(sample_index, slice_index)
233
234
 
234
235
  def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
235
236
  """Loads the segmentation mask from a semantic label NifTi file."""
236
237
  sample_index, slice_index = self._indices[index]
237
- return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
238
+ nii = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
239
+ return io.nifti_to_array(nii)
238
240
 
239
241
  def _load_masks_as_semantic_label(
240
242
  self, sample_index: int, slice_index: int | None = None
@@ -248,7 +250,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
248
250
  masks_dir = self._get_masks_dir(sample_index)
249
251
  classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
250
252
  mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
251
- binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
253
+ binary_masks = [io.nifti_to_array(io.read_nifti(path, slice_index)) for path in mask_paths]
252
254
 
253
255
  if self._class_mappings:
254
256
  mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
@@ -1,17 +1,92 @@
1
1
  """Vision Dataset base class."""
2
2
 
3
3
  import abc
4
- from typing import Generic, TypeVar
4
+ from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar
5
5
 
6
6
  from eva.core.data.datasets import base
7
7
 
8
- DataSample = TypeVar("DataSample")
9
- """The data sample type."""
8
+ InputType = TypeVar("InputType")
9
+ """The input data type."""
10
10
 
11
+ TargetType = TypeVar("TargetType")
12
+ """The target data type."""
11
13
 
12
- class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
14
+
15
+ class VisionDataset(
16
+ base.MapDataset[Tuple[InputType, TargetType, Dict[str, Any]]],
17
+ abc.ABC,
18
+ Generic[InputType, TargetType],
19
+ ):
13
20
  """Base dataset class for vision tasks."""
14
21
 
22
+ def __init__(
23
+ self,
24
+ transforms: Callable | None = None,
25
+ ) -> None:
26
+ """Initializes the dataset.
27
+
28
+ Args:
29
+ transforms: A function/transform which returns a transformed
30
+ version of the raw data samples.
31
+ """
32
+ super().__init__()
33
+
34
+ self._transforms = transforms
35
+
36
+ @property
37
+ def classes(self) -> List[str] | None:
38
+ """Returns the list with names of the dataset names."""
39
+
40
+ @property
41
+ def class_to_idx(self) -> Dict[str, int] | None:
42
+ """Returns a mapping of the class name to its target index."""
43
+
44
+ def __getitem__(self, index: int) -> Tuple[InputType, TargetType, Dict[str, Any]]:
45
+ """Returns the `index`'th data sample.
46
+
47
+ Args:
48
+ index: The index of the data sample to load.
49
+
50
+ Returns:
51
+ A tuple with the image, the target and the metadata.
52
+ """
53
+ image = self.load_data(index)
54
+ target = self.load_target(index)
55
+ image, target = self._apply_transforms(image, target)
56
+ return image, target, self.load_metadata(index) or {}
57
+
58
+ def load_metadata(self, index: int) -> Dict[str, Any] | None:
59
+ """Returns the dataset metadata.
60
+
61
+ Args:
62
+ index: The index of the data sample to return the metadata of.
63
+
64
+ Returns:
65
+ The sample metadata.
66
+ """
67
+
68
+ @abc.abstractmethod
69
+ def load_data(self, index: int) -> InputType:
70
+ """Returns the `index`'th data sample.
71
+
72
+ Args:
73
+ index: The index of the data sample to load.
74
+
75
+ Returns:
76
+ The sample data.
77
+ """
78
+
79
+ @abc.abstractmethod
80
+ def load_target(self, index: int) -> TargetType:
81
+ """Returns the `index`'th target sample.
82
+
83
+ Args:
84
+ index: The index of the data sample to load.
85
+
86
+ Returns:
87
+ The sample target.
88
+ """
89
+
15
90
  @abc.abstractmethod
16
91
  def filename(self, index: int) -> str:
17
92
  """Returns the filename of the `index`'th data sample.
@@ -24,3 +99,19 @@ class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
24
99
  Returns:
25
100
  The filename of the `index`'th data sample.
26
101
  """
102
+
103
+ def _apply_transforms(
104
+ self, image: InputType, target: TargetType
105
+ ) -> Tuple[InputType, TargetType]:
106
+ """Applies the transforms to the provided data and returns them.
107
+
108
+ Args:
109
+ image: The desired image.
110
+ target: The target of the image.
111
+
112
+ Returns:
113
+ A tuple with the image and the target transformed.
114
+ """
115
+ if self._transforms is not None:
116
+ image, target = self._transforms(image, target)
117
+ return image, target
@@ -11,12 +11,12 @@ from torchvision import tv_tensors
11
11
  from torchvision.transforms.v2 import functional
12
12
  from typing_extensions import override
13
13
 
14
+ from eva.core.data.datasets import base
14
15
  from eva.vision.data import wsi
15
- from eva.vision.data.datasets import vision
16
16
  from eva.vision.data.wsi.patching import samplers
17
17
 
18
18
 
19
- class WsiDataset(vision.VisionDataset):
19
+ class WsiDataset(base.MapDataset):
20
20
  """Dataset class for reading patches from whole-slide images."""
21
21
 
22
22
  def __init__(
@@ -57,8 +57,8 @@ class WsiDataset(vision.VisionDataset):
57
57
  def __len__(self):
58
58
  return len(self._coords.x_y)
59
59
 
60
- @override
61
60
  def filename(self, index: int) -> str:
61
+ """Returns the filename of the patch at the specified index."""
62
62
  return f"{self._file_path}_{index}"
63
63
 
64
64
  @property
@@ -103,7 +103,7 @@ class WsiDataset(vision.VisionDataset):
103
103
  return image
104
104
 
105
105
 
106
- class MultiWsiDataset(vision.VisionDataset):
106
+ class MultiWsiDataset(base.MapDataset):
107
107
  """Dataset class for reading patches from multiple whole-slide images."""
108
108
 
109
109
  def __init__(
@@ -171,8 +171,8 @@ class MultiWsiDataset(vision.VisionDataset):
171
171
  def __getitem__(self, index: int) -> tv_tensors.Image:
172
172
  return self._concat_dataset[index]
173
173
 
174
- @override
175
174
  def filename(self, index: int) -> str:
175
+ """Returns the filename of the patch at the specified index."""
176
176
  return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
177
177
 
178
178
  def load_metadata(self, index: int) -> Dict[str, Any]:
@@ -1,6 +1,25 @@
1
1
  """Vision data transforms."""
2
2
 
3
- from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop
4
- from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity
3
+ from eva.vision.data.transforms.common import ResizeAndCrop
4
+ from eva.vision.data.transforms.croppad import CropForeground, RandCropByPosNegLabel, SpatialPad
5
+ from eva.vision.data.transforms.intensity import (
6
+ RandScaleIntensity,
7
+ RandShiftIntensity,
8
+ ScaleIntensityRange,
9
+ )
10
+ from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Spacing
11
+ from eva.vision.data.transforms.utility import EnsureChannelFirst
5
12
 
6
- __all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"]
13
+ __all__ = [
14
+ "ResizeAndCrop",
15
+ "CropForeground",
16
+ "RandCropByPosNegLabel",
17
+ "SpatialPad",
18
+ "RandScaleIntensity",
19
+ "RandShiftIntensity",
20
+ "ScaleIntensityRange",
21
+ "RandFlip",
22
+ "RandRotate90",
23
+ "Spacing",
24
+ "EnsureChannelFirst",
25
+ ]
@@ -1,6 +1,5 @@
1
1
  """Common vision transforms."""
2
2
 
3
- from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
4
3
  from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
5
4
 
6
- __all__ = ["ResizeAndClamp", "ResizeAndCrop"]
5
+ __all__ = ["ResizeAndCrop"]
@@ -0,0 +1,11 @@
1
+ """Transforms for crop and pad operations."""
2
+
3
+ from eva.vision.data.transforms.croppad.crop_foreground import CropForeground
4
+ from eva.vision.data.transforms.croppad.rand_crop_by_pos_neg_label import RandCropByPosNegLabel
5
+ from eva.vision.data.transforms.croppad.spatial_pad import SpatialPad
6
+
7
+ __all__ = [
8
+ "CropForeground",
9
+ "RandCropByPosNegLabel",
10
+ "SpatialPad",
11
+ ]
@@ -0,0 +1,110 @@
1
+ """Crop foreground transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import torch
7
+ from monai.config import type_definitions
8
+ from monai.transforms.croppad import array as monai_croppad_transforms
9
+ from monai.utils.enums import PytorchPadMode
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms import v2
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data import tv_tensors as eva_tv_tensors
15
+
16
+
17
+ class CropForeground(v2.Transform):
18
+ """Crop an image using a bounding box.
19
+
20
+ The bounding box is generated by selecting foreground using select_fn
21
+ at channels channel_indices. margin is added in each spatial dimension
22
+ of the bounding box. The typical usage is to help training and evaluation
23
+ if the valid part is small in the whole medical image.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ threshold: float = 0.0,
29
+ channel_indices: type_definitions.IndexSelection | None = None,
30
+ margin: Sequence[int] | int = 0,
31
+ allow_smaller: bool = True,
32
+ return_coords: bool = False,
33
+ k_divisible: Sequence[int] | int = 1,
34
+ mode: str = PytorchPadMode.CONSTANT,
35
+ **pad_kwargs,
36
+ ) -> None:
37
+ """Initializes the transform.
38
+
39
+ Args:
40
+ threshold: function to select expected foreground.
41
+ channel_indices: if defined, select foreground only on the specified channels
42
+ of image. if None, select foreground on the whole image.
43
+ margin: add margin value to spatial dims of the bounding box, if only 1 value provided,
44
+ use it for all dims.
45
+ allow_smaller: when computing box size with `margin`, whether to allow the image edges
46
+ to be smaller than the final box edges. If `False`, part of a padded output box
47
+ might be outside of the original image, if `True`, the image edges will be used as
48
+ the box edges. Default to `True`.
49
+ return_coords: whether return the coordinates of spatial bounding box for foreground.
50
+ k_divisible: make each spatial dimension to be divisible by k, default to 1.
51
+ if `k_divisible` is an int, the same `k` be applied to all the input spatial
52
+ dimensions.
53
+ mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``,
54
+ ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``,
55
+ ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor:
56
+ {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed
57
+ string values or a user supplied function. Defaults to ``"constant"``.
58
+ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
59
+ https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
60
+ pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
61
+ note that `np.pad` treats channel dimension as the first dimension.
62
+ """
63
+ super().__init__()
64
+
65
+ self._foreground_crop = monai_croppad_transforms.CropForeground(
66
+ select_fn=functools.partial(_threshold_fn, threshold=threshold),
67
+ channel_indices=channel_indices,
68
+ margin=margin,
69
+ allow_smaller=allow_smaller,
70
+ return_coords=return_coords,
71
+ k_divisible=k_divisible,
72
+ mode=mode,
73
+ lazy=False,
74
+ **pad_kwargs,
75
+ )
76
+
77
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
78
+ volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
79
+ box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
80
+ return {"box_start": box_start, "box_end": box_end}
81
+
82
+ @functools.singledispatchmethod
83
+ @override
84
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
85
+ return inpt
86
+
87
+ @_transform.register(tv_tensors.Image)
88
+ @_transform.register(eva_tv_tensors.Volume)
89
+ @_transform.register(tv_tensors.Mask)
90
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
91
+ inpt_foreground_cropped = self._foreground_crop.crop_pad(
92
+ inpt, params["box_start"], params["box_end"]
93
+ )
94
+ return tv_tensors.wrap(inpt_foreground_cropped, like=inpt)
95
+
96
+
97
+ def _threshold_fn(image: torch.Tensor, threshold: int | float = 0.3) -> torch.Tensor:
98
+ """Applies a thresholding operation to an image tensor.
99
+
100
+ Pixels greater than the threshold are set to True, while others are False.
101
+
102
+ Args:
103
+ image: Input image tensor with pixel values.
104
+ threshold: Threshold value.
105
+
106
+ Returns:
107
+ A binary mask tensor of the same shape as `image`,
108
+ where True represents pixels above the threshold.
109
+ """
110
+ return image > threshold
@@ -0,0 +1,109 @@
1
+ """Crop foreground transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import torch
7
+ from monai.config.type_definitions import NdarrayOrTensor
8
+ from monai.transforms.croppad import array as monai_croppad_transforms
9
+ from torchvision import tv_tensors
10
+ from torchvision.transforms import v2
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data import tv_tensors as eva_tv_tensors
14
+
15
+
16
+ class RandCropByPosNegLabel(v2.Transform):
17
+ """Crop random fixed sized regions with the center being a foreground or background voxel.
18
+
19
+ Its based on the Pos Neg Ratio and will return a list of arrays for all the cropped images.
20
+ For example, crop two (3 x 3) arrays from (5 x 5) array with pos/neg=1::
21
+
22
+ [[[0, 0, 0, 0, 0],
23
+ [0, 1, 2, 1, 0], [[0, 1, 2], [[2, 1, 0],
24
+ [0, 1, 3, 0, 0], --> [0, 1, 3], [3, 0, 0],
25
+ [0, 0, 0, 0, 0], [0, 0, 0]] [0, 0, 0]]
26
+ [0, 0, 0, 0, 0]]]
27
+
28
+ If a dimension of the expected spatial size is larger than the input image size,
29
+ will not crop that dimension. So the cropped result may be smaller than expected
30
+ size, and the cropped results of several images may not have exactly same shape.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ spatial_size: Sequence[int] | int,
36
+ label: torch.Tensor | None = None,
37
+ pos: float = 1.0,
38
+ neg: float = 1.0,
39
+ num_samples: int = 1,
40
+ image: torch.Tensor | None = None,
41
+ image_threshold: float = 0.0,
42
+ fg_indices: NdarrayOrTensor | None = None,
43
+ bg_indices: NdarrayOrTensor | None = None,
44
+ allow_smaller: bool = False,
45
+ ) -> None:
46
+ """Initializes the transform.
47
+
48
+ Args:
49
+ spatial_size: the spatial size of the crop region e.g. [224, 224, 128].
50
+ if a dimension of ROI size is larger than image size, will not crop that dimension.
51
+ if components have non-positive values, corresponding size of `label` will be used.
52
+ for example: if the spatial size of input data is [40, 40, 40] and
53
+ `spatial_size=[32, 64, -1]`, the spatial size of output data will be [32, 40, 40].
54
+ label: the label image that is used for finding foreground/background, if None, must
55
+ set at `self.__call__`. Non-zero indicates foreground, zero indicates background.
56
+ pos: used with `neg` together to calculate the ratio ``pos / (pos + neg)`` for
57
+ the probability to pick a foreground voxel as center rather than background voxel.
58
+ neg: used with `pos` together to calculate the ratio ``pos / (pos + neg)`` for
59
+ the probability to pick a foreground voxel as center rather than background voxel.
60
+ num_samples: number of samples (crop regions) to take in each list.
61
+ image: optional image data to help select valid area, can be same as `img` or another.
62
+ if not None, use ``label == 0 & image > image_threshold`` to select the negative
63
+ sample (background) center. Crop center will only come from valid image areas.
64
+ image_threshold: if enabled `image`, use ``image > image_threshold`` to determine
65
+ the valid image content areas.
66
+ fg_indices: if provided pre-computed foreground indices of `label`, will ignore `image`
67
+ and `image_threshold`, randomly select crop centers based on them, need to provide
68
+ `fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
69
+ a typical usage is to call `FgBgToIndices` transform first and cache the results.
70
+ bg_indices: if provided pre-computed background indices of `label`, will ignore `image`
71
+ and `image_threshold`, randomly select crop centers based on them, need to provide
72
+ `fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
73
+ a typical usage is to call `FgBgToIndices` transform first and cache the results.
74
+ allow_smaller: if `False`, an exception will be raised if the image is smaller than
75
+ the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
76
+ match the cropped size (i.e., no cropping in that dimension).
77
+ """
78
+ super().__init__()
79
+
80
+ self._rand_crop = monai_croppad_transforms.RandCropByPosNegLabel(
81
+ spatial_size=spatial_size,
82
+ label=label,
83
+ pos=pos,
84
+ neg=neg,
85
+ num_samples=num_samples,
86
+ image=image,
87
+ image_threshold=image_threshold,
88
+ fg_indices=fg_indices,
89
+ bg_indices=bg_indices,
90
+ allow_smaller=allow_smaller,
91
+ lazy=False,
92
+ )
93
+
94
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
95
+ mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
96
+ self._rand_crop.randomize(label=mask)
97
+ return {}
98
+
99
+ @functools.singledispatchmethod
100
+ @override
101
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
102
+ return inpt
103
+
104
+ @_transform.register(tv_tensors.Image)
105
+ @_transform.register(eva_tv_tensors.Volume)
106
+ @_transform.register(tv_tensors.Mask)
107
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
108
+ inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
109
+ return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
@@ -0,0 +1,67 @@
1
+ """General purpose cropper to produce sub-volume region of interest (ROI)."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Sequence
5
+
6
+ from monai.transforms.croppad import array as monai_croppad_transforms
7
+ from monai.utils.enums import Method, PytorchPadMode
8
+ from torchvision import tv_tensors
9
+ from torchvision.transforms import v2
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data import tv_tensors as eva_tv_tensors
13
+
14
+
15
+ class SpatialPad(v2.Transform):
16
+ """Performs padding to the data.
17
+
18
+ Padding is applied symmetric for all sides or all on one side for each dimension.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...],
24
+ method: str = Method.SYMMETRIC,
25
+ mode: str = PytorchPadMode.CONSTANT,
26
+ ) -> None:
27
+ """Initializes the transform.
28
+
29
+ Args:
30
+ spatial_size: The spatial size of output data after padding.
31
+ If a dimension of the input data size is larger than the
32
+ pad size, will not pad that dimension. If its components
33
+ have non-positive values, the corresponding size of input
34
+ image will be used (no padding). for example: if the spatial
35
+ size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`,
36
+ the spatial size of output data will be [32, 30, 30].
37
+ method: {``"symmetric"``, ``"end"``}
38
+ Pad image symmetrically on every side or only pad at the
39
+ end sides. Defaults to ``"symmetric"``.
40
+ mode: available modes for numpy array:{``"constant"``, ``"edge"``,
41
+ ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``,
42
+ ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
43
+ available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``,
44
+ ``"circular"``}. One of the listed string values or a user supplied function.
45
+ Defaults to ``"constant"``.
46
+ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
47
+ https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
48
+ """
49
+ super().__init__()
50
+
51
+ self._spatial_pad = monai_croppad_transforms.SpatialPad(
52
+ spatial_size=spatial_size,
53
+ method=method,
54
+ mode=mode,
55
+ )
56
+
57
+ @functools.singledispatchmethod
58
+ @override
59
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
60
+ return inpt
61
+
62
+ @_transform.register(tv_tensors.Image)
63
+ @_transform.register(eva_tv_tensors.Volume)
64
+ @_transform.register(tv_tensors.Mask)
65
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ inpt_padded = self._spatial_pad(inpt)
67
+ return tv_tensors.wrap(inpt_padded, like=inpt)
@@ -0,0 +1,11 @@
1
+ """Transforms for intensity adjustment."""
2
+
3
+ from eva.vision.data.transforms.intensity.rand_scale_intensity import RandScaleIntensity
4
+ from eva.vision.data.transforms.intensity.rand_shift_intensity import RandShiftIntensity
5
+ from eva.vision.data.transforms.intensity.scale_intensity_ranged import ScaleIntensityRange
6
+
7
+ __all__ = [
8
+ "RandScaleIntensity",
9
+ "RandShiftIntensity",
10
+ "ScaleIntensityRange",
11
+ ]
@@ -0,0 +1,59 @@
1
+ """Intensity scaling transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ import numpy as np
7
+ from monai.config.type_definitions import DtypeLike
8
+ from monai.transforms.intensity import array as monai_intensity_transforms
9
+ from torchvision import tv_tensors
10
+ from torchvision.transforms import v2
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data import tv_tensors as eva_tv_tensors
14
+
15
+
16
+ class RandScaleIntensity(v2.Transform):
17
+ """Randomly scale the intensity of input image.
18
+
19
+ The factor is by ``v = v * (1 + factor)``, where
20
+ the `factor` is randomly picked.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ factors: tuple[float, float] | float,
26
+ prob: float = 0.1,
27
+ channel_wise: bool = False,
28
+ dtype: DtypeLike = np.float32,
29
+ ) -> None:
30
+ """Initializes the transform.
31
+
32
+ Args:
33
+ factors: factor range to randomly scale by ``v = v * (1 + factor)``.
34
+ if single number, factor value is picked from (-factors, factors).
35
+ prob: probability of scale.
36
+ channel_wise: if True, shift intensity on each channel separately.
37
+ For each channel, a random offset will be chosen. Please ensure
38
+ that the first dimension represents the channel of the image if True.
39
+ dtype: output data type, if None, same as input image. defaults to float32.
40
+ """
41
+ super().__init__()
42
+
43
+ self._rand_scale_intensity = monai_intensity_transforms.RandScaleIntensity(
44
+ factors=factors,
45
+ prob=prob,
46
+ channel_wise=channel_wise,
47
+ dtype=dtype,
48
+ )
49
+
50
+ @functools.singledispatchmethod
51
+ @override
52
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
53
+ return inpt
54
+
55
+ @_transform.register(tv_tensors.Image)
56
+ @_transform.register(eva_tv_tensors.Volume)
57
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
58
+ inpt_scaled = self._rand_scale_intensity(inpt)
59
+ return tv_tensors.wrap(inpt_scaled, like=inpt)