kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/data/datasets/classification/embeddings.py +2 -2
- eva/core/data/datasets/classification/multi_embeddings.py +2 -2
- eva/core/data/datasets/embeddings.py +4 -4
- eva/core/data/samplers/classification/balanced.py +19 -18
- eva/core/loggers/utils/wandb.py +33 -0
- eva/core/models/modules/head.py +5 -3
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/functional.py +8 -5
- eva/core/trainers/trainer.py +32 -17
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +10 -2
- eva/vision/data/datasets/classification/__init__.py +9 -0
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +111 -0
- eva/vision/data/datasets/classification/breakhis.py +209 -0
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +158 -0
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/embeddings.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,12 +16,11 @@ from torchvision.datasets import utils
|
|
|
16
16
|
from typing_extensions import override
|
|
17
17
|
|
|
18
18
|
from eva.core.utils.progress_bar import tqdm
|
|
19
|
-
from eva.vision.data.datasets import _validators, structs
|
|
20
|
-
from eva.vision.data.datasets.segmentation import base
|
|
19
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
21
20
|
from eva.vision.utils import io
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
class MoNuSAC(
|
|
23
|
+
class MoNuSAC(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
25
24
|
"""MoNuSAC2020: A Multi-organ Nuclei Segmentation and Classification Challenge.
|
|
26
25
|
|
|
27
26
|
Webpage: https://monusac-2020.grand-challenge.org/
|
|
@@ -112,13 +111,13 @@ class MoNuSAC(base.ImageSegmentation):
|
|
|
112
111
|
)
|
|
113
112
|
|
|
114
113
|
@override
|
|
115
|
-
def
|
|
114
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
116
115
|
image_path = self._image_files[index]
|
|
117
116
|
image_rgb_array = io.read_image(image_path)
|
|
118
117
|
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
|
|
119
118
|
|
|
120
119
|
@override
|
|
121
|
-
def
|
|
120
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
122
121
|
semantic_labels = (
|
|
123
122
|
self._load_semantic_mask_file(index)
|
|
124
123
|
if self._export_masks
|
|
@@ -17,12 +17,12 @@ from typing_extensions import override
|
|
|
17
17
|
|
|
18
18
|
from eva.core.utils import io as core_io
|
|
19
19
|
from eva.core.utils import multiprocessing
|
|
20
|
-
from eva.vision.data.datasets import _validators, structs
|
|
21
|
-
from eva.vision.data.datasets.segmentation import _total_segmentator
|
|
20
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
21
|
+
from eva.vision.data.datasets.segmentation import _total_segmentator
|
|
22
22
|
from eva.vision.utils import io
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
class TotalSegmentator2D(
|
|
25
|
+
class TotalSegmentator2D(vision.VisionDataset[tv_tensors.Image, tv_tensors.Mask]):
|
|
26
26
|
"""TotalSegmentator 2D segmentation dataset."""
|
|
27
27
|
|
|
28
28
|
_expected_dataset_lengths: Dict[str, int] = {
|
|
@@ -206,19 +206,20 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
206
206
|
return len(self._indices)
|
|
207
207
|
|
|
208
208
|
@override
|
|
209
|
-
def
|
|
209
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
210
210
|
sample_index, slice_index = self._indices[index]
|
|
211
211
|
image_path = self._get_image_path(sample_index)
|
|
212
|
-
|
|
212
|
+
image_nii = io.read_nifti(image_path, slice_index)
|
|
213
|
+
image_array = io.nifti_to_array(image_nii)
|
|
213
214
|
image_array = self._fix_orientation(image_array)
|
|
214
215
|
return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))
|
|
215
216
|
|
|
216
217
|
@override
|
|
217
|
-
def
|
|
218
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
218
219
|
if self._optimize_mask_loading:
|
|
219
220
|
mask = self._load_semantic_label_mask(index)
|
|
220
221
|
else:
|
|
221
|
-
mask = self.
|
|
222
|
+
mask = self._load_target(index)
|
|
222
223
|
mask = self._fix_orientation(mask)
|
|
223
224
|
return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore
|
|
224
225
|
|
|
@@ -227,14 +228,15 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
227
228
|
_, slice_index = self._indices[index]
|
|
228
229
|
return {"slice_index": slice_index}
|
|
229
230
|
|
|
230
|
-
def
|
|
231
|
+
def _load_target(self, index: int) -> npt.NDArray[Any]:
|
|
231
232
|
sample_index, slice_index = self._indices[index]
|
|
232
233
|
return self._load_masks_as_semantic_label(sample_index, slice_index)
|
|
233
234
|
|
|
234
235
|
def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
|
|
235
236
|
"""Loads the segmentation mask from a semantic label NifTi file."""
|
|
236
237
|
sample_index, slice_index = self._indices[index]
|
|
237
|
-
|
|
238
|
+
nii = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
|
|
239
|
+
return io.nifti_to_array(nii)
|
|
238
240
|
|
|
239
241
|
def _load_masks_as_semantic_label(
|
|
240
242
|
self, sample_index: int, slice_index: int | None = None
|
|
@@ -248,7 +250,7 @@ class TotalSegmentator2D(base.ImageSegmentation):
|
|
|
248
250
|
masks_dir = self._get_masks_dir(sample_index)
|
|
249
251
|
classes = self._class_mappings.keys() if self._class_mappings else self.classes[1:]
|
|
250
252
|
mask_paths = [os.path.join(masks_dir, f"{label}.nii.gz") for label in classes]
|
|
251
|
-
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]
|
|
253
|
+
binary_masks = [io.nifti_to_array(io.read_nifti(path, slice_index)) for path in mask_paths]
|
|
252
254
|
|
|
253
255
|
if self._class_mappings:
|
|
254
256
|
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
|
|
@@ -1,17 +1,92 @@
|
|
|
1
1
|
"""Vision Dataset base class."""
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import Generic, TypeVar
|
|
4
|
+
from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar
|
|
5
5
|
|
|
6
6
|
from eva.core.data.datasets import base
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
"""The data
|
|
8
|
+
InputType = TypeVar("InputType")
|
|
9
|
+
"""The input data type."""
|
|
10
10
|
|
|
11
|
+
TargetType = TypeVar("TargetType")
|
|
12
|
+
"""The target data type."""
|
|
11
13
|
|
|
12
|
-
|
|
14
|
+
|
|
15
|
+
class VisionDataset(
|
|
16
|
+
base.MapDataset[Tuple[InputType, TargetType, Dict[str, Any]]],
|
|
17
|
+
abc.ABC,
|
|
18
|
+
Generic[InputType, TargetType],
|
|
19
|
+
):
|
|
13
20
|
"""Base dataset class for vision tasks."""
|
|
14
21
|
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
transforms: Callable | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Initializes the dataset.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
transforms: A function/transform which returns a transformed
|
|
30
|
+
version of the raw data samples.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
self._transforms = transforms
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def classes(self) -> List[str] | None:
|
|
38
|
+
"""Returns the list with names of the dataset names."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def class_to_idx(self) -> Dict[str, int] | None:
|
|
42
|
+
"""Returns a mapping of the class name to its target index."""
|
|
43
|
+
|
|
44
|
+
def __getitem__(self, index: int) -> Tuple[InputType, TargetType, Dict[str, Any]]:
|
|
45
|
+
"""Returns the `index`'th data sample.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
index: The index of the data sample to load.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A tuple with the image, the target and the metadata.
|
|
52
|
+
"""
|
|
53
|
+
image = self.load_data(index)
|
|
54
|
+
target = self.load_target(index)
|
|
55
|
+
image, target = self._apply_transforms(image, target)
|
|
56
|
+
return image, target, self.load_metadata(index) or {}
|
|
57
|
+
|
|
58
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
59
|
+
"""Returns the dataset metadata.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
index: The index of the data sample to return the metadata of.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The sample metadata.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
@abc.abstractmethod
|
|
69
|
+
def load_data(self, index: int) -> InputType:
|
|
70
|
+
"""Returns the `index`'th data sample.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
index: The index of the data sample to load.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The sample data.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
@abc.abstractmethod
|
|
80
|
+
def load_target(self, index: int) -> TargetType:
|
|
81
|
+
"""Returns the `index`'th target sample.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
index: The index of the data sample to load.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The sample target.
|
|
88
|
+
"""
|
|
89
|
+
|
|
15
90
|
@abc.abstractmethod
|
|
16
91
|
def filename(self, index: int) -> str:
|
|
17
92
|
"""Returns the filename of the `index`'th data sample.
|
|
@@ -24,3 +99,19 @@ class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
|
|
|
24
99
|
Returns:
|
|
25
100
|
The filename of the `index`'th data sample.
|
|
26
101
|
"""
|
|
102
|
+
|
|
103
|
+
def _apply_transforms(
|
|
104
|
+
self, image: InputType, target: TargetType
|
|
105
|
+
) -> Tuple[InputType, TargetType]:
|
|
106
|
+
"""Applies the transforms to the provided data and returns them.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
image: The desired image.
|
|
110
|
+
target: The target of the image.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A tuple with the image and the target transformed.
|
|
114
|
+
"""
|
|
115
|
+
if self._transforms is not None:
|
|
116
|
+
image, target = self._transforms(image, target)
|
|
117
|
+
return image, target
|
eva/vision/data/datasets/wsi.py
CHANGED
|
@@ -11,12 +11,12 @@ from torchvision import tv_tensors
|
|
|
11
11
|
from torchvision.transforms.v2 import functional
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
+
from eva.core.data.datasets import base
|
|
14
15
|
from eva.vision.data import wsi
|
|
15
|
-
from eva.vision.data.datasets import vision
|
|
16
16
|
from eva.vision.data.wsi.patching import samplers
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class WsiDataset(
|
|
19
|
+
class WsiDataset(base.MapDataset):
|
|
20
20
|
"""Dataset class for reading patches from whole-slide images."""
|
|
21
21
|
|
|
22
22
|
def __init__(
|
|
@@ -57,8 +57,8 @@ class WsiDataset(vision.VisionDataset):
|
|
|
57
57
|
def __len__(self):
|
|
58
58
|
return len(self._coords.x_y)
|
|
59
59
|
|
|
60
|
-
@override
|
|
61
60
|
def filename(self, index: int) -> str:
|
|
61
|
+
"""Returns the filename of the patch at the specified index."""
|
|
62
62
|
return f"{self._file_path}_{index}"
|
|
63
63
|
|
|
64
64
|
@property
|
|
@@ -103,7 +103,7 @@ class WsiDataset(vision.VisionDataset):
|
|
|
103
103
|
return image
|
|
104
104
|
|
|
105
105
|
|
|
106
|
-
class MultiWsiDataset(
|
|
106
|
+
class MultiWsiDataset(base.MapDataset):
|
|
107
107
|
"""Dataset class for reading patches from multiple whole-slide images."""
|
|
108
108
|
|
|
109
109
|
def __init__(
|
|
@@ -171,8 +171,8 @@ class MultiWsiDataset(vision.VisionDataset):
|
|
|
171
171
|
def __getitem__(self, index: int) -> tv_tensors.Image:
|
|
172
172
|
return self._concat_dataset[index]
|
|
173
173
|
|
|
174
|
-
@override
|
|
175
174
|
def filename(self, index: int) -> str:
|
|
175
|
+
"""Returns the filename of the patch at the specified index."""
|
|
176
176
|
return os.path.basename(self._file_paths[self._get_dataset_idx(index)])
|
|
177
177
|
|
|
178
178
|
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
@@ -1,6 +1,25 @@
|
|
|
1
1
|
"""Vision data transforms."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.transforms.common import
|
|
4
|
-
from eva.vision.data.transforms.
|
|
3
|
+
from eva.vision.data.transforms.common import ResizeAndCrop
|
|
4
|
+
from eva.vision.data.transforms.croppad import CropForeground, RandCropByPosNegLabel, SpatialPad
|
|
5
|
+
from eva.vision.data.transforms.intensity import (
|
|
6
|
+
RandScaleIntensity,
|
|
7
|
+
RandShiftIntensity,
|
|
8
|
+
ScaleIntensityRange,
|
|
9
|
+
)
|
|
10
|
+
from eva.vision.data.transforms.spatial import RandFlip, RandRotate90, Spacing
|
|
11
|
+
from eva.vision.data.transforms.utility import EnsureChannelFirst
|
|
5
12
|
|
|
6
|
-
__all__ = [
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ResizeAndCrop",
|
|
15
|
+
"CropForeground",
|
|
16
|
+
"RandCropByPosNegLabel",
|
|
17
|
+
"SpatialPad",
|
|
18
|
+
"RandScaleIntensity",
|
|
19
|
+
"RandShiftIntensity",
|
|
20
|
+
"ScaleIntensityRange",
|
|
21
|
+
"RandFlip",
|
|
22
|
+
"RandRotate90",
|
|
23
|
+
"Spacing",
|
|
24
|
+
"EnsureChannelFirst",
|
|
25
|
+
]
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Common vision transforms."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.transforms.common.resize_and_clamp import ResizeAndClamp
|
|
4
3
|
from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
|
|
5
4
|
|
|
6
|
-
__all__ = ["
|
|
5
|
+
__all__ = ["ResizeAndCrop"]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Transforms for crop and pad operations."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.transforms.croppad.crop_foreground import CropForeground
|
|
4
|
+
from eva.vision.data.transforms.croppad.rand_crop_by_pos_neg_label import RandCropByPosNegLabel
|
|
5
|
+
from eva.vision.data.transforms.croppad.spatial_pad import SpatialPad
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CropForeground",
|
|
9
|
+
"RandCropByPosNegLabel",
|
|
10
|
+
"SpatialPad",
|
|
11
|
+
]
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Crop foreground transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from monai.config import type_definitions
|
|
8
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
|
+
from monai.utils.enums import PytorchPadMode
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.transforms import v2
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CropForeground(v2.Transform):
|
|
18
|
+
"""Crop an image using a bounding box.
|
|
19
|
+
|
|
20
|
+
The bounding box is generated by selecting foreground using select_fn
|
|
21
|
+
at channels channel_indices. margin is added in each spatial dimension
|
|
22
|
+
of the bounding box. The typical usage is to help training and evaluation
|
|
23
|
+
if the valid part is small in the whole medical image.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
threshold: float = 0.0,
|
|
29
|
+
channel_indices: type_definitions.IndexSelection | None = None,
|
|
30
|
+
margin: Sequence[int] | int = 0,
|
|
31
|
+
allow_smaller: bool = True,
|
|
32
|
+
return_coords: bool = False,
|
|
33
|
+
k_divisible: Sequence[int] | int = 1,
|
|
34
|
+
mode: str = PytorchPadMode.CONSTANT,
|
|
35
|
+
**pad_kwargs,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Initializes the transform.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
threshold: function to select expected foreground.
|
|
41
|
+
channel_indices: if defined, select foreground only on the specified channels
|
|
42
|
+
of image. if None, select foreground on the whole image.
|
|
43
|
+
margin: add margin value to spatial dims of the bounding box, if only 1 value provided,
|
|
44
|
+
use it for all dims.
|
|
45
|
+
allow_smaller: when computing box size with `margin`, whether to allow the image edges
|
|
46
|
+
to be smaller than the final box edges. If `False`, part of a padded output box
|
|
47
|
+
might be outside of the original image, if `True`, the image edges will be used as
|
|
48
|
+
the box edges. Default to `True`.
|
|
49
|
+
return_coords: whether return the coordinates of spatial bounding box for foreground.
|
|
50
|
+
k_divisible: make each spatial dimension to be divisible by k, default to 1.
|
|
51
|
+
if `k_divisible` is an int, the same `k` be applied to all the input spatial
|
|
52
|
+
dimensions.
|
|
53
|
+
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``,
|
|
54
|
+
``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``,
|
|
55
|
+
``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor:
|
|
56
|
+
{``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed
|
|
57
|
+
string values or a user supplied function. Defaults to ``"constant"``.
|
|
58
|
+
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
|
|
59
|
+
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
|
60
|
+
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
|
|
61
|
+
note that `np.pad` treats channel dimension as the first dimension.
|
|
62
|
+
"""
|
|
63
|
+
super().__init__()
|
|
64
|
+
|
|
65
|
+
self._foreground_crop = monai_croppad_transforms.CropForeground(
|
|
66
|
+
select_fn=functools.partial(_threshold_fn, threshold=threshold),
|
|
67
|
+
channel_indices=channel_indices,
|
|
68
|
+
margin=margin,
|
|
69
|
+
allow_smaller=allow_smaller,
|
|
70
|
+
return_coords=return_coords,
|
|
71
|
+
k_divisible=k_divisible,
|
|
72
|
+
mode=mode,
|
|
73
|
+
lazy=False,
|
|
74
|
+
**pad_kwargs,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
78
|
+
volume = next(inpt for inpt in flat_inputs if isinstance(inpt, eva_tv_tensors.Volume))
|
|
79
|
+
box_start, box_end = self._foreground_crop.compute_bounding_box(volume)
|
|
80
|
+
return {"box_start": box_start, "box_end": box_end}
|
|
81
|
+
|
|
82
|
+
@functools.singledispatchmethod
|
|
83
|
+
@override
|
|
84
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
85
|
+
return inpt
|
|
86
|
+
|
|
87
|
+
@_transform.register(tv_tensors.Image)
|
|
88
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
89
|
+
@_transform.register(tv_tensors.Mask)
|
|
90
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
91
|
+
inpt_foreground_cropped = self._foreground_crop.crop_pad(
|
|
92
|
+
inpt, params["box_start"], params["box_end"]
|
|
93
|
+
)
|
|
94
|
+
return tv_tensors.wrap(inpt_foreground_cropped, like=inpt)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _threshold_fn(image: torch.Tensor, threshold: int | float = 0.3) -> torch.Tensor:
|
|
98
|
+
"""Applies a thresholding operation to an image tensor.
|
|
99
|
+
|
|
100
|
+
Pixels greater than the threshold are set to True, while others are False.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
image: Input image tensor with pixel values.
|
|
104
|
+
threshold: Threshold value.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
A binary mask tensor of the same shape as `image`,
|
|
108
|
+
where True represents pixels above the threshold.
|
|
109
|
+
"""
|
|
110
|
+
return image > threshold
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Crop foreground transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from monai.config.type_definitions import NdarrayOrTensor
|
|
8
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.transforms import v2
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RandCropByPosNegLabel(v2.Transform):
|
|
17
|
+
"""Crop random fixed sized regions with the center being a foreground or background voxel.
|
|
18
|
+
|
|
19
|
+
Its based on the Pos Neg Ratio and will return a list of arrays for all the cropped images.
|
|
20
|
+
For example, crop two (3 x 3) arrays from (5 x 5) array with pos/neg=1::
|
|
21
|
+
|
|
22
|
+
[[[0, 0, 0, 0, 0],
|
|
23
|
+
[0, 1, 2, 1, 0], [[0, 1, 2], [[2, 1, 0],
|
|
24
|
+
[0, 1, 3, 0, 0], --> [0, 1, 3], [3, 0, 0],
|
|
25
|
+
[0, 0, 0, 0, 0], [0, 0, 0]] [0, 0, 0]]
|
|
26
|
+
[0, 0, 0, 0, 0]]]
|
|
27
|
+
|
|
28
|
+
If a dimension of the expected spatial size is larger than the input image size,
|
|
29
|
+
will not crop that dimension. So the cropped result may be smaller than expected
|
|
30
|
+
size, and the cropped results of several images may not have exactly same shape.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
spatial_size: Sequence[int] | int,
|
|
36
|
+
label: torch.Tensor | None = None,
|
|
37
|
+
pos: float = 1.0,
|
|
38
|
+
neg: float = 1.0,
|
|
39
|
+
num_samples: int = 1,
|
|
40
|
+
image: torch.Tensor | None = None,
|
|
41
|
+
image_threshold: float = 0.0,
|
|
42
|
+
fg_indices: NdarrayOrTensor | None = None,
|
|
43
|
+
bg_indices: NdarrayOrTensor | None = None,
|
|
44
|
+
allow_smaller: bool = False,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Initializes the transform.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
spatial_size: the spatial size of the crop region e.g. [224, 224, 128].
|
|
50
|
+
if a dimension of ROI size is larger than image size, will not crop that dimension.
|
|
51
|
+
if components have non-positive values, corresponding size of `label` will be used.
|
|
52
|
+
for example: if the spatial size of input data is [40, 40, 40] and
|
|
53
|
+
`spatial_size=[32, 64, -1]`, the spatial size of output data will be [32, 40, 40].
|
|
54
|
+
label: the label image that is used for finding foreground/background, if None, must
|
|
55
|
+
set at `self.__call__`. Non-zero indicates foreground, zero indicates background.
|
|
56
|
+
pos: used with `neg` together to calculate the ratio ``pos / (pos + neg)`` for
|
|
57
|
+
the probability to pick a foreground voxel as center rather than background voxel.
|
|
58
|
+
neg: used with `pos` together to calculate the ratio ``pos / (pos + neg)`` for
|
|
59
|
+
the probability to pick a foreground voxel as center rather than background voxel.
|
|
60
|
+
num_samples: number of samples (crop regions) to take in each list.
|
|
61
|
+
image: optional image data to help select valid area, can be same as `img` or another.
|
|
62
|
+
if not None, use ``label == 0 & image > image_threshold`` to select the negative
|
|
63
|
+
sample (background) center. Crop center will only come from valid image areas.
|
|
64
|
+
image_threshold: if enabled `image`, use ``image > image_threshold`` to determine
|
|
65
|
+
the valid image content areas.
|
|
66
|
+
fg_indices: if provided pre-computed foreground indices of `label`, will ignore `image`
|
|
67
|
+
and `image_threshold`, randomly select crop centers based on them, need to provide
|
|
68
|
+
`fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
|
|
69
|
+
a typical usage is to call `FgBgToIndices` transform first and cache the results.
|
|
70
|
+
bg_indices: if provided pre-computed background indices of `label`, will ignore `image`
|
|
71
|
+
and `image_threshold`, randomly select crop centers based on them, need to provide
|
|
72
|
+
`fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices.
|
|
73
|
+
a typical usage is to call `FgBgToIndices` transform first and cache the results.
|
|
74
|
+
allow_smaller: if `False`, an exception will be raised if the image is smaller than
|
|
75
|
+
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
|
|
76
|
+
match the cropped size (i.e., no cropping in that dimension).
|
|
77
|
+
"""
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
self._rand_crop = monai_croppad_transforms.RandCropByPosNegLabel(
|
|
81
|
+
spatial_size=spatial_size,
|
|
82
|
+
label=label,
|
|
83
|
+
pos=pos,
|
|
84
|
+
neg=neg,
|
|
85
|
+
num_samples=num_samples,
|
|
86
|
+
image=image,
|
|
87
|
+
image_threshold=image_threshold,
|
|
88
|
+
fg_indices=fg_indices,
|
|
89
|
+
bg_indices=bg_indices,
|
|
90
|
+
allow_smaller=allow_smaller,
|
|
91
|
+
lazy=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
95
|
+
mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
|
|
96
|
+
self._rand_crop.randomize(label=mask)
|
|
97
|
+
return {}
|
|
98
|
+
|
|
99
|
+
@functools.singledispatchmethod
|
|
100
|
+
@override
|
|
101
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
102
|
+
return inpt
|
|
103
|
+
|
|
104
|
+
@_transform.register(tv_tensors.Image)
|
|
105
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
106
|
+
@_transform.register(tv_tensors.Mask)
|
|
107
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
108
|
+
inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
|
|
109
|
+
return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""General purpose cropper to produce sub-volume region of interest (ROI)."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, Sequence
|
|
5
|
+
|
|
6
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
7
|
+
from monai.utils.enums import Method, PytorchPadMode
|
|
8
|
+
from torchvision import tv_tensors
|
|
9
|
+
from torchvision.transforms import v2
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SpatialPad(v2.Transform):
|
|
16
|
+
"""Performs padding to the data.
|
|
17
|
+
|
|
18
|
+
Padding is applied symmetric for all sides or all on one side for each dimension.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...],
|
|
24
|
+
method: str = Method.SYMMETRIC,
|
|
25
|
+
mode: str = PytorchPadMode.CONSTANT,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initializes the transform.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
spatial_size: The spatial size of output data after padding.
|
|
31
|
+
If a dimension of the input data size is larger than the
|
|
32
|
+
pad size, will not pad that dimension. If its components
|
|
33
|
+
have non-positive values, the corresponding size of input
|
|
34
|
+
image will be used (no padding). for example: if the spatial
|
|
35
|
+
size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`,
|
|
36
|
+
the spatial size of output data will be [32, 30, 30].
|
|
37
|
+
method: {``"symmetric"``, ``"end"``}
|
|
38
|
+
Pad image symmetrically on every side or only pad at the
|
|
39
|
+
end sides. Defaults to ``"symmetric"``.
|
|
40
|
+
mode: available modes for numpy array:{``"constant"``, ``"edge"``,
|
|
41
|
+
``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``,
|
|
42
|
+
``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
|
|
43
|
+
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``,
|
|
44
|
+
``"circular"``}. One of the listed string values or a user supplied function.
|
|
45
|
+
Defaults to ``"constant"``.
|
|
46
|
+
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
|
|
47
|
+
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
|
48
|
+
"""
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
self._spatial_pad = monai_croppad_transforms.SpatialPad(
|
|
52
|
+
spatial_size=spatial_size,
|
|
53
|
+
method=method,
|
|
54
|
+
mode=mode,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@functools.singledispatchmethod
|
|
58
|
+
@override
|
|
59
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
60
|
+
return inpt
|
|
61
|
+
|
|
62
|
+
@_transform.register(tv_tensors.Image)
|
|
63
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
64
|
+
@_transform.register(tv_tensors.Mask)
|
|
65
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
66
|
+
inpt_padded = self._spatial_pad(inpt)
|
|
67
|
+
return tv_tensors.wrap(inpt_padded, like=inpt)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Transforms for intensity adjustment."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.transforms.intensity.rand_scale_intensity import RandScaleIntensity
|
|
4
|
+
from eva.vision.data.transforms.intensity.rand_shift_intensity import RandShiftIntensity
|
|
5
|
+
from eva.vision.data.transforms.intensity.scale_intensity_ranged import ScaleIntensityRange
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"RandScaleIntensity",
|
|
9
|
+
"RandShiftIntensity",
|
|
10
|
+
"ScaleIntensityRange",
|
|
11
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Intensity scaling transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from monai.config.type_definitions import DtypeLike
|
|
8
|
+
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.transforms import v2
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RandScaleIntensity(v2.Transform):
|
|
17
|
+
"""Randomly scale the intensity of input image.
|
|
18
|
+
|
|
19
|
+
The factor is by ``v = v * (1 + factor)``, where
|
|
20
|
+
the `factor` is randomly picked.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
factors: tuple[float, float] | float,
|
|
26
|
+
prob: float = 0.1,
|
|
27
|
+
channel_wise: bool = False,
|
|
28
|
+
dtype: DtypeLike = np.float32,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes the transform.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
|
|
34
|
+
if single number, factor value is picked from (-factors, factors).
|
|
35
|
+
prob: probability of scale.
|
|
36
|
+
channel_wise: if True, shift intensity on each channel separately.
|
|
37
|
+
For each channel, a random offset will be chosen. Please ensure
|
|
38
|
+
that the first dimension represents the channel of the image if True.
|
|
39
|
+
dtype: output data type, if None, same as input image. defaults to float32.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self._rand_scale_intensity = monai_intensity_transforms.RandScaleIntensity(
|
|
44
|
+
factors=factors,
|
|
45
|
+
prob=prob,
|
|
46
|
+
channel_wise=channel_wise,
|
|
47
|
+
dtype=dtype,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@functools.singledispatchmethod
|
|
51
|
+
@override
|
|
52
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
53
|
+
return inpt
|
|
54
|
+
|
|
55
|
+
@_transform.register(tv_tensors.Image)
|
|
56
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
57
|
+
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
58
|
+
inpt_scaled = self._rand_scale_intensity(inpt)
|
|
59
|
+
return tv_tensors.wrap(inpt_scaled, like=inpt)
|