kaiko-eva 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/models/modules/head.py +4 -2
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +2 -2
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +3 -4
- eva/vision/data/datasets/classification/breakhis.py +3 -4
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +3 -4
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +3 -4
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +73 -57
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.0.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
eva/vision/data/datasets/wsi.py
CHANGED
|
@@ -11,12 +11,12 @@ from torchvision import tv_tensors
|
|
|
11
11
|
from torchvision.transforms.v2 import functional
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
+
from eva.core.data.datasets import base
|
|
14
15
|
from eva.vision.data import wsi
|
|
15
|
-
from eva.vision.data.datasets import vision
|
|
16
16
|
from eva.vision.data.wsi.patching import samplers
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class WsiDataset(
|
|
19
|
+
class WsiDataset(base.MapDataset):
|
|
20
20
|
"""Dataset class for reading patches from whole-slide images."""
|
|
21
21
|
|
|
22
22
|
def __init__(
|
|
@@ -57,8 +57,8 @@ class WsiDataset(vision.VisionDataset):
|
|
|
57
57
|
def __len__(self):
|
|
58
58
|
return len(self._coords.x_y)
|
|
59
59
|
|
|
60
|
-
@override
|
|
61
60
|
def filename(self, index: int) -> str:
|
|
61
|
+
"""Returns the filename of the patch at the specified index."""
|
|
62
62
|
return f"{self._file_path}_{index}"
|
|
63
63
|
|
|
64
64
|
@property
|
|
@@ -103,7 +103,7 @@ class WsiDataset(vision.VisionDataset):
|
|
|
103
103
|
return image
|
|
104
104
|
|
|
105
105
|
|
|
106
|
-
class MultiWsiDataset(
|
|
106
|
+
class MultiWsiDataset(base.MapDataset):
|
|
107
107
|
"""Dataset class for reading patches from multiple whole-slide images."""
|
|
108
108
|
|
|
109
109
|
def __init__(
|
|
@@ -171,8 +171,8 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
171
171
|
def __getitem__(self, index: int) -> tv_tensors.Image:
|
|
172
172
|
return self._concat_dataset[index]
|
|
173
173
|
|
|
174
|
-
@override
|
|
175
174
|
def filename(self, index: int) -> str:
|
|
175
|
+
"""Returns the filename of the patch at the specified index."""
|
|
176
176
|
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
|
|
177
177
|
|
|
178
178
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
@@ -1,6 +1,25 @@
|
|
|
1
1
|
"""Vision data transforms."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.transforms.common import
|
|
4
|
-
from eva.vision.data.transforms.
|
|
3
|
+
from eva.vision.data.transforms.common import ResizeAndCrop
|
|
4
|
+
from eva.vision.data.transforms.croppad import CropForeground, RandCropByPosNegLabel, SpatialPad
|
|
5
|
+
from eva.vision.data.transforms.intensity import (
|
|
6
|
+
RandScaleIntensity,
|
|
7
|
+
RandShiftIntensity,
|
|
8
|
+
ScaleIntensityRange,
|
|
9
|
+
)
|
|
10
|
+
from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Spacing
|
|
11
|
+
from eva.vision.data.transforms.utility import EnsureChannelFirst
|
|
5
12
|
|
|
6
|
-
__all__ = [
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ResizeAndCrop",
|
|
15
|
+
"CropForeground",
|
|
16
|
+
"RandCropByPosNegLabel",
|
|
17
|
+
"SpatialPad",
|
|
18
|
+
"RandScaleIntensity",
|
|
19
|
+
"RandShiftIntensity",
|
|
20
|
+
"ScaleIntensityRange",
|
|
21
|
+
"RandFlip",
|
|
22
|
+
"RandRotate90",
|
|
23
|
+
"Spacing",
|
|
24
|
+
"EnsureChannelFirst",
|
|
25
|
+
]
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Common vision transforms."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
|
|
4
3
|
from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
|
|
5
4
|
|
|
6
|
-
__all__ = ["
|
|
5
|
+
__all__ = ["ResizeAndCrop"]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Transforms for crop and pad operations."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.transforms.croppad.crop_foreground import CropForeground
|
|
4
|
+
from eva.vision.data.transforms.croppad.rand_crop_by_pos_neg_label import RandCropByPosNegLabel
|
|
5
|
+
from eva.vision.data.transforms.croppad.spatial_pad import SpatialPad
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CropForeground",
|
|
9
|
+
"RandCropByPosNegLabel",
|
|
10
|
+
"SpatialPad",
|
|
11
|
+
]
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Crop foreground transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from monai.config import type_definitions
|
|
8
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
|
+
from monai.utils.enums import PytorchPadMode
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.transforms import v2
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CropForeground(v2.Transform):
|
|
18
|
+
"""Crop an image using a bounding box.
|
|
19
|
+
|
|
20
|
+
The bounding box is generated by selecting foreground using select_fn
|
|
21
|
+
at channels channel_indices. margin is added in each spatial dimension
|
|
22
|
+
of the bounding box. The typical usage is to help training and evaluation
|
|
23
|
+
if the valid part is small in the whole medical image.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
threshold: float = 0.0,
|
|
29
|
+
channel_indices: type_definitions.IndexSelection | None = None,
|
|
30
|
+
margin: Sequence[int] | int = 0,
|
|
31
|
+
allow_smaller: bool = True,
|
|
32
|
+
return_coords: bool = False,
|
|
33
|
+
k_divisible: Sequence[int] | int = 1,
|
|
34
|
+
mode: str = PytorchPadMode.CONSTANT,
|
|
35
|
+
**pad_kwargs,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Initializes the transform.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
threshold: function to select expected foreground.
|
|
41
|
+
channel_indices: if defined, select foreground only on the specified channels
|
|
42
|
+
of image. if None, select foreground on the whole image.
|
|
43
|
+
margin: add margin value to spatial dims of the bounding box, if only 1 value provided,
|
|
44
|
+
use it for all dims.
|
|
45
|
+
allow_smaller: when computing box size with `margin`, whether to allow the image edges
|
|
46
|
+
to be smaller than the final box edges. If `False`, part of a padded output box
|
|
47
|
+
might be outside of the original image, if `True`, the image edges will be used as
|
|
48
|
+
the box edges. Default to `True`.
|
|
49
|
+
return_coords: whether return the coordinates of spatial bounding box for foreground.
|
|
50
|
+
k_divisible: make each spatial dimension to be divisible by k, default to 1.
|
|
51
|
+
if `k_divisible` is an int, the same `k` be applied to all the input spatial
|
|
52
|
+
dimensions.
|
|
53
|
+
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``,
|
|
54
|
+
``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``,
|
|
55
|
+
``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor:
|
|
56
|
+
{``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed
|
|
57
|
+
string values or a user supplied function. Defaults to ``"constant"``.
|
|
58
|
+
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
|
|
59
|
+
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
|
60
|
+
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
|
|
61
|
+
note that `np.pad` treats channel dimension as the first dimension.
|
|
62
|
+
"""
|
|
63
|
+
super().__init__()
|
|
64
|
+
|
|
65
|
+
self._foreground_crop = monai_croppad_transforms.CropForeground(
|
|
66
|
+
select_fn=functools.partial(_threshold_fn, threshold=threshold),
|
|
67
|
+
channel_indices=channel_indices,
|
|
68
|
+
margin=margin,
|
|
69
|
+
allow_smaller=allow_smaller,
|
|
70
|
+
return_coords=return_coords,
|
|
71
|
+
k_divisible=k_divisible,
|
|
72
|
+
mode=mode,
|
|
73
|
+
lazy=False,
|
|
74
|
+
**pad_kwargs,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
78
|
+
volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
|
|
79
|
+
box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
|
|
80
|
+
return {"box_start": box_start, "box_end": box_end}
|
|
81
|
+
|
|
82
|
+
@functools.singledispatchmethod
|
|
83
|
+
@override
|
|
84
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
85
|
+
return inpt
|
|
86
|
+
|
|
87
|
+
@_transform.register(tv_tensors.Image)
|
|
88
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
89
|
+
@_transform.register(tv_tensors.Mask)
|
|
90
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
91
|
+
inpt_foreground_cropped = self._foreground_crop.crop_pad(
|
|
92
|
+
inpt, params["box_start"], params["box_end"]
|
|
93
|
+
)
|
|
94
|
+
return tv_tensors.wrap(inpt_foreground_cropped, like=inpt)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _threshold_fn(image: torch.Tensor, threshold: int | float = 0.3) -> torch.Tensor:
|
|
98
|
+
"""Applies a thresholding operation to an image tensor.
|
|
99
|
+
|
|
100
|
+
Pixels greater than the threshold are set to True, while others are False.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
image: Input image tensor with pixel values.
|
|
104
|
+
threshold: Threshold value.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
A binary mask tensor of the same shape as `image`,
|
|
108
|
+
where True represents pixels above the threshold.
|
|
109
|
+
"""
|
|
110
|
+
return image > threshold
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Crop foreground transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from monai.config.type_definitions import NdarrayOrTensor
|
|
8
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.transforms import v2
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RandCropByPosNegLabel(v2.Transform):
|
|
17
|
+
"""Crop random fixed sized regions with the center being a foreground or background voxel.
|
|
18
|
+
|
|
19
|
+
Its based on the Pos Neg Ratio and will return a list of arrays for all the cropped images.
|
|
20
|
+
For example, crop two (3 x 3) arrays from (5 x 5) array with pos/neg=1::
|
|
21
|
+
|
|
22
|
+
[[[0, 0, 0, 0, 0],
|
|
23
|
+
[0, 1, 2, 1, 0], [[0, 1, 2], [[2, 1, 0],
|
|
24
|
+
[0, 1, 3, 0, 0], --> [0, 1, 3], [3, 0, 0],
|
|
25
|
+
[0, 0, 0, 0, 0], [0, 0, 0]] [0, 0, 0]]
|
|
26
|
+
[0, 0, 0, 0, 0]]]
|
|
27
|
+
|
|
28
|
+
If a dimension of the expected spatial size is larger than the input image size,
|
|
29
|
+
will not crop that dimension. So the cropped result may be smaller than expected
|
|
30
|
+
size, and the cropped results of several images may not have exactly same shape.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
spatial_size: Sequence[int] | int,
|
|
36
|
+
label: torch.Tensor | None = None,
|
|
37
|
+
pos: float = 1.0,
|
|
38
|
+
neg: float = 1.0,
|
|
39
|
+
num_samples: int = 1,
|
|
40
|
+
image: torch.Tensor | None = None,
|
|
41
|
+
image_threshold: float = 0.0,
|
|
42
|
+
fg_indices: NdarrayOrTensor | None = None,
|
|
43
|
+
bg_indices: NdarrayOrTensor | None = None,
|
|
44
|
+
allow_smaller: bool = False,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Initializes the transform.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
spatial_size: the spatial size of the crop region e.g. [224, 224, 128].
|
|
50
|
+
if a dimension of ROI size is larger than image size, will not crop that dimension.
|
|
51
|
+
if components have non-positive values, corresponding size of `label` will be used.
|
|
52
|
+
for example: if the spatial size of input data is [40, 40, 40] and
|
|
53
|
+
`spatial_size=[32, 64, -1]`, the spatial size of output data will be [32, 40, 40].
|
|
54
|
+
label: the label image that is used for finding foreground/background, if None, must
|
|
55
|
+
set at `self.__call__`. Non-zero indicates foreground, zero indicates background.
|
|
56
|
+
pos: used with `neg` together to calculate the ratio ``pos / (pos + neg)`` for
|
|
57
|
+
the probability to pick a foreground voxel as center rather than background voxel.
|
|
58
|
+
neg: used with `pos` together to calculate the ratio ``pos / (pos + neg)`` for
|
|
59
|
+
the probability to pick a foreground voxel as center rather than background voxel.
|
|
60
|
+
num_samples: number of samples (crop regions) to take in each list.
|
|
61
|
+
image: optional image data to help select valid area, can be same as `img` or another.
|
|
62
|
+
if not None, use ``label == 0 & image > image_threshold`` to select the negative
|
|
63
|
+
sample (background) center. Crop center will only come from valid image areas.
|
|
64
|
+
image_threshold: if enabled `image`, use ``image > image_threshold`` to determine
|
|
65
|
+
the valid image content areas.
|
|
66
|
+
fg_indices: if provided pre-computed foreground indices of `label`, will ignore `image`
|
|
67
|
+
and `image_threshold`, randomly select crop centers based on them, need to provide
|
|
68
|
+
`fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
|
|
69
|
+
a typical usage is to call `FgBgToIndices` transform first and cache the results.
|
|
70
|
+
bg_indices: if provided pre-computed background indices of `label`, will ignore `image`
|
|
71
|
+
and `image_threshold`, randomly select crop centers based on them, need to provide
|
|
72
|
+
`fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
|
|
73
|
+
a typical usage is to call `FgBgToIndices` transform first and cache the results.
|
|
74
|
+
allow_smaller: if `False`, an exception will be raised if the image is smaller than
|
|
75
|
+
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
|
|
76
|
+
match the cropped size (i.e., no cropping in that dimension).
|
|
77
|
+
"""
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
self._rand_crop = monai_croppad_transforms.RandCropByPosNegLabel(
|
|
81
|
+
spatial_size=spatial_size,
|
|
82
|
+
label=label,
|
|
83
|
+
pos=pos,
|
|
84
|
+
neg=neg,
|
|
85
|
+
num_samples=num_samples,
|
|
86
|
+
image=image,
|
|
87
|
+
image_threshold=image_threshold,
|
|
88
|
+
fg_indices=fg_indices,
|
|
89
|
+
bg_indices=bg_indices,
|
|
90
|
+
allow_smaller=allow_smaller,
|
|
91
|
+
lazy=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
95
|
+
mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
|
|
96
|
+
self._rand_crop.randomize(label=mask)
|
|
97
|
+
return {}
|
|
98
|
+
|
|
99
|
+
@functools.singledispatchmethod
|
|
100
|
+
@override
|
|
101
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
102
|
+
return inpt
|
|
103
|
+
|
|
104
|
+
@_transform.register(tv_tensors.Image)
|
|
105
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
106
|
+
@_transform.register(tv_tensors.Mask)
|
|
107
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
108
|
+
inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
|
|
109
|
+
return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""General purpose cropper to produce sub-volume region of interest (ROI)."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, Sequence
|
|
5
|
+
|
|
6
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
7
|
+
from monai.utils.enums import Method, PytorchPadMode
|
|
8
|
+
from torchvision import tv_tensors
|
|
9
|
+
from torchvision.transforms import v2
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SpatialPad(v2.Transform):
|
|
16
|
+
"""Performs padding to the data.
|
|
17
|
+
|
|
18
|
+
Padding is applied symmetric for all sides or all on one side for each dimension.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...],
|
|
24
|
+
method: str = Method.SYMMETRIC,
|
|
25
|
+
mode: str = PytorchPadMode.CONSTANT,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initializes the transform.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
spatial_size: The spatial size of output data after padding.
|
|
31
|
+
If a dimension of the input data size is larger than the
|
|
32
|
+
pad size, will not pad that dimension. If its components
|
|
33
|
+
have non-positive values, the corresponding size of input
|
|
34
|
+
image will be used (no padding). for example: if the spatial
|
|
35
|
+
size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`,
|
|
36
|
+
the spatial size of output data will be [32, 30, 30].
|
|
37
|
+
method: {``"symmetric"``, ``"end"``}
|
|
38
|
+
Pad image symmetrically on every side or only pad at the
|
|
39
|
+
end sides. Defaults to ``"symmetric"``.
|
|
40
|
+
mode: available modes for numpy array:{``"constant"``, ``"edge"``,
|
|
41
|
+
``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``,
|
|
42
|
+
``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
|
|
43
|
+
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``,
|
|
44
|
+
``"circular"``}. One of the listed string values or a user supplied function.
|
|
45
|
+
Defaults to ``"constant"``.
|
|
46
|
+
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
|
|
47
|
+
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
|
48
|
+
"""
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
self._spatial_pad = monai_croppad_transforms.SpatialPad(
|
|
52
|
+
spatial_size=spatial_size,
|
|
53
|
+
method=method,
|
|
54
|
+
mode=mode,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@functools.singledispatchmethod
|
|
58
|
+
@override
|
|
59
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
60
|
+
return inpt
|
|
61
|
+
|
|
62
|
+
@_transform.register(tv_tensors.Image)
|
|
63
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
64
|
+
@_transform.register(tv_tensors.Mask)
|
|
65
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
|
+
inpt_padded = self._spatial_pad(inpt)
|
|
67
|
+
return tv_tensors.wrap(inpt_padded, like=inpt)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Transforms for intensity adjustment."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.transforms.intensity.rand_scale_intensity import RandScaleIntensity
|
|
4
|
+
from eva.vision.data.transforms.intensity.rand_shift_intensity import RandShiftIntensity
|
|
5
|
+
from eva.vision.data.transforms.intensity.scale_intensity_ranged import ScaleIntensityRange
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"RandScaleIntensity",
|
|
9
|
+
"RandShiftIntensity",
|
|
10
|
+
"ScaleIntensityRange",
|
|
11
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Intensity scaling transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from monai.config.type_definitions import DtypeLike
|
|
8
|
+
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.transforms import v2
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RandScaleIntensity(v2.Transform):
|
|
17
|
+
"""Randomly scale the intensity of input image.
|
|
18
|
+
|
|
19
|
+
The factor is by ``v = v * (1 + factor)``, where
|
|
20
|
+
the `factor` is randomly picked.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
factors: tuple[float, float] | float,
|
|
26
|
+
prob: float = 0.1,
|
|
27
|
+
channel_wise: bool = False,
|
|
28
|
+
dtype: DtypeLike = np.float32,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes the transform.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
|
|
34
|
+
if single number, factor value is picked from (-factors, factors).
|
|
35
|
+
prob: probability of scale.
|
|
36
|
+
channel_wise: if True, shift intensity on each channel separately.
|
|
37
|
+
For each channel, a random offset will be chosen. Please ensure
|
|
38
|
+
that the first dimension represents the channel of the image if True.
|
|
39
|
+
dtype: output data type, if None, same as input image. defaults to float32.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self._rand_scale_intensity = monai_intensity_transforms.RandScaleIntensity(
|
|
44
|
+
factors=factors,
|
|
45
|
+
prob=prob,
|
|
46
|
+
channel_wise=channel_wise,
|
|
47
|
+
dtype=dtype,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@functools.singledispatchmethod
|
|
51
|
+
@override
|
|
52
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
53
|
+
return inpt
|
|
54
|
+
|
|
55
|
+
@_transform.register(tv_tensors.Image)
|
|
56
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
57
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
58
|
+
inpt_scaled = self._rand_scale_intensity(inpt)
|
|
59
|
+
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Intensity shifting transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RandShiftIntensity(v2.Transform):
|
|
15
|
+
"""Randomly shift intensity with randomly picked offset."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
offsets: tuple[float, float] | float,
|
|
20
|
+
safe: bool = False,
|
|
21
|
+
prob: float = 0.1,
|
|
22
|
+
channel_wise: bool = False,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the transform.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
offsets: Offset range to randomly shift.
|
|
28
|
+
if single number, offset value is picked from (-offsets, offsets).
|
|
29
|
+
safe: If `True`, then do safe dtype convert when intensity overflow.
|
|
30
|
+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then
|
|
31
|
+
`[256, -12]` -> `[array(255), array(0)]`.
|
|
32
|
+
prob: Probability of shift.
|
|
33
|
+
channel_wise: If True, shift intensity on each channel separately.
|
|
34
|
+
For each channel, a random offset will be chosen. Please ensure
|
|
35
|
+
that the first dimension represents the channel of the image if True.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self._rand_swift_intensity = monai_intensity_transforms.RandShiftIntensity(
|
|
40
|
+
offsets=offsets,
|
|
41
|
+
safe=safe,
|
|
42
|
+
prob=prob,
|
|
43
|
+
channel_wise=channel_wise,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@functools.singledispatchmethod
|
|
47
|
+
@override
|
|
48
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
49
|
+
return inpt
|
|
50
|
+
|
|
51
|
+
@_transform.register(tv_tensors.Image)
|
|
52
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
53
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
54
|
+
inpt_scaled = self._rand_swift_intensity(inpt)
|
|
55
|
+
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Intensity scaling transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ScaleIntensityRange(v2.Transform):
|
|
15
|
+
"""Intensity scaling transform.
|
|
16
|
+
|
|
17
|
+
Scaling from [a_min, a_max] to [b_min, b_max] with clip option.
|
|
18
|
+
|
|
19
|
+
When `b_min` or `b_max` are `None`, `scaled_array * (b_max - b_min) + b_min`
|
|
20
|
+
will be skipped. If `clip=True`, when `b_min`/`b_max` is None, the clipping
|
|
21
|
+
is not performed on the corresponding edge.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
input_range: Tuple[float, float],
|
|
27
|
+
output_range: Tuple[float, float] | None = None,
|
|
28
|
+
clip: bool = True,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes the transform.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
input_range: Intensity original range min and max.
|
|
34
|
+
output_range: Intensity target range min and max.
|
|
35
|
+
clip: Whether to perform clip after scaling.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self._scale_intensity_range = monai_intensity_transforms.ScaleIntensityRange(
|
|
40
|
+
a_min=input_range[0],
|
|
41
|
+
a_max=input_range[1],
|
|
42
|
+
b_min=output_range[0] if output_range else None,
|
|
43
|
+
b_max=output_range[1] if output_range else None,
|
|
44
|
+
clip=clip,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@functools.singledispatchmethod
|
|
48
|
+
@override
|
|
49
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
50
|
+
return inpt
|
|
51
|
+
|
|
52
|
+
@_transform.register(tv_tensors.Image)
|
|
53
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
54
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
55
|
+
inpt_scaled = self._scale_intensity_range(inpt)
|
|
56
|
+
return tv_tensors.wrap(inpt_scaled, like=inpt)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Transforms for spatial operations."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.transforms.spatial.flip import RandFlip
|
|
4
|
+
from eva.vision.data.transforms.spatial.rotate import RandRotate90
|
|
5
|
+
from eva.vision.data.transforms.spatial.spacing import Spacing
|
|
6
|
+
|
|
7
|
+
__all__ = ["Spacing", "RandFlip", "RandRotate90"]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Flip transforms."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from monai.transforms.spatial import array as monai_spatial_transforms
|
|
8
|
+
from torchvision import tv_tensors
|
|
9
|
+
from torchvision.transforms import v2
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RandFlip(v2.Transform):
|
|
16
|
+
"""Randomly flips the image along axes."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
prob: float = 0.1,
|
|
21
|
+
spatial_axes: Sequence[int] | int | None = None,
|
|
22
|
+
apply_per_axis: bool = True,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initializes the transform.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
prob: Probability of flipping.
|
|
28
|
+
spatial_axes: Spatial axes along which to flip over. Default is None.
|
|
29
|
+
apply_per_axis: If True, will apply a random flip transform to each
|
|
30
|
+
axis individually (if spatial_axes is a sequence of multiple axis).
|
|
31
|
+
If False, will apply a single random flip transform applied to all axes.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
if apply_per_axis:
|
|
36
|
+
if not isinstance(spatial_axes, (list, tuple)):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"`spatial_axis` is expected to be sequence `apply_per_axis` "
|
|
39
|
+
f"is enabled, got {type(spatial_axes)}"
|
|
40
|
+
)
|
|
41
|
+
self._flips = [
|
|
42
|
+
monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=axis)
|
|
43
|
+
for axis in spatial_axes
|
|
44
|
+
]
|
|
45
|
+
else:
|
|
46
|
+
self._flips = [monai_spatial_transforms.RandFlip(prob=prob, spatial_axis=spatial_axes)]
|
|
47
|
+
|
|
48
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
49
|
+
for flip in self._flips:
|
|
50
|
+
flip.randomize(None)
|
|
51
|
+
return {}
|
|
52
|
+
|
|
53
|
+
@functools.singledispatchmethod
|
|
54
|
+
@override
|
|
55
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
56
|
+
return inpt
|
|
57
|
+
|
|
58
|
+
@_transform.register(tv_tensors.Image)
|
|
59
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
60
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
61
|
+
inpt_flipped = self._apply_flips(inpt)
|
|
62
|
+
return tv_tensors.wrap(inpt_flipped, like=inpt)
|
|
63
|
+
|
|
64
|
+
@_transform.register(tv_tensors.Mask)
|
|
65
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
|
+
inpt_flipped = torch.tensor(self._apply_flips(inpt), dtype=torch.long)
|
|
67
|
+
return tv_tensors.wrap(inpt_flipped, like=inpt)
|
|
68
|
+
|
|
69
|
+
def _apply_flips(self, inpt: Any) -> Any:
|
|
70
|
+
for flip in self._flips:
|
|
71
|
+
inpt = flip(img=inpt, randomize=False)
|
|
72
|
+
return inpt
|