kaiko-eva 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (85) hide show
  1. eva/core/data/datasets/base.py +7 -2
  2. eva/core/models/modules/head.py +4 -2
  3. eva/core/models/modules/typings.py +2 -2
  4. eva/core/models/transforms/__init__.py +2 -1
  5. eva/core/models/transforms/as_discrete.py +57 -0
  6. eva/core/models/wrappers/_utils.py +121 -1
  7. eva/core/trainers/_recorder.py +4 -1
  8. eva/core/utils/suppress_logs.py +28 -0
  9. eva/vision/data/__init__.py +2 -2
  10. eva/vision/data/dataloaders/__init__.py +5 -0
  11. eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
  12. eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
  13. eva/vision/data/datasets/__init__.py +2 -2
  14. eva/vision/data/datasets/classification/bach.py +3 -4
  15. eva/vision/data/datasets/classification/bracs.py +3 -4
  16. eva/vision/data/datasets/classification/breakhis.py +3 -4
  17. eva/vision/data/datasets/classification/camelyon16.py +4 -5
  18. eva/vision/data/datasets/classification/crc.py +3 -4
  19. eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
  20. eva/vision/data/datasets/classification/mhist.py +3 -4
  21. eva/vision/data/datasets/classification/panda.py +4 -5
  22. eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
  23. eva/vision/data/datasets/classification/unitopatho.py +3 -4
  24. eva/vision/data/datasets/classification/wsi.py +6 -5
  25. eva/vision/data/datasets/segmentation/__init__.py +2 -2
  26. eva/vision/data/datasets/segmentation/_utils.py +47 -0
  27. eva/vision/data/datasets/segmentation/bcss.py +7 -8
  28. eva/vision/data/datasets/segmentation/btcv.py +236 -0
  29. eva/vision/data/datasets/segmentation/consep.py +6 -7
  30. eva/vision/data/datasets/segmentation/lits.py +9 -8
  31. eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
  32. eva/vision/data/datasets/segmentation/monusac.py +4 -5
  33. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
  34. eva/vision/data/datasets/vision.py +95 -4
  35. eva/vision/data/datasets/wsi.py +5 -5
  36. eva/vision/data/transforms/__init__.py +22 -3
  37. eva/vision/data/transforms/common/__init__.py +1 -2
  38. eva/vision/data/transforms/croppad/__init__.py +11 -0
  39. eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
  40. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
  41. eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
  42. eva/vision/data/transforms/intensity/__init__.py +11 -0
  43. eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
  44. eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
  45. eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
  46. eva/vision/data/transforms/spatial/__init__.py +7 -0
  47. eva/vision/data/transforms/spatial/flip.py +72 -0
  48. eva/vision/data/transforms/spatial/rotate.py +53 -0
  49. eva/vision/data/transforms/spatial/spacing.py +69 -0
  50. eva/vision/data/transforms/utility/__init__.py +5 -0
  51. eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
  52. eva/vision/data/tv_tensors/__init__.py +5 -0
  53. eva/vision/data/tv_tensors/volume.py +61 -0
  54. eva/vision/metrics/segmentation/monai_dice.py +9 -2
  55. eva/vision/models/modules/semantic_segmentation.py +32 -19
  56. eva/vision/models/networks/backbones/__init__.py +9 -2
  57. eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
  58. eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
  59. eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
  60. eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
  61. eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
  62. eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
  63. eva/vision/models/networks/backbones/radiology/voco.py +75 -0
  64. eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
  65. eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
  66. eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
  67. eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
  68. eva/vision/utils/io/__init__.py +2 -0
  69. eva/vision/utils/io/nifti.py +91 -11
  70. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/METADATA +16 -12
  71. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/RECORD +74 -58
  72. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/WHEEL +1 -1
  73. eva/vision/data/datasets/classification/base.py +0 -96
  74. eva/vision/data/datasets/segmentation/base.py +0 -96
  75. eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
  76. eva/vision/data/transforms/normalization/__init__.py +0 -6
  77. eva/vision/data/transforms/normalization/clamp.py +0 -43
  78. eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
  79. eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
  80. eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
  81. eva/vision/metrics/segmentation/BUILD +0 -1
  82. eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
  83. eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
  84. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/entry_points.txt +0 -0
  85. {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -11,12 +11,12 @@ from torchvision import tv_tensors
11
11
  from torchvision.transforms.v2 import functional
12
12
  from typing_extensions import override
13
13
 
14
+ from eva.core.data.datasets import base
14
15
  from eva.vision.data import wsi
15
- from eva.vision.data.datasets import vision
16
16
  from eva.vision.data.wsi.patching import samplers
17
17
 
18
18
 
19
- class WsiDataset(vision.VisionDataset):
19
+ class WsiDataset(base.MapDataset):
20
20
  """Dataset class for reading patches from whole-slide images."""
21
21
 
22
22
  def __init__(
@@ -57,8 +57,8 @@ class WsiDataset(vision.VisionDataset):
57
57
  def __len__(self):
58
58
  return len(self._coords.x_y)
59
59
 
60
- @override
61
60
  def filename(self, index: int) -> str:
61
+ """Returns the filename of the patch at the specified index."""
62
62
  return f"{self._file_path}_{index}"
63
63
 
64
64
  @property
@@ -103,7 +103,7 @@ class WsiDataset(vision.VisionDataset):
103
103
  return image
104
104
 
105
105
 
106
- class MultiWsiDataset(vision.VisionDataset):
106
+ class MultiWsiDataset(base.MapDataset):
107
107
  """Dataset class for reading patches from multiple whole-slide images."""
108
108
 
109
109
  def __init__(
@@ -171,8 +171,8 @@ class MultiWsiDataset(vision.VisionDataset):
171
171
  def __getitem__(self, index: int) -> tv_tensors.Image:
172
172
  return self._concat_dataset[index]
173
173
 
174
- @override
175
174
  def filename(self, index: int) -> str:
175
+ """Returns the filename of the patch at the specified index."""
176
176
  return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
177
177
 
178
178
  def load_metadata(self, index: int) -> Dict[str, Any]:
@@ -1,6 +1,25 @@
1
1
  """Vision data transforms."""
2
2
 
3
- from eva.vision.data.transforms.common import ResizeAndClamp, ResizeAndCrop
4
- from eva.vision.data.transforms.normalization import Clamp, RescaleIntensity
3
+ from eva.vision.data.transforms.common import ResizeAndCrop
4
+ from eva.vision.data.transforms.croppad import CropForeground, RandCropByPosNegLabel, SpatialPad
5
+ from eva.vision.data.transforms.intensity import (
6
+ RandScaleIntensity,
7
+ RandShiftIntensity,
8
+ ScaleIntensityRange,
9
+ )
10
+ from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Spacing
11
+ from eva.vision.data.transforms.utility import EnsureChannelFirst
5
12
 
6
- __all__ = ["ResizeAndCrop", "ResizeAndClamp", "Clamp", "RescaleIntensity"]
13
+ __all__ = [
14
+ "ResizeAndCrop",
15
+ "CropForeground",
16
+ "RandCropByPosNegLabel",
17
+ "SpatialPad",
18
+ "RandScaleIntensity",
19
+ "RandShiftIntensity",
20
+ "ScaleIntensityRange",
21
+ "RandFlip",
22
+ "RandRotate90",
23
+ "Spacing",
24
+ "EnsureChannelFirst",
25
+ ]
@@ -1,6 +1,5 @@
1
1
  """Common vision transforms."""
2
2
 
3
- from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
4
3
  from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
5
4
 
6
- __all__ = ["ResizeAndClamp", "ResizeAndCrop"]
5
+ __all__ = ["ResizeAndCrop"]
@@ -0,0 +1,11 @@
1
+ """Transforms for crop and pad operations."""
2
+
3
+ from eva.vision.data.transforms.croppad.crop_foreground import CropForeground
4
+ from eva.vision.data.transforms.croppad.rand_crop_by_pos_neg_label import RandCropByPosNegLabel
5
+ from eva.vision.data.transforms.croppad.spatial_pad import SpatialPad
6
+
7
+ __all__ = [
8
+ "CropForeground",
9
+ "RandCropByPosNegLabel",
10
+ "SpatialPad",
11
+ ]
@@ -0,0 +1,110 @@
1
+ """Crop foreground transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import torch
7
+ from monai.config import type_definitions
8
+ from monai.transforms.croppad import array as monai_croppad_transforms
9
+ from monai.utils.enums import PytorchPadMode
10
+ from torchvision import tv_tensors
11
+ from torchvision.transforms import v2
12
+ from typing_extensions import override
13
+
14
+ from eva.vision.data import tv_tensors as eva_tv_tensors
15
+
16
+
17
+ class CropForeground(v2.Transform):
18
+ """Crop an image using a bounding box.
19
+
20
+ The bounding box is generated by selecting foreground using select_fn
21
+ at channels channel_indices. margin is added in each spatial dimension
22
+ of the bounding box. The typical usage is to help training and evaluation
23
+ if the valid part is small in the whole medical image.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ threshold: float = 0.0,
29
+ channel_indices: type_definitions.IndexSelection | None = None,
30
+ margin: Sequence[int] | int = 0,
31
+ allow_smaller: bool = True,
32
+ return_coords: bool = False,
33
+ k_divisible: Sequence[int] | int = 1,
34
+ mode: str = PytorchPadMode.CONSTANT,
35
+ **pad_kwargs,
36
+ ) -> None:
37
+ """Initializes the transform.
38
+
39
+ Args:
40
+ threshold: function to select expected foreground.
41
+ channel_indices: if defined, select foreground only on the specified channels
42
+ of image. if None, select foreground on the whole image.
43
+ margin: add margin value to spatial dims of the bounding box, if only 1 value provided,
44
+ use it for all dims.
45
+ allow_smaller: when computing box size with `margin`, whether to allow the image edges
46
+ to be smaller than the final box edges. If `False`, part of a padded output box
47
+ might be outside of the original image, if `True`, the image edges will be used as
48
+ the box edges. Default to `True`.
49
+ return_coords: whether return the coordinates of spatial bounding box for foreground.
50
+ k_divisible: make each spatial dimension to be divisible by k, default to 1.
51
+ if `k_divisible` is an int, the same `k` be applied to all the input spatial
52
+ dimensions.
53
+ mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``,
54
+ ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``,
55
+ ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor:
56
+ {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed
57
+ string values or a user supplied function. Defaults to ``"constant"``.
58
+ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
59
+ https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
60
+ pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
61
+ note that `np.pad` treats channel dimension as the first dimension.
62
+ """
63
+ super().__init__()
64
+
65
+ self._foreground_crop = monai_croppad_transforms.CropForeground(
66
+ select_fn=functools.partial(_threshold_fn, threshold=threshold),
67
+ channel_indices=channel_indices,
68
+ margin=margin,
69
+ allow_smaller=allow_smaller,
70
+ return_coords=return_coords,
71
+ k_divisible=k_divisible,
72
+ mode=mode,
73
+ lazy=False,
74
+ **pad_kwargs,
75
+ )
76
+
77
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
78
+ volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
79
+ box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
80
+ return {"box_start": box_start, "box_end": box_end}
81
+
82
+ @functools.singledispatchmethod
83
+ @override
84
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
85
+ return inpt
86
+
87
+ @_transform.register(tv_tensors.Image)
88
+ @_transform.register(eva_tv_tensors.Volume)
89
+ @_transform.register(tv_tensors.Mask)
90
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
91
+ inpt_foreground_cropped = self._foreground_crop.crop_pad(
92
+ inpt, params["box_start"], params["box_end"]
93
+ )
94
+ return tv_tensors.wrap(inpt_foreground_cropped, like=inpt)
95
+
96
+
97
+ def _threshold_fn(image: torch.Tensor, threshold: int | float = 0.3) -> torch.Tensor:
98
+ """Applies a thresholding operation to an image tensor.
99
+
100
+ Pixels greater than the threshold are set to True, while others are False.
101
+
102
+ Args:
103
+ image: Input image tensor with pixel values.
104
+ threshold: Threshold value.
105
+
106
+ Returns:
107
+ A binary mask tensor of the same shape as `image`,
108
+ where True represents pixels above the threshold.
109
+ """
110
+ return image > threshold
@@ -0,0 +1,109 @@
1
+ """Crop foreground transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import torch
7
+ from monai.config.type_definitions import NdarrayOrTensor
8
+ from monai.transforms.croppad import array as monai_croppad_transforms
9
+ from torchvision import tv_tensors
10
+ from torchvision.transforms import v2
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data import tv_tensors as eva_tv_tensors
14
+
15
+
16
+ class RandCropByPosNegLabel(v2.Transform):
17
+ """Crop random fixed sized regions with the center being a foreground or background voxel.
18
+
19
+ Its based on the Pos Neg Ratio and will return a list of arrays for all the cropped images.
20
+ For example, crop two (3 x 3) arrays from (5 x 5) array with pos/neg=1::
21
+
22
+ [[[0, 0, 0, 0, 0],
23
+ [0, 1, 2, 1, 0], [[0, 1, 2], [[2, 1, 0],
24
+ [0, 1, 3, 0, 0], --> [0, 1, 3], [3, 0, 0],
25
+ [0, 0, 0, 0, 0], [0, 0, 0]] [0, 0, 0]]
26
+ [0, 0, 0, 0, 0]]]
27
+
28
+ If a dimension of the expected spatial size is larger than the input image size,
29
+ will not crop that dimension. So the cropped result may be smaller than expected
30
+ size, and the cropped results of several images may not have exactly same shape.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ spatial_size: Sequence[int] | int,
36
+ label: torch.Tensor | None = None,
37
+ pos: float = 1.0,
38
+ neg: float = 1.0,
39
+ num_samples: int = 1,
40
+ image: torch.Tensor | None = None,
41
+ image_threshold: float = 0.0,
42
+ fg_indices: NdarrayOrTensor | None = None,
43
+ bg_indices: NdarrayOrTensor | None = None,
44
+ allow_smaller: bool = False,
45
+ ) -> None:
46
+ """Initializes the transform.
47
+
48
+ Args:
49
+ spatial_size: the spatial size of the crop region e.g. [224, 224, 128].
50
+ if a dimension of ROI size is larger than image size, will not crop that dimension.
51
+ if components have non-positive values, corresponding size of `label` will be used.
52
+ for example: if the spatial size of input data is [40, 40, 40] and
53
+ `spatial_size=[32, 64, -1]`, the spatial size of output data will be [32, 40, 40].
54
+ label: the label image that is used for finding foreground/background, if None, must
55
+ set at `self.__call__`. Non-zero indicates foreground, zero indicates background.
56
+ pos: used with `neg` together to calculate the ratio ``pos / (pos + neg)`` for
57
+ the probability to pick a foreground voxel as center rather than background voxel.
58
+ neg: used with `pos` together to calculate the ratio ``pos / (pos + neg)`` for
59
+ the probability to pick a foreground voxel as center rather than background voxel.
60
+ num_samples: number of samples (crop regions) to take in each list.
61
+ image: optional image data to help select valid area, can be same as `img` or another.
62
+ if not None, use ``label == 0 & image > image_threshold`` to select the negative
63
+ sample (background) center. Crop center will only come from valid image areas.
64
+ image_threshold: if enabled `image`, use ``image > image_threshold`` to determine
65
+ the valid image content areas.
66
+ fg_indices: if provided pre-computed foreground indices of `label`, will ignore `image`
67
+ and `image_threshold`, randomly select crop centers based on them, need to provide
68
+ `fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
69
+ a typical usage is to call `FgBgToIndices` transform first and cache the results.
70
+ bg_indices: if provided pre-computed background indices of `label`, will ignore `image`
71
+ and `image_threshold`, randomly select crop centers based on them, need to provide
72
+ `fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
73
+ a typical usage is to call `FgBgToIndices` transform first and cache the results.
74
+ allow_smaller: if `False`, an exception will be raised if the image is smaller than
75
+ the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
76
+ match the cropped size (i.e., no cropping in that dimension).
77
+ """
78
+ super().__init__()
79
+
80
+ self._rand_crop = monai_croppad_transforms.RandCropByPosNegLabel(
81
+ spatial_size=spatial_size,
82
+ label=label,
83
+ pos=pos,
84
+ neg=neg,
85
+ num_samples=num_samples,
86
+ image=image,
87
+ image_threshold=image_threshold,
88
+ fg_indices=fg_indices,
89
+ bg_indices=bg_indices,
90
+ allow_smaller=allow_smaller,
91
+ lazy=False,
92
+ )
93
+
94
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
95
+ mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
96
+ self._rand_crop.randomize(label=mask)
97
+ return {}
98
+
99
+ @functools.singledispatchmethod
100
+ @override
101
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
102
+ return inpt
103
+
104
+ @_transform.register(tv_tensors.Image)
105
+ @_transform.register(eva_tv_tensors.Volume)
106
+ @_transform.register(tv_tensors.Mask)
107
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
108
+ inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
109
+ return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
@@ -0,0 +1,67 @@
1
+ """General purpose cropper to produce sub-volume region of interest (ROI)."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Sequence
5
+
6
+ from monai.transforms.croppad import array as monai_croppad_transforms
7
+ from monai.utils.enums import Method, PytorchPadMode
8
+ from torchvision import tv_tensors
9
+ from torchvision.transforms import v2
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data import tv_tensors as eva_tv_tensors
13
+
14
+
15
+ class SpatialPad(v2.Transform):
16
+ """Performs padding to the data.
17
+
18
+ Padding is applied symmetric for all sides or all on one side for each dimension.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...],
24
+ method: str = Method.SYMMETRIC,
25
+ mode: str = PytorchPadMode.CONSTANT,
26
+ ) -> None:
27
+ """Initializes the transform.
28
+
29
+ Args:
30
+ spatial_size: The spatial size of output data after padding.
31
+ If a dimension of the input data size is larger than the
32
+ pad size, will not pad that dimension. If its components
33
+ have non-positive values, the corresponding size of input
34
+ image will be used (no padding). for example: if the spatial
35
+ size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`,
36
+ the spatial size of output data will be [32, 30, 30].
37
+ method: {``"symmetric"``, ``"end"``}
38
+ Pad image symmetrically on every side or only pad at the
39
+ end sides. Defaults to ``"symmetric"``.
40
+ mode: available modes for numpy array:{``"constant"``, ``"edge"``,
41
+ ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``,
42
+ ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
43
+ available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``,
44
+ ``"circular"``}. One of the listed string values or a user supplied function.
45
+ Defaults to ``"constant"``.
46
+ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
47
+ https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
48
+ """
49
+ super().__init__()
50
+
51
+ self._spatial_pad = monai_croppad_transforms.SpatialPad(
52
+ spatial_size=spatial_size,
53
+ method=method,
54
+ mode=mode,
55
+ )
56
+
57
+ @functools.singledispatchmethod
58
+ @override
59
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
60
+ return inpt
61
+
62
+ @_transform.register(tv_tensors.Image)
63
+ @_transform.register(eva_tv_tensors.Volume)
64
+ @_transform.register(tv_tensors.Mask)
65
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ inpt_padded = self._spatial_pad(inpt)
67
+ return tv_tensors.wrap(inpt_padded, like=inpt)
@@ -0,0 +1,11 @@
1
+ """Transforms for intensity adjustment."""
2
+
3
+ from eva.vision.data.transforms.intensity.rand_scale_intensity import RandScaleIntensity
4
+ from eva.vision.data.transforms.intensity.rand_shift_intensity import RandShiftIntensity
5
+ from eva.vision.data.transforms.intensity.scale_intensity_ranged import ScaleIntensityRange
6
+
7
+ __all__ = [
8
+ "RandScaleIntensity",
9
+ "RandShiftIntensity",
10
+ "ScaleIntensityRange",
11
+ ]
@@ -0,0 +1,59 @@
1
+ """Intensity scaling transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ import numpy as np
7
+ from monai.config.type_definitions import DtypeLike
8
+ from monai.transforms.intensity import array as monai_intensity_transforms
9
+ from torchvision import tv_tensors
10
+ from torchvision.transforms import v2
11
+ from typing_extensions import override
12
+
13
+ from eva.vision.data import tv_tensors as eva_tv_tensors
14
+
15
+
16
+ class RandScaleIntensity(v2.Transform):
17
+ """Randomly scale the intensity of input image.
18
+
19
+ The factor is by ``v = v * (1 + factor)``, where
20
+ the `factor` is randomly picked.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ factors: tuple[float, float] | float,
26
+ prob: float = 0.1,
27
+ channel_wise: bool = False,
28
+ dtype: DtypeLike = np.float32,
29
+ ) -> None:
30
+ """Initializes the transform.
31
+
32
+ Args:
33
+ factors: factor range to randomly scale by ``v = v * (1 + factor)``.
34
+ if single number, factor value is picked from (-factors, factors).
35
+ prob: probability of scale.
36
+ channel_wise: if True, shift intensity on each channel separately.
37
+ For each channel, a random offset will be chosen. Please ensure
38
+ that the first dimension represents the channel of the image if True.
39
+ dtype: output data type, if None, same as input image. defaults to float32.
40
+ """
41
+ super().__init__()
42
+
43
+ self._rand_scale_intensity = monai_intensity_transforms.RandScaleIntensity(
44
+ factors=factors,
45
+ prob=prob,
46
+ channel_wise=channel_wise,
47
+ dtype=dtype,
48
+ )
49
+
50
+ @functools.singledispatchmethod
51
+ @override
52
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
53
+ return inpt
54
+
55
+ @_transform.register(tv_tensors.Image)
56
+ @_transform.register(eva_tv_tensors.Volume)
57
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
58
+ inpt_scaled = self._rand_scale_intensity(inpt)
59
+ return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -0,0 +1,55 @@
1
+ """Intensity shifting transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict
5
+
6
+ from monai.transforms.intensity import array as monai_intensity_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class RandShiftIntensity(v2.Transform):
15
+ """Randomly shift intensity with randomly picked offset."""
16
+
17
+ def __init__(
18
+ self,
19
+ offsets: tuple[float, float] | float,
20
+ safe: bool = False,
21
+ prob: float = 0.1,
22
+ channel_wise: bool = False,
23
+ ) -> None:
24
+ """Initializes the transform.
25
+
26
+ Args:
27
+ offsets: Offset range to randomly shift.
28
+ if single number, offset value is picked from (-offsets, offsets).
29
+ safe: If `True`, then do safe dtype convert when intensity overflow.
30
+ E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then
31
+ `[256, -12]` -> `[array(255), array(0)]`.
32
+ prob: Probability of shift.
33
+ channel_wise: If True, shift intensity on each channel separately.
34
+ For each channel, a random offset will be chosen. Please ensure
35
+ that the first dimension represents the channel of the image if True.
36
+ """
37
+ super().__init__()
38
+
39
+ self._rand_swift_intensity = monai_intensity_transforms.RandShiftIntensity(
40
+ offsets=offsets,
41
+ safe=safe,
42
+ prob=prob,
43
+ channel_wise=channel_wise,
44
+ )
45
+
46
+ @functools.singledispatchmethod
47
+ @override
48
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
49
+ return inpt
50
+
51
+ @_transform.register(tv_tensors.Image)
52
+ @_transform.register(eva_tv_tensors.Volume)
53
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
54
+ inpt_scaled = self._rand_swift_intensity(inpt)
55
+ return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -0,0 +1,56 @@
1
+ """Intensity scaling transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, Tuple
5
+
6
+ from monai.transforms.intensity import array as monai_intensity_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from typing_extensions import override
10
+
11
+ from eva.vision.data import tv_tensors as eva_tv_tensors
12
+
13
+
14
+ class ScaleIntensityRange(v2.Transform):
15
+ """Intensity scaling transform.
16
+
17
+ Scaling from [a_min, a_max] to [b_min, b_max] with clip option.
18
+
19
+ When `b_min` or `b_max` are `None`, `scaled_array * (b_max - b_min) + b_min`
20
+ will be skipped. If `clip=True`, when `b_min`/`b_max` is None, the clipping
21
+ is not performed on the corresponding edge.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ input_range: Tuple[float, float],
27
+ output_range: Tuple[float, float] | None = None,
28
+ clip: bool = True,
29
+ ) -> None:
30
+ """Initializes the transform.
31
+
32
+ Args:
33
+ input_range: Intensity original range min and max.
34
+ output_range: Intensity target range min and max.
35
+ clip: Whether to perform clip after scaling.
36
+ """
37
+ super().__init__()
38
+
39
+ self._scale_intensity_range = monai_intensity_transforms.ScaleIntensityRange(
40
+ a_min=input_range[0],
41
+ a_max=input_range[1],
42
+ b_min=output_range[0] if output_range else None,
43
+ b_max=output_range[1] if output_range else None,
44
+ clip=clip,
45
+ )
46
+
47
+ @functools.singledispatchmethod
48
+ @override
49
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
50
+ return inpt
51
+
52
+ @_transform.register(tv_tensors.Image)
53
+ @_transform.register(eva_tv_tensors.Volume)
54
+ def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
55
+ inpt_scaled = self._scale_intensity_range(inpt)
56
+ return tv_tensors.wrap(inpt_scaled, like=inpt)
@@ -0,0 +1,7 @@
1
+ """Transforms for spatial operations."""
2
+
3
+ from eva.vision.data.transforms.spatial.flip import RandFlip
4
+ from eva.vision.data.transforms.spatial.rotate import RandRotate90
5
+ from eva.vision.data.transforms.spatial.spacing import Spacing
6
+
7
+ __all__ = ["Spacing", "RandFlip", "RandRotate90"]
@@ -0,0 +1,72 @@
1
+ """Flip transforms."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import torch
7
+ from monai.transforms.spatial import array as monai_spatial_transforms
8
+ from torchvision import tv_tensors
9
+ from torchvision.transforms import v2
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data import tv_tensors as eva_tv_tensors
13
+
14
+
15
+ class RandFlip(v2.Transform):
16
+ """Randomly flips the image along axes."""
17
+
18
+ def __init__(
19
+ self,
20
+ prob: float = 0.1,
21
+ spatial_axes: Sequence[int] | int | None = None,
22
+ apply_per_axis: bool = True,
23
+ ) -> None:
24
+ """Initializes the transform.
25
+
26
+ Args:
27
+ prob: Probability of flipping.
28
+ spatial_axes: Spatial axes along which to flip over. Default is None.
29
+ apply_per_axis: If True, will apply a random flip transform to each
30
+ axis individually (if spatial_axes is a sequence of multiple axis).
31
+ If False, will apply a single random flip transform applied to all axes.
32
+ """
33
+ super().__init__()
34
+
35
+ if apply_per_axis:
36
+ if not isinstance(spatial_axes, (list, tuple)):
37
+ raise ValueError(
38
+ "`spatial_axis` is expected to be sequence `apply_per_axis` "
39
+ f"is enabled, got {type(spatial_axes)}"
40
+ )
41
+ self._flips = [
42
+ monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=axis)
43
+ for axis in spatial_axes
44
+ ]
45
+ else:
46
+ self._flips = [monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=spatial_axes)]
47
+
48
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
49
+ for flip in self._flips:
50
+ flip.randomize(None)
51
+ return {}
52
+
53
+ @functools.singledispatchmethod
54
+ @override
55
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
56
+ return inpt
57
+
58
+ @_transform.register(tv_tensors.Image)
59
+ @_transform.register(eva_tv_tensors.Volume)
60
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
61
+ inpt_flipped = self._apply_flips(inpt)
62
+ return tv_tensors.wrap(inpt_flipped, like=inpt)
63
+
64
+ @_transform.register(tv_tensors.Mask)
65
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ inpt_flipped = torch.tensor(self._apply_flips(inpt), dtype=torch.long)
67
+ return tv_tensors.wrap(inpt_flipped, like=inpt)
68
+
69
+ def _apply_flips(self, inpt: Any) -> Any:
70
+ for flip in self._flips:
71
+ inpt = flip(img=inpt, randomize=False)
72
+ return inpt