kaiko-eva 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/data/dataloaders/__init__.py +2 -1
- eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/core/data/dataloaders/collate_fn/collate.py +24 -0
- eva/core/data/dataloaders/dataloader.py +4 -0
- eva/core/interface/interface.py +34 -1
- eva/core/metrics/defaults/classification/multiclass.py +45 -35
- eva/core/models/modules/__init__.py +2 -1
- eva/core/models/modules/scheduler.py +51 -0
- eva/core/models/transforms/extract_cls_features.py +1 -1
- eva/core/models/transforms/extract_patch_features.py +1 -1
- eva/core/models/wrappers/base.py +17 -14
- eva/core/models/wrappers/from_function.py +5 -4
- eva/core/models/wrappers/from_torchhub.py +5 -6
- eva/core/models/wrappers/huggingface.py +8 -5
- eva/core/models/wrappers/onnx.py +4 -4
- eva/core/trainers/functional.py +40 -43
- eva/core/utils/factory.py +66 -0
- eva/core/utils/registry.py +42 -0
- eva/core/utils/requirements.py +26 -0
- eva/language/__init__.py +13 -0
- eva/language/data/__init__.py +5 -0
- eva/language/data/datasets/__init__.py +9 -0
- eva/language/data/datasets/classification/__init__.py +7 -0
- eva/language/data/datasets/classification/base.py +63 -0
- eva/language/data/datasets/classification/pubmedqa.py +149 -0
- eva/language/data/datasets/language.py +13 -0
- eva/language/models/__init__.py +25 -0
- eva/language/models/modules/__init__.py +5 -0
- eva/language/models/modules/text.py +85 -0
- eva/language/models/modules/typings.py +16 -0
- eva/language/models/wrappers/__init__.py +11 -0
- eva/language/models/wrappers/huggingface.py +69 -0
- eva/language/models/wrappers/litellm.py +77 -0
- eva/language/models/wrappers/vllm.py +149 -0
- eva/language/utils/__init__.py +5 -0
- eva/language/utils/str_to_int_tensor.py +95 -0
- eva/vision/data/dataloaders/__init__.py +2 -1
- eva/vision/data/dataloaders/worker_init.py +35 -0
- eva/vision/data/datasets/__init__.py +5 -5
- eva/vision/data/datasets/segmentation/__init__.py +4 -4
- eva/vision/data/datasets/segmentation/btcv.py +3 -0
- eva/vision/data/datasets/segmentation/consep.py +5 -4
- eva/vision/data/datasets/segmentation/lits17.py +231 -0
- eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
- eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
- eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
- eva/vision/data/transforms/__init__.py +11 -2
- eva/vision/data/transforms/base/__init__.py +5 -0
- eva/vision/data/transforms/base/monai.py +27 -0
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/squeeze.py +24 -0
- eva/vision/data/transforms/croppad/__init__.py +4 -0
- eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
- eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
- eva/vision/models/modules/semantic_segmentation.py +18 -7
- eva/vision/models/networks/backbones/__init__.py +2 -3
- eva/vision/models/networks/backbones/_utils.py +1 -1
- eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
- eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
- eva/vision/models/networks/backbones/pathology/histai.py +3 -3
- eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
- eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
- eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
- eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
- eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
- eva/vision/models/networks/backbones/pathology/paige.py +3 -3
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
- eva/vision/models/networks/backbones/radiology/voco.py +5 -5
- eva/vision/models/networks/backbones/registry.py +2 -44
- eva/vision/models/networks/backbones/timm/backbones.py +2 -2
- eva/vision/models/networks/backbones/universal/__init__.py +8 -1
- eva/vision/models/networks/backbones/universal/vit.py +53 -3
- eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
- eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
- eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
- eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
- eva/vision/models/wrappers/from_registry.py +14 -9
- eva/vision/models/wrappers/from_timm.py +6 -5
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/METADATA +10 -2
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/RECORD +88 -57
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/segmentation/lits.py +0 -199
- eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
- /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""Dataset for Task 7 (pancreas tumor) from the Medical Segmentation Decathlon (MSD)."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import huggingface_hub
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.datasets import utils as data_utils
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.core.utils import requirements
|
|
14
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
15
|
+
from eva.vision.data.datasets.segmentation import _utils
|
|
16
|
+
from eva.vision.data.datasets.segmentation.metadata import _msd_task7_pancreas
|
|
17
|
+
from eva.vision.data.datasets.vision import VisionDataset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MSDTask7Pancreas(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
|
|
21
|
+
"""Task 7 (pancreas tumor) of the Medical Segmentation Decathlon (MSD).
|
|
22
|
+
|
|
23
|
+
The data set consists of 420 portal-venous phase CT scans of patients undergoing
|
|
24
|
+
resection of pancreatic masses. The corresponding target ROIs were the pancreatic
|
|
25
|
+
parenchyma and pancreatic mass (cyst or tumor). This data set was selected due to
|
|
26
|
+
label unbalance between large (background), medium (pancreas) and small (tumor)
|
|
27
|
+
structures. The data was acquired in the Memorial Sloan Kettering Cancer
|
|
28
|
+
Center, New York, US.
|
|
29
|
+
|
|
30
|
+
More info:
|
|
31
|
+
- Paper: https://www.nature.com/articles/s41467-022-30695-9
|
|
32
|
+
- Dataset source: https://huggingface.co/datasets/Luffy503/VoCo_Downstream
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
_train_ids = _msd_task7_pancreas.train_ids
|
|
36
|
+
"""File indices of the training set."""
|
|
37
|
+
|
|
38
|
+
_val_ids = _msd_task7_pancreas.val_ids
|
|
39
|
+
"""File indices of the validation set."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
root: str,
|
|
44
|
+
split: Literal["train", "val"] | None = None,
|
|
45
|
+
download: bool = False,
|
|
46
|
+
transforms: Callable | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initializes the dataset.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
root: Path to the dataset root directory.
|
|
52
|
+
split: Dataset split to use ('train' or 'val').
|
|
53
|
+
If None, it uses the full dataset.
|
|
54
|
+
download: Whether to download the dataset.
|
|
55
|
+
transforms: A callable object for applying data transformations.
|
|
56
|
+
If None, no transformations are applied.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__()
|
|
59
|
+
|
|
60
|
+
self._root = root
|
|
61
|
+
self._split = split
|
|
62
|
+
self._download = download
|
|
63
|
+
self._transforms = transforms
|
|
64
|
+
|
|
65
|
+
self._samples: Dict[int, Tuple[str, str]]
|
|
66
|
+
self._indices: List[int]
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
@override
|
|
70
|
+
def classes(self) -> List[str]:
|
|
71
|
+
return [
|
|
72
|
+
"background",
|
|
73
|
+
"pancreas",
|
|
74
|
+
"cancer",
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
@override
|
|
79
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
80
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
def filename(self, index: int) -> str:
|
|
84
|
+
return os.path.relpath(self._samples[self._indices[index]][0], self._root)
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def prepare_data(self) -> None:
|
|
88
|
+
if self._download:
|
|
89
|
+
self._download_dataset()
|
|
90
|
+
|
|
91
|
+
@override
|
|
92
|
+
def configure(self) -> None:
|
|
93
|
+
self._samples = self._find_samples()
|
|
94
|
+
self._indices = self._make_indices()
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def validate(self) -> None:
|
|
98
|
+
requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
|
|
99
|
+
|
|
100
|
+
def _valid_sample(index: int) -> bool:
|
|
101
|
+
"""Indicates if the sample files exist and are reachable."""
|
|
102
|
+
volume_file, segmentation_file = self._samples[self._indices[index]]
|
|
103
|
+
return os.path.isfile(volume_file) and os.path.isfile(segmentation_file)
|
|
104
|
+
|
|
105
|
+
if len(self._samples) < len(self._indices):
|
|
106
|
+
raise OSError(f"Dataset is missing {len(self._indices) - len(self._samples)} files.")
|
|
107
|
+
|
|
108
|
+
invalid_samples = [self._samples[i] for i in range(len(self)) if not _valid_sample(i)]
|
|
109
|
+
if invalid_samples:
|
|
110
|
+
raise OSError(
|
|
111
|
+
f"Dataset '{self.__class__.__qualname__}' contains missing or "
|
|
112
|
+
f"corrupted samples ({len(invalid_samples)} in total). "
|
|
113
|
+
f"Examples of missing folders: {str(invalid_samples[:10])[:-1]}, ...]. "
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def __getitem__(
|
|
118
|
+
self, index: int
|
|
119
|
+
) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask, dict[str, Any]]:
|
|
120
|
+
volume = self.load_data(index)
|
|
121
|
+
mask = self.load_target(index)
|
|
122
|
+
metadata = self.load_metadata(index) or {}
|
|
123
|
+
volume_tensor, mask_tensor = self._apply_transforms(volume, mask)
|
|
124
|
+
return volume_tensor, mask_tensor, metadata
|
|
125
|
+
|
|
126
|
+
@override
|
|
127
|
+
def __len__(self) -> int:
|
|
128
|
+
return len(self._indices)
|
|
129
|
+
|
|
130
|
+
@override
|
|
131
|
+
def load_data(self, index: int) -> eva_tv_tensors.Volume:
|
|
132
|
+
"""Loads the CT volume for a given sample.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
index: The index of the desired sample.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Tensor representing the CT volume of shape `[T, C, H, W]`.
|
|
139
|
+
"""
|
|
140
|
+
ct_scan_file, _ = self._samples[self._indices[index]]
|
|
141
|
+
return _utils.load_volume_tensor(ct_scan_file)
|
|
142
|
+
|
|
143
|
+
@override
|
|
144
|
+
def load_target(self, index: int) -> tv_tensors.Mask:
|
|
145
|
+
"""Loads the segmentation mask for a given sample.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
index: The index of the desired sample.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Tensor representing the segmentation mask of shape `[T, C, H, W]`.
|
|
152
|
+
"""
|
|
153
|
+
ct_scan_file, mask_file = self._samples[self._indices[index]]
|
|
154
|
+
return _utils.load_mask_tensor(mask_file, ct_scan_file)
|
|
155
|
+
|
|
156
|
+
def _apply_transforms(
|
|
157
|
+
self, ct_scan: eva_tv_tensors.Volume, mask: tv_tensors.Mask
|
|
158
|
+
) -> Tuple[eva_tv_tensors.Volume, tv_tensors.Mask]:
|
|
159
|
+
"""Applies transformations to the provided data.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
ct_scan: The CT volume tensor.
|
|
163
|
+
mask: The segmentation mask tensor.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
A tuple containing the transformed CT and mask tensors.
|
|
167
|
+
"""
|
|
168
|
+
return self._transforms(ct_scan, mask) if self._transforms else (ct_scan, mask)
|
|
169
|
+
|
|
170
|
+
def _find_samples(self) -> Dict[int, Tuple[str, str]]:
|
|
171
|
+
"""Retrieves the file paths for the CT volumes and segmentation.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
The a dictionary mapping the file id to the volume and segmentation paths.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def filename_id_volume(filename: str) -> int:
|
|
178
|
+
matches = re.match(r".*(\d{3})_\d{4}.*", filename)
|
|
179
|
+
if matches is None:
|
|
180
|
+
raise ValueError(f"Filename '{filename}' is not valid.")
|
|
181
|
+
return int(matches.group(1))
|
|
182
|
+
|
|
183
|
+
def filename_id_segmentation(filename: str) -> int:
|
|
184
|
+
matches = re.match(r".*(\d{3}).*", filename)
|
|
185
|
+
if matches is None:
|
|
186
|
+
raise ValueError(f"Filename '{filename}' is not valid.")
|
|
187
|
+
return int(matches.group(1))
|
|
188
|
+
|
|
189
|
+
optional_subdir = os.path.join(self._root, "Dataset007_Pancreas")
|
|
190
|
+
search_dir = optional_subdir if os.path.isdir(optional_subdir) else self._root
|
|
191
|
+
|
|
192
|
+
volume_files_pattern = os.path.join(search_dir, "imagesTr", "*.nii.gz")
|
|
193
|
+
volume_filenames = glob.glob(volume_files_pattern)
|
|
194
|
+
volume_ids = {filename_id_volume(filename): filename for filename in volume_filenames}
|
|
195
|
+
|
|
196
|
+
segmentation_files_pattern = os.path.join(search_dir, "labelsTr", "*.nii.gz")
|
|
197
|
+
segmentation_filenames = glob.glob(segmentation_files_pattern)
|
|
198
|
+
segmentation_ids = {
|
|
199
|
+
filename_id_segmentation(filename): filename for filename in segmentation_filenames
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
return {
|
|
203
|
+
file_id: (volume_ids[file_id], segmentation_ids[file_id])
|
|
204
|
+
for file_id in volume_ids.keys()
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
def _make_indices(self) -> List[int]:
|
|
208
|
+
"""Builds the dataset indices for the specified split."""
|
|
209
|
+
file_ids = []
|
|
210
|
+
match self._split:
|
|
211
|
+
case "train":
|
|
212
|
+
file_ids = self._train_ids
|
|
213
|
+
case "val":
|
|
214
|
+
file_ids = self._val_ids
|
|
215
|
+
case None:
|
|
216
|
+
file_ids = self._train_ids + self._val_ids
|
|
217
|
+
case _:
|
|
218
|
+
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
|
|
219
|
+
|
|
220
|
+
return file_ids
|
|
221
|
+
|
|
222
|
+
def _download_dataset(self) -> None:
|
|
223
|
+
hf_token = os.getenv("HF_TOKEN")
|
|
224
|
+
if not hf_token:
|
|
225
|
+
raise ValueError("Huggingface token required, please set the HF_TOKEN env variable.")
|
|
226
|
+
|
|
227
|
+
huggingface_hub.snapshot_download(
|
|
228
|
+
"Luffy503/VoCo_Downstream",
|
|
229
|
+
repo_type="dataset",
|
|
230
|
+
token=hf_token,
|
|
231
|
+
local_dir=self._root,
|
|
232
|
+
ignore_patterns=[".git*"],
|
|
233
|
+
allow_patterns=["**/Dataset007_Pancreas.zip"],
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
zip_path = os.path.join(self._root, "MSD_Decathlon/Dataset007_Pancreas.zip")
|
|
237
|
+
if not os.path.exists(zip_path):
|
|
238
|
+
raise FileNotFoundError(
|
|
239
|
+
f"MSD_Decathlon/Dataset007_Pancreas.zip not found in {self._root}, "
|
|
240
|
+
"something with the download went wrong."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
data_utils.extract_archive(zip_path, self._root, remove_finished=True)
|
|
@@ -18,7 +18,7 @@ from typing_extensions import override
|
|
|
18
18
|
from eva.core.utils import io as core_io
|
|
19
19
|
from eva.core.utils import multiprocessing
|
|
20
20
|
from eva.vision.data.datasets import _validators, structs, vision
|
|
21
|
-
from eva.vision.data.datasets.segmentation import _total_segmentator
|
|
21
|
+
from eva.vision.data.datasets.segmentation.metadata import _total_segmentator
|
|
22
22
|
from eva.vision.utils import io
|
|
23
23
|
|
|
24
24
|
|
|
@@ -1,7 +1,13 @@
|
|
|
1
1
|
"""Vision data transforms."""
|
|
2
2
|
|
|
3
|
-
from eva.vision.data.transforms.common import ResizeAndCrop
|
|
4
|
-
from eva.vision.data.transforms.croppad import
|
|
3
|
+
from eva.vision.data.transforms.common import ResizeAndCrop, Squeeze
|
|
4
|
+
from eva.vision.data.transforms.croppad import (
|
|
5
|
+
CropForeground,
|
|
6
|
+
RandCropByLabelClasses,
|
|
7
|
+
RandCropByPosNegLabel,
|
|
8
|
+
RandSpatialCrop,
|
|
9
|
+
SpatialPad,
|
|
10
|
+
)
|
|
5
11
|
from eva.vision.data.transforms.intensity import (
|
|
6
12
|
RandScaleIntensity,
|
|
7
13
|
RandShiftIntensity,
|
|
@@ -12,7 +18,9 @@ from eva.vision.data.transforms.utility import EnsureChannelFirst
|
|
|
12
18
|
|
|
13
19
|
__all__ = [
|
|
14
20
|
"ResizeAndCrop",
|
|
21
|
+
"Squeeze",
|
|
15
22
|
"CropForeground",
|
|
23
|
+
"RandCropByLabelClasses",
|
|
16
24
|
"RandCropByPosNegLabel",
|
|
17
25
|
"SpatialPad",
|
|
18
26
|
"RandScaleIntensity",
|
|
@@ -20,6 +28,7 @@ __all__ = [
|
|
|
20
28
|
"ScaleIntensityRange",
|
|
21
29
|
"RandFlip",
|
|
22
30
|
"RandRotate90",
|
|
31
|
+
"RandSpatialCrop",
|
|
23
32
|
"Spacing",
|
|
24
33
|
"EnsureChannelFirst",
|
|
25
34
|
]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Base class for MONAI transform wrappers."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
|
|
5
|
+
from torchvision.transforms import v2
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RandomMonaiTransform(v2.Transform, abc.ABC):
|
|
9
|
+
"""Base class for MONAI transform wrappers."""
|
|
10
|
+
|
|
11
|
+
@abc.abstractmethod
|
|
12
|
+
def set_random_state(self, seed: int) -> None:
|
|
13
|
+
"""Set the random state for the transform.
|
|
14
|
+
|
|
15
|
+
MONAI's random transforms use numpy.random for random number generation
|
|
16
|
+
which is seeded at the very beginning by lightning's seed_everything, but when
|
|
17
|
+
torch spins up the dataloader workers, it will only reseed torch's random states
|
|
18
|
+
and not numpy - so you basically end up with multiple dataloader workers that
|
|
19
|
+
have equally seeded random transforms, resulting in redundant transform outputs
|
|
20
|
+
and therefore reducing the diversity of the resulting training data.
|
|
21
|
+
To solve this, this method should be called in the dataloader's worker_init_fn
|
|
22
|
+
with a unique seed for each worker.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
seed: The seed to set for the random state of the transform.
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Squeeze transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torchvision import tv_tensors
|
|
7
|
+
from torchvision.transforms import v2
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Squeeze(v2.Transform):
|
|
11
|
+
"""Squeezes the input tensor accross all or specified dimensions."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, dim: int | list[int] | None = None):
|
|
14
|
+
"""Initializes the transform.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
dim: If specified, the input will be squeezed only in the specified dimensions.
|
|
18
|
+
"""
|
|
19
|
+
super().__init__()
|
|
20
|
+
self._dim = dim
|
|
21
|
+
|
|
22
|
+
def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
|
23
|
+
output = torch.squeeze(inpt) if self._dim is None else torch.squeeze(inpt, dim=self._dim)
|
|
24
|
+
return tv_tensors.wrap(output, like=inpt)
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""Transforms for crop and pad operations."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.transforms.croppad.crop_foreground import CropForeground
|
|
4
|
+
from eva.vision.data.transforms.croppad.rand_crop_by_label_classes import RandCropByLabelClasses
|
|
4
5
|
from eva.vision.data.transforms.croppad.rand_crop_by_pos_neg_label import RandCropByPosNegLabel
|
|
6
|
+
from eva.vision.data.transforms.croppad.rand_spatial_crop import RandSpatialCrop
|
|
5
7
|
from eva.vision.data.transforms.croppad.spatial_pad import SpatialPad
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
10
|
"CropForeground",
|
|
11
|
+
"RandCropByLabelClasses",
|
|
9
12
|
"RandCropByPosNegLabel",
|
|
13
|
+
"RandSpatialCrop",
|
|
10
14
|
"SpatialPad",
|
|
11
15
|
]
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Crop by label classes transform."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from monai.config.type_definitions import NdarrayOrTensor
|
|
8
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
from eva.vision.data.transforms import base
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RandCropByLabelClasses(base.RandomMonaiTransform):
|
|
17
|
+
"""Crop random fixed sized regions with the center belonging to one of the classes.
|
|
18
|
+
|
|
19
|
+
Please refer to `monai.transforms.croppad.RandCropByLabelClasses` docs for more details.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
spatial_size: Sequence[int] | int,
|
|
25
|
+
ratios: list[float | int] | None = None,
|
|
26
|
+
label: torch.Tensor | None = None,
|
|
27
|
+
num_classes: int | None = None,
|
|
28
|
+
num_samples: int = 1,
|
|
29
|
+
image: torch.Tensor | None = None,
|
|
30
|
+
image_threshold: float = 0.0,
|
|
31
|
+
indices: list[NdarrayOrTensor] | None = None,
|
|
32
|
+
allow_smaller: bool = False,
|
|
33
|
+
warn: bool = True,
|
|
34
|
+
max_samples_per_class: int | None = None,
|
|
35
|
+
lazy: bool = False,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Initializes the transform."""
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self._rand_crop = monai_croppad_transforms.RandCropByLabelClasses(
|
|
41
|
+
spatial_size=spatial_size,
|
|
42
|
+
ratios=ratios,
|
|
43
|
+
label=label,
|
|
44
|
+
num_classes=num_classes,
|
|
45
|
+
num_samples=num_samples,
|
|
46
|
+
image=image,
|
|
47
|
+
image_threshold=image_threshold,
|
|
48
|
+
indices=indices,
|
|
49
|
+
allow_smaller=allow_smaller,
|
|
50
|
+
warn=warn,
|
|
51
|
+
max_samples_per_class=max_samples_per_class,
|
|
52
|
+
lazy=lazy,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def set_random_state(self, seed: int) -> None:
|
|
57
|
+
self._rand_crop.set_random_state(seed)
|
|
58
|
+
|
|
59
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
60
|
+
mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
|
|
61
|
+
self._rand_crop.randomize(label=mask)
|
|
62
|
+
return {}
|
|
63
|
+
|
|
64
|
+
@functools.singledispatchmethod
|
|
65
|
+
@override
|
|
66
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
67
|
+
return inpt
|
|
68
|
+
|
|
69
|
+
@_transform.register(tv_tensors.Image)
|
|
70
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
71
|
+
@_transform.register(tv_tensors.Mask)
|
|
72
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
73
|
+
inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
|
|
74
|
+
return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
|
|
@@ -7,13 +7,13 @@ import torch
|
|
|
7
7
|
from monai.config.type_definitions import NdarrayOrTensor
|
|
8
8
|
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
9
9
|
from torchvision import tv_tensors
|
|
10
|
-
from torchvision.transforms import v2
|
|
11
10
|
from typing_extensions import override
|
|
12
11
|
|
|
13
12
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
from eva.vision.data.transforms import base
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class RandCropByPosNegLabel(
|
|
16
|
+
class RandCropByPosNegLabel(base.RandomMonaiTransform):
|
|
17
17
|
"""Crop random fixed sized regions with the center being a foreground or background voxel.
|
|
18
18
|
|
|
19
19
|
Its based on the Pos Neg Ratio and will return a list of arrays for all the cropped images.
|
|
@@ -91,6 +91,10 @@ class RandCropByPosNegLabel(v2.Transform):
|
|
|
91
91
|
lazy=False,
|
|
92
92
|
)
|
|
93
93
|
|
|
94
|
+
@override
|
|
95
|
+
def set_random_state(self, seed: int) -> None:
|
|
96
|
+
self._rand_crop.set_random_state(seed)
|
|
97
|
+
|
|
94
98
|
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
95
99
|
mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
|
|
96
100
|
self._rand_crop.randomize(label=mask)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Crop image with random size or specific size ROI."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List, Sequence, Tuple
|
|
5
|
+
|
|
6
|
+
from monai.transforms.croppad import array as monai_croppad_transforms
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.transforms import v2
|
|
9
|
+
from torchvision.transforms.v2 import _utils as tv_utils
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RandSpatialCrop(v2.Transform):
|
|
16
|
+
"""Crop image with random size or specific size ROI.
|
|
17
|
+
|
|
18
|
+
It can crop at a random position as center or at the image center.
|
|
19
|
+
And allows to set the minimum and maximum size to limit the randomly
|
|
20
|
+
generated ROI.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
roi_size: Sequence[int] | int,
|
|
26
|
+
max_roi_size: Sequence[int] | int | None = None,
|
|
27
|
+
random_center: bool = True,
|
|
28
|
+
random_size: bool = False,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes the transform.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
roi_size: if `random_size` is True, it specifies the minimum crop
|
|
34
|
+
region. if `random_size` is False, it specifies the expected
|
|
35
|
+
ROI size to crop. e.g. [224, 224, 128]. if a dimension of ROI
|
|
36
|
+
size is larger than image size, will not crop that dimension of
|
|
37
|
+
the image. If its components have non-positive values, the
|
|
38
|
+
corresponding size of input image will be used. for example: if
|
|
39
|
+
the spatial size of input data is [40, 40, 40] and
|
|
40
|
+
`roi_size=[32, 64, -1]`, the spatial size of output data will be
|
|
41
|
+
[32, 40, 40].
|
|
42
|
+
max_roi_size: if `random_size` is True and `roi_size` specifies the
|
|
43
|
+
min crop region size, `max_roi_size` can specify the max crop
|
|
44
|
+
region size. if None, defaults to the input image size. if its
|
|
45
|
+
components have non-positive values, the corresponding size of
|
|
46
|
+
input image will be used.
|
|
47
|
+
random_center: crop at random position as center or the image center.
|
|
48
|
+
random_size: crop with random size or specific size ROI. if True, the
|
|
49
|
+
actual size is sampled from `randint(roi_size, max_roi_size + 1)`.
|
|
50
|
+
"""
|
|
51
|
+
super().__init__()
|
|
52
|
+
|
|
53
|
+
self._rand_spatial_crop = monai_croppad_transforms.RandSpatialCrop(
|
|
54
|
+
roi_size=roi_size,
|
|
55
|
+
max_roi_size=max_roi_size,
|
|
56
|
+
random_center=random_center,
|
|
57
|
+
random_size=random_size,
|
|
58
|
+
)
|
|
59
|
+
self._cropper = monai_croppad_transforms.Crop()
|
|
60
|
+
|
|
61
|
+
def set_random_state(self, seed: int) -> None:
|
|
62
|
+
"""Set the random state for the transform."""
|
|
63
|
+
self._rand_spatial_crop.set_random_state(seed)
|
|
64
|
+
|
|
65
|
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
|
66
|
+
t, h, w = tv_utils.query_chw(flat_inputs)
|
|
67
|
+
self._rand_spatial_crop.randomize((t, h, w))
|
|
68
|
+
return {}
|
|
69
|
+
|
|
70
|
+
@functools.singledispatchmethod
|
|
71
|
+
@override
|
|
72
|
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
73
|
+
return inpt
|
|
74
|
+
|
|
75
|
+
@_transform.register(tv_tensors.Image)
|
|
76
|
+
@_transform.register(eva_tv_tensors.Volume)
|
|
77
|
+
@_transform.register(tv_tensors.Mask)
|
|
78
|
+
def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
79
|
+
slices = self._get_crop_slices()
|
|
80
|
+
inpt_rand_crop = self._cropper(inpt, slices=slices)
|
|
81
|
+
return tv_tensors.wrap(inpt_rand_crop, like=inpt)
|
|
82
|
+
|
|
83
|
+
def _get_crop_slices(self) -> Tuple[slice, ...]:
|
|
84
|
+
"""Returns the sequence of slices to crop."""
|
|
85
|
+
if self._rand_spatial_crop.random_center:
|
|
86
|
+
return self._rand_spatial_crop._slices
|
|
87
|
+
|
|
88
|
+
central_cropper = monai_croppad_transforms.CenterSpatialCrop(self._size)
|
|
89
|
+
return central_cropper.compute_slices(self._rand_spatial_crop._size) # type: ignore
|
|
@@ -7,13 +7,13 @@ import numpy as np
|
|
|
7
7
|
from monai.config.type_definitions import DtypeLike
|
|
8
8
|
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
9
9
|
from torchvision import tv_tensors
|
|
10
|
-
from torchvision.transforms import v2
|
|
11
10
|
from typing_extensions import override
|
|
12
11
|
|
|
13
12
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
13
|
+
from eva.vision.data.transforms import base
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class RandScaleIntensity(
|
|
16
|
+
class RandScaleIntensity(base.RandomMonaiTransform):
|
|
17
17
|
"""Randomly scale the intensity of input image.
|
|
18
18
|
|
|
19
19
|
The factor is by ``v = v * (1 + factor)``, where
|
|
@@ -47,6 +47,10 @@ class RandScaleIntensity(v2.Transform):
|
|
|
47
47
|
dtype=dtype,
|
|
48
48
|
)
|
|
49
49
|
|
|
50
|
+
@override
|
|
51
|
+
def set_random_state(self, seed: int) -> None:
|
|
52
|
+
self._rand_scale_intensity.set_random_state(seed)
|
|
53
|
+
|
|
50
54
|
@functools.singledispatchmethod
|
|
51
55
|
@override
|
|
52
56
|
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
@@ -5,13 +5,13 @@ from typing import Any, Dict
|
|
|
5
5
|
|
|
6
6
|
from monai.transforms.intensity import array as monai_intensity_transforms
|
|
7
7
|
from torchvision import tv_tensors
|
|
8
|
-
from torchvision.transforms import v2
|
|
9
8
|
from typing_extensions import override
|
|
10
9
|
|
|
11
10
|
from eva.vision.data import tv_tensors as eva_tv_tensors
|
|
11
|
+
from eva.vision.data.transforms import base
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
class RandShiftIntensity(
|
|
14
|
+
class RandShiftIntensity(base.RandomMonaiTransform):
|
|
15
15
|
"""Randomly shift intensity with randomly picked offset."""
|
|
16
16
|
|
|
17
17
|
def __init__(
|
|
@@ -36,13 +36,17 @@ class RandShiftIntensity(v2.Transform):
|
|
|
36
36
|
"""
|
|
37
37
|
super().__init__()
|
|
38
38
|
|
|
39
|
-
self.
|
|
39
|
+
self._rand_shift_intensity = monai_intensity_transforms.RandShiftIntensity(
|
|
40
40
|
offsets=offsets,
|
|
41
41
|
safe=safe,
|
|
42
42
|
prob=prob,
|
|
43
43
|
channel_wise=channel_wise,
|
|
44
44
|
)
|
|
45
45
|
|
|
46
|
+
@override
|
|
47
|
+
def set_random_state(self, seed: int) -> None:
|
|
48
|
+
self._rand_shift_intensity.set_random_state(seed)
|
|
49
|
+
|
|
46
50
|
@functools.singledispatchmethod
|
|
47
51
|
@override
|
|
48
52
|
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
|
@@ -51,5 +55,5 @@ class RandShiftIntensity(v2.Transform):
|
|
|
51
55
|
@_transform.register(tv_tensors.Image)
|
|
52
56
|
@_transform.register(eva_tv_tensors.Volume)
|
|
53
57
|
def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
|
|
54
|
-
inpt_scaled = self.
|
|
58
|
+
inpt_scaled = self._rand_shift_intensity(inpt)
|
|
55
59
|
return tv_tensors.wrap(inpt_scaled, like=inpt)
|