kaiko-eva 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaiko-eva might be problematic. Click here for more details.

Files changed (90) hide show
  1. eva/core/data/dataloaders/__init__.py +2 -1
  2. eva/core/data/dataloaders/collate_fn/__init__.py +5 -0
  3. eva/core/data/dataloaders/collate_fn/collate.py +24 -0
  4. eva/core/data/dataloaders/dataloader.py +4 -0
  5. eva/core/interface/interface.py +34 -1
  6. eva/core/metrics/defaults/classification/multiclass.py +45 -35
  7. eva/core/models/modules/__init__.py +2 -1
  8. eva/core/models/modules/scheduler.py +51 -0
  9. eva/core/models/transforms/extract_cls_features.py +1 -1
  10. eva/core/models/transforms/extract_patch_features.py +1 -1
  11. eva/core/models/wrappers/base.py +17 -14
  12. eva/core/models/wrappers/from_function.py +5 -4
  13. eva/core/models/wrappers/from_torchhub.py +5 -6
  14. eva/core/models/wrappers/huggingface.py +8 -5
  15. eva/core/models/wrappers/onnx.py +4 -4
  16. eva/core/trainers/functional.py +40 -43
  17. eva/core/utils/factory.py +66 -0
  18. eva/core/utils/registry.py +42 -0
  19. eva/core/utils/requirements.py +26 -0
  20. eva/language/__init__.py +13 -0
  21. eva/language/data/__init__.py +5 -0
  22. eva/language/data/datasets/__init__.py +9 -0
  23. eva/language/data/datasets/classification/__init__.py +7 -0
  24. eva/language/data/datasets/classification/base.py +63 -0
  25. eva/language/data/datasets/classification/pubmedqa.py +149 -0
  26. eva/language/data/datasets/language.py +13 -0
  27. eva/language/models/__init__.py +25 -0
  28. eva/language/models/modules/__init__.py +5 -0
  29. eva/language/models/modules/text.py +85 -0
  30. eva/language/models/modules/typings.py +16 -0
  31. eva/language/models/wrappers/__init__.py +11 -0
  32. eva/language/models/wrappers/huggingface.py +69 -0
  33. eva/language/models/wrappers/litellm.py +77 -0
  34. eva/language/models/wrappers/vllm.py +149 -0
  35. eva/language/utils/__init__.py +5 -0
  36. eva/language/utils/str_to_int_tensor.py +95 -0
  37. eva/vision/data/dataloaders/__init__.py +2 -1
  38. eva/vision/data/dataloaders/worker_init.py +35 -0
  39. eva/vision/data/datasets/__init__.py +5 -5
  40. eva/vision/data/datasets/segmentation/__init__.py +4 -4
  41. eva/vision/data/datasets/segmentation/btcv.py +3 -0
  42. eva/vision/data/datasets/segmentation/consep.py +5 -4
  43. eva/vision/data/datasets/segmentation/lits17.py +231 -0
  44. eva/vision/data/datasets/segmentation/metadata/__init__.py +1 -0
  45. eva/vision/data/datasets/segmentation/metadata/_msd_task7_pancreas.py +287 -0
  46. eva/vision/data/datasets/segmentation/msd_task7_pancreas.py +243 -0
  47. eva/vision/data/datasets/segmentation/total_segmentator_2d.py +1 -1
  48. eva/vision/data/transforms/__init__.py +11 -2
  49. eva/vision/data/transforms/base/__init__.py +5 -0
  50. eva/vision/data/transforms/base/monai.py +27 -0
  51. eva/vision/data/transforms/common/__init__.py +2 -1
  52. eva/vision/data/transforms/common/squeeze.py +24 -0
  53. eva/vision/data/transforms/croppad/__init__.py +4 -0
  54. eva/vision/data/transforms/croppad/rand_crop_by_label_classes.py +74 -0
  55. eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +6 -2
  56. eva/vision/data/transforms/croppad/rand_spatial_crop.py +89 -0
  57. eva/vision/data/transforms/intensity/rand_scale_intensity.py +6 -2
  58. eva/vision/data/transforms/intensity/rand_shift_intensity.py +8 -4
  59. eva/vision/models/modules/semantic_segmentation.py +18 -7
  60. eva/vision/models/networks/backbones/__init__.py +2 -3
  61. eva/vision/models/networks/backbones/_utils.py +1 -1
  62. eva/vision/models/networks/backbones/pathology/bioptimus.py +4 -4
  63. eva/vision/models/networks/backbones/pathology/gigapath.py +2 -2
  64. eva/vision/models/networks/backbones/pathology/histai.py +3 -3
  65. eva/vision/models/networks/backbones/pathology/hkust.py +2 -2
  66. eva/vision/models/networks/backbones/pathology/kaiko.py +7 -7
  67. eva/vision/models/networks/backbones/pathology/lunit.py +3 -3
  68. eva/vision/models/networks/backbones/pathology/mahmood.py +3 -3
  69. eva/vision/models/networks/backbones/pathology/owkin.py +3 -3
  70. eva/vision/models/networks/backbones/pathology/paige.py +3 -3
  71. eva/vision/models/networks/backbones/radiology/swin_unetr.py +2 -2
  72. eva/vision/models/networks/backbones/radiology/voco.py +5 -5
  73. eva/vision/models/networks/backbones/registry.py +2 -44
  74. eva/vision/models/networks/backbones/timm/backbones.py +2 -2
  75. eva/vision/models/networks/backbones/universal/__init__.py +8 -1
  76. eva/vision/models/networks/backbones/universal/vit.py +53 -3
  77. eva/vision/models/networks/decoders/segmentation/decoder2d.py +1 -1
  78. eva/vision/models/networks/decoders/segmentation/linear.py +1 -1
  79. eva/vision/models/networks/decoders/segmentation/semantic/common.py +2 -2
  80. eva/vision/models/networks/decoders/segmentation/typings.py +1 -1
  81. eva/vision/models/wrappers/from_registry.py +14 -9
  82. eva/vision/models/wrappers/from_timm.py +6 -5
  83. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/METADATA +10 -2
  84. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/RECORD +88 -57
  85. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/WHEEL +1 -1
  86. eva/vision/data/datasets/segmentation/lits.py +0 -199
  87. eva/vision/data/datasets/segmentation/lits_balanced.py +0 -94
  88. /eva/vision/data/datasets/segmentation/{_total_segmentator.py → metadata/_total_segmentator.py} +0 -0
  89. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/entry_points.txt +0 -0
  90. {kaiko_eva-0.2.2.dist-info → kaiko_eva-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,243 @@
1
+ """Dataset for Task 7 (pancreas tumor) from the Medical Segmentation Decathlon (MSD)."""
2
+
3
+ import glob
4
+ import os
5
+ import re
6
+ from typing import Any, Callable, Dict, List, Literal, Tuple
7
+
8
+ import huggingface_hub
9
+ from torchvision import tv_tensors
10
+ from torchvision.datasets import utils as data_utils
11
+ from typing_extensions import override
12
+
13
+ from eva.core.utils import requirements
14
+ from eva.vision.data import tv_tensors as eva_tv_tensors
15
+ from eva.vision.data.datasets.segmentation import _utils
16
+ from eva.vision.data.datasets.segmentation.metadata import _msd_task7_pancreas
17
+ from eva.vision.data.datasets.vision import VisionDataset
18
+
19
+
20
+ class MSDTask7Pancreas(VisionDataset[eva_tv_tensors.Volume, tv_tensors.Mask]):
21
+ """Task 7 (pancreas tumor) of the Medical Segmentation Decathlon (MSD).
22
+
23
+ The data set consists of 420 portal-venous phase CT scans of patients undergoing
24
+ resection of pancreatic masses. The corresponding target ROIs were the pancreatic
25
+ parenchyma and pancreatic mass (cyst or tumor). This data set was selected due to
26
+ label unbalance between large (background), medium (pancreas) and small (tumor)
27
+ structures. The data was acquired in the Memorial Sloan Kettering Cancer
28
+ Center, New York, US.
29
+
30
+ More info:
31
+ - Paper: https://www.nature.com/articles/s41467-022-30695-9
32
+ - Dataset source: https://huggingface.co/datasets/Luffy503/VoCo_Downstream
33
+ """
34
+
35
+ _train_ids = _msd_task7_pancreas.train_ids
36
+ """File indices of the training set."""
37
+
38
+ _val_ids = _msd_task7_pancreas.val_ids
39
+ """File indices of the validation set."""
40
+
41
+ def __init__(
42
+ self,
43
+ root: str,
44
+ split: Literal["train", "val"] | None = None,
45
+ download: bool = False,
46
+ transforms: Callable | None = None,
47
+ ) -> None:
48
+ """Initializes the dataset.
49
+
50
+ Args:
51
+ root: Path to the dataset root directory.
52
+ split: Dataset split to use ('train' or 'val').
53
+ If None, it uses the full dataset.
54
+ download: Whether to download the dataset.
55
+ transforms: A callable object for applying data transformations.
56
+ If None, no transformations are applied.
57
+ """
58
+ super().__init__()
59
+
60
+ self._root = root
61
+ self._split = split
62
+ self._download = download
63
+ self._transforms = transforms
64
+
65
+ self._samples: Dict[int, Tuple[str, str]]
66
+ self._indices: List[int]
67
+
68
+ @property
69
+ @override
70
+ def classes(self) -> List[str]:
71
+ return [
72
+ "background",
73
+ "pancreas",
74
+ "cancer",
75
+ ]
76
+
77
+ @property
78
+ @override
79
+ def class_to_idx(self) -> Dict[str, int]:
80
+ return {label: index for index, label in enumerate(self.classes)}
81
+
82
+ @override
83
+ def filename(self, index: int) -> str:
84
+ return os.path.relpath(self._samples[self._indices[index]][0], self._root)
85
+
86
+ @override
87
+ def prepare_data(self) -> None:
88
+ if self._download:
89
+ self._download_dataset()
90
+
91
+ @override
92
+ def configure(self) -> None:
93
+ self._samples = self._find_samples()
94
+ self._indices = self._make_indices()
95
+
96
+ @override
97
+ def validate(self) -> None:
98
+ requirements.check_dependencies(requirements={"torch": "2.5.1", "torchvision": "0.20.1"})
99
+
100
+ def _valid_sample(index: int) -> bool:
101
+ """Indicates if the sample files exist and are reachable."""
102
+ volume_file, segmentation_file = self._samples[self._indices[index]]
103
+ return os.path.isfile(volume_file) and os.path.isfile(segmentation_file)
104
+
105
+ if len(self._samples) < len(self._indices):
106
+ raise OSError(f"Dataset is missing {len(self._indices) - len(self._samples)} files.")
107
+
108
+ invalid_samples = [self._samples[i] for i in range(len(self)) if not _valid_sample(i)]
109
+ if invalid_samples:
110
+ raise OSError(
111
+ f"Dataset '{self.__class__.__qualname__}' contains missing or "
112
+ f"corrupted samples ({len(invalid_samples)} in total). "
113
+ f"Examples of missing folders: {str(invalid_samples[:10])[:-1]}, ...]. "
114
+ )
115
+
116
+ @override
117
+ def __getitem__(
118
+ self, index: int
119
+ ) -> tuple[eva_tv_tensors.Volume, tv_tensors.Mask, dict[str, Any]]:
120
+ volume = self.load_data(index)
121
+ mask = self.load_target(index)
122
+ metadata = self.load_metadata(index) or {}
123
+ volume_tensor, mask_tensor = self._apply_transforms(volume, mask)
124
+ return volume_tensor, mask_tensor, metadata
125
+
126
+ @override
127
+ def __len__(self) -> int:
128
+ return len(self._indices)
129
+
130
+ @override
131
+ def load_data(self, index: int) -> eva_tv_tensors.Volume:
132
+ """Loads the CT volume for a given sample.
133
+
134
+ Args:
135
+ index: The index of the desired sample.
136
+
137
+ Returns:
138
+ Tensor representing the CT volume of shape `[T, C, H, W]`.
139
+ """
140
+ ct_scan_file, _ = self._samples[self._indices[index]]
141
+ return _utils.load_volume_tensor(ct_scan_file)
142
+
143
+ @override
144
+ def load_target(self, index: int) -> tv_tensors.Mask:
145
+ """Loads the segmentation mask for a given sample.
146
+
147
+ Args:
148
+ index: The index of the desired sample.
149
+
150
+ Returns:
151
+ Tensor representing the segmentation mask of shape `[T, C, H, W]`.
152
+ """
153
+ ct_scan_file, mask_file = self._samples[self._indices[index]]
154
+ return _utils.load_mask_tensor(mask_file, ct_scan_file)
155
+
156
+ def _apply_transforms(
157
+ self, ct_scan: eva_tv_tensors.Volume, mask: tv_tensors.Mask
158
+ ) -> Tuple[eva_tv_tensors.Volume, tv_tensors.Mask]:
159
+ """Applies transformations to the provided data.
160
+
161
+ Args:
162
+ ct_scan: The CT volume tensor.
163
+ mask: The segmentation mask tensor.
164
+
165
+ Returns:
166
+ A tuple containing the transformed CT and mask tensors.
167
+ """
168
+ return self._transforms(ct_scan, mask) if self._transforms else (ct_scan, mask)
169
+
170
+ def _find_samples(self) -> Dict[int, Tuple[str, str]]:
171
+ """Retrieves the file paths for the CT volumes and segmentation.
172
+
173
+ Returns:
174
+ The a dictionary mapping the file id to the volume and segmentation paths.
175
+ """
176
+
177
+ def filename_id_volume(filename: str) -> int:
178
+ matches = re.match(r".*(\d{3})_\d{4}.*", filename)
179
+ if matches is None:
180
+ raise ValueError(f"Filename '{filename}' is not valid.")
181
+ return int(matches.group(1))
182
+
183
+ def filename_id_segmentation(filename: str) -> int:
184
+ matches = re.match(r".*(\d{3}).*", filename)
185
+ if matches is None:
186
+ raise ValueError(f"Filename '{filename}' is not valid.")
187
+ return int(matches.group(1))
188
+
189
+ optional_subdir = os.path.join(self._root, "Dataset007_Pancreas")
190
+ search_dir = optional_subdir if os.path.isdir(optional_subdir) else self._root
191
+
192
+ volume_files_pattern = os.path.join(search_dir, "imagesTr", "*.nii.gz")
193
+ volume_filenames = glob.glob(volume_files_pattern)
194
+ volume_ids = {filename_id_volume(filename): filename for filename in volume_filenames}
195
+
196
+ segmentation_files_pattern = os.path.join(search_dir, "labelsTr", "*.nii.gz")
197
+ segmentation_filenames = glob.glob(segmentation_files_pattern)
198
+ segmentation_ids = {
199
+ filename_id_segmentation(filename): filename for filename in segmentation_filenames
200
+ }
201
+
202
+ return {
203
+ file_id: (volume_ids[file_id], segmentation_ids[file_id])
204
+ for file_id in volume_ids.keys()
205
+ }
206
+
207
+ def _make_indices(self) -> List[int]:
208
+ """Builds the dataset indices for the specified split."""
209
+ file_ids = []
210
+ match self._split:
211
+ case "train":
212
+ file_ids = self._train_ids
213
+ case "val":
214
+ file_ids = self._val_ids
215
+ case None:
216
+ file_ids = self._train_ids + self._val_ids
217
+ case _:
218
+ raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
219
+
220
+ return file_ids
221
+
222
+ def _download_dataset(self) -> None:
223
+ hf_token = os.getenv("HF_TOKEN")
224
+ if not hf_token:
225
+ raise ValueError("Huggingface token required, please set the HF_TOKEN env variable.")
226
+
227
+ huggingface_hub.snapshot_download(
228
+ "Luffy503/VoCo_Downstream",
229
+ repo_type="dataset",
230
+ token=hf_token,
231
+ local_dir=self._root,
232
+ ignore_patterns=[".git*"],
233
+ allow_patterns=["**/Dataset007_Pancreas.zip"],
234
+ )
235
+
236
+ zip_path = os.path.join(self._root, "MSD_Decathlon/Dataset007_Pancreas.zip")
237
+ if not os.path.exists(zip_path):
238
+ raise FileNotFoundError(
239
+ f"MSD_Decathlon/Dataset007_Pancreas.zip not found in {self._root}, "
240
+ "something with the download went wrong."
241
+ )
242
+
243
+ data_utils.extract_archive(zip_path, self._root, remove_finished=True)
@@ -18,7 +18,7 @@ from typing_extensions import override
18
18
  from eva.core.utils import io as core_io
19
19
  from eva.core.utils import multiprocessing
20
20
  from eva.vision.data.datasets import _validators, structs, vision
21
- from eva.vision.data.datasets.segmentation import _total_segmentator
21
+ from eva.vision.data.datasets.segmentation.metadata import _total_segmentator
22
22
  from eva.vision.utils import io
23
23
 
24
24
 
@@ -1,7 +1,13 @@
1
1
  """Vision data transforms."""
2
2
 
3
- from eva.vision.data.transforms.common import ResizeAndCrop
4
- from eva.vision.data.transforms.croppad import CropForeground, RandCropByPosNegLabel, SpatialPad
3
+ from eva.vision.data.transforms.common import ResizeAndCrop, Squeeze
4
+ from eva.vision.data.transforms.croppad import (
5
+ CropForeground,
6
+ RandCropByLabelClasses,
7
+ RandCropByPosNegLabel,
8
+ RandSpatialCrop,
9
+ SpatialPad,
10
+ )
5
11
  from eva.vision.data.transforms.intensity import (
6
12
  RandScaleIntensity,
7
13
  RandShiftIntensity,
@@ -12,7 +18,9 @@ from eva.vision.data.transforms.utility import EnsureChannelFirst
12
18
 
13
19
  __all__ = [
14
20
  "ResizeAndCrop",
21
+ "Squeeze",
15
22
  "CropForeground",
23
+ "RandCropByLabelClasses",
16
24
  "RandCropByPosNegLabel",
17
25
  "SpatialPad",
18
26
  "RandScaleIntensity",
@@ -20,6 +28,7 @@ __all__ = [
20
28
  "ScaleIntensityRange",
21
29
  "RandFlip",
22
30
  "RandRotate90",
31
+ "RandSpatialCrop",
23
32
  "Spacing",
24
33
  "EnsureChannelFirst",
25
34
  ]
@@ -0,0 +1,5 @@
1
+ """Base classes for transforms."""
2
+
3
+ from eva.vision.data.transforms.base.monai import RandomMonaiTransform
4
+
5
+ __all__ = ["RandomMonaiTransform"]
@@ -0,0 +1,27 @@
1
+ """Base class for MONAI transform wrappers."""
2
+
3
+ import abc
4
+
5
+ from torchvision.transforms import v2
6
+
7
+
8
+ class RandomMonaiTransform(v2.Transform, abc.ABC):
9
+ """Base class for MONAI transform wrappers."""
10
+
11
+ @abc.abstractmethod
12
+ def set_random_state(self, seed: int) -> None:
13
+ """Set the random state for the transform.
14
+
15
+ MONAI's random transforms use numpy.random for random number generation
16
+ which is seeded at the very beginning by lightning's seed_everything, but when
17
+ torch spins up the dataloader workers, it will only reseed torch's random states
18
+ and not numpy - so you basically end up with multiple dataloader workers that
19
+ have equally seeded random transforms, resulting in redundant transform outputs
20
+ and therefore reducing the diversity of the resulting training data.
21
+ To solve this, this method should be called in the dataloader's worker_init_fn
22
+ with a unique seed for each worker.
23
+
24
+ Args:
25
+ seed: The seed to set for the random state of the transform.
26
+ """
27
+ pass
@@ -1,5 +1,6 @@
1
1
  """Common vision transforms."""
2
2
 
3
3
  from eva.vision.data.transforms.common.resize_and_crop import ResizeAndCrop
4
+ from eva.vision.data.transforms.common.squeeze import Squeeze
4
5
 
5
- __all__ = ["ResizeAndCrop"]
6
+ __all__ = ["ResizeAndCrop", "Squeeze"]
@@ -0,0 +1,24 @@
1
+ """Squeeze transform."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torchvision import tv_tensors
7
+ from torchvision.transforms import v2
8
+
9
+
10
+ class Squeeze(v2.Transform):
11
+ """Squeezes the input tensor accross all or specified dimensions."""
12
+
13
+ def __init__(self, dim: int | list[int] | None = None):
14
+ """Initializes the transform.
15
+
16
+ Args:
17
+ dim: If specified, the input will be squeezed only in the specified dimensions.
18
+ """
19
+ super().__init__()
20
+ self._dim = dim
21
+
22
+ def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
23
+ output = torch.squeeze(inpt) if self._dim is None else torch.squeeze(inpt, dim=self._dim)
24
+ return tv_tensors.wrap(output, like=inpt)
@@ -1,11 +1,15 @@
1
1
  """Transforms for crop and pad operations."""
2
2
 
3
3
  from eva.vision.data.transforms.croppad.crop_foreground import CropForeground
4
+ from eva.vision.data.transforms.croppad.rand_crop_by_label_classes import RandCropByLabelClasses
4
5
  from eva.vision.data.transforms.croppad.rand_crop_by_pos_neg_label import RandCropByPosNegLabel
6
+ from eva.vision.data.transforms.croppad.rand_spatial_crop import RandSpatialCrop
5
7
  from eva.vision.data.transforms.croppad.spatial_pad import SpatialPad
6
8
 
7
9
  __all__ = [
8
10
  "CropForeground",
11
+ "RandCropByLabelClasses",
9
12
  "RandCropByPosNegLabel",
13
+ "RandSpatialCrop",
10
14
  "SpatialPad",
11
15
  ]
@@ -0,0 +1,74 @@
1
+ """Crop by label classes transform."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import torch
7
+ from monai.config.type_definitions import NdarrayOrTensor
8
+ from monai.transforms.croppad import array as monai_croppad_transforms
9
+ from torchvision import tv_tensors
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data import tv_tensors as eva_tv_tensors
13
+ from eva.vision.data.transforms import base
14
+
15
+
16
+ class RandCropByLabelClasses(base.RandomMonaiTransform):
17
+ """Crop random fixed sized regions with the center belonging to one of the classes.
18
+
19
+ Please refer to `monai.transforms.croppad.RandCropByLabelClasses` docs for more details.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ spatial_size: Sequence[int] | int,
25
+ ratios: list[float | int] | None = None,
26
+ label: torch.Tensor | None = None,
27
+ num_classes: int | None = None,
28
+ num_samples: int = 1,
29
+ image: torch.Tensor | None = None,
30
+ image_threshold: float = 0.0,
31
+ indices: list[NdarrayOrTensor] | None = None,
32
+ allow_smaller: bool = False,
33
+ warn: bool = True,
34
+ max_samples_per_class: int | None = None,
35
+ lazy: bool = False,
36
+ ) -> None:
37
+ """Initializes the transform."""
38
+ super().__init__()
39
+
40
+ self._rand_crop = monai_croppad_transforms.RandCropByLabelClasses(
41
+ spatial_size=spatial_size,
42
+ ratios=ratios,
43
+ label=label,
44
+ num_classes=num_classes,
45
+ num_samples=num_samples,
46
+ image=image,
47
+ image_threshold=image_threshold,
48
+ indices=indices,
49
+ allow_smaller=allow_smaller,
50
+ warn=warn,
51
+ max_samples_per_class=max_samples_per_class,
52
+ lazy=lazy,
53
+ )
54
+
55
+ @override
56
+ def set_random_state(self, seed: int) -> None:
57
+ self._rand_crop.set_random_state(seed)
58
+
59
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
60
+ mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
61
+ self._rand_crop.randomize(label=mask)
62
+ return {}
63
+
64
+ @functools.singledispatchmethod
65
+ @override
66
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
67
+ return inpt
68
+
69
+ @_transform.register(tv_tensors.Image)
70
+ @_transform.register(eva_tv_tensors.Volume)
71
+ @_transform.register(tv_tensors.Mask)
72
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
73
+ inpt_foreground_crops = self._rand_crop(img=inpt, randomize=False)
74
+ return [tv_tensors.wrap(crop, like=inpt) for crop in inpt_foreground_crops]
@@ -7,13 +7,13 @@ import torch
7
7
  from monai.config.type_definitions import NdarrayOrTensor
8
8
  from monai.transforms.croppad import array as monai_croppad_transforms
9
9
  from torchvision import tv_tensors
10
- from torchvision.transforms import v2
11
10
  from typing_extensions import override
12
11
 
13
12
  from eva.vision.data import tv_tensors as eva_tv_tensors
13
+ from eva.vision.data.transforms import base
14
14
 
15
15
 
16
- class RandCropByPosNegLabel(v2.Transform):
16
+ class RandCropByPosNegLabel(base.RandomMonaiTransform):
17
17
  """Crop random fixed sized regions with the center being a foreground or background voxel.
18
18
 
19
19
  Its based on the Pos Neg Ratio and will return a list of arrays for all the cropped images.
@@ -91,6 +91,10 @@ class RandCropByPosNegLabel(v2.Transform):
91
91
  lazy=False,
92
92
  )
93
93
 
94
+ @override
95
+ def set_random_state(self, seed: int) -> None:
96
+ self._rand_crop.set_random_state(seed)
97
+
94
98
  def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
95
99
  mask = next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.Mask))
96
100
  self._rand_crop.randomize(label=mask)
@@ -0,0 +1,89 @@
1
+ """Crop image with random size or specific size ROI."""
2
+
3
+ import functools
4
+ from typing import Any, Dict, List, Sequence, Tuple
5
+
6
+ from monai.transforms.croppad import array as monai_croppad_transforms
7
+ from torchvision import tv_tensors
8
+ from torchvision.transforms import v2
9
+ from torchvision.transforms.v2 import _utils as tv_utils
10
+ from typing_extensions import override
11
+
12
+ from eva.vision.data import tv_tensors as eva_tv_tensors
13
+
14
+
15
+ class RandSpatialCrop(v2.Transform):
16
+ """Crop image with random size or specific size ROI.
17
+
18
+ It can crop at a random position as center or at the image center.
19
+ And allows to set the minimum and maximum size to limit the randomly
20
+ generated ROI.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ roi_size: Sequence[int] | int,
26
+ max_roi_size: Sequence[int] | int | None = None,
27
+ random_center: bool = True,
28
+ random_size: bool = False,
29
+ ) -> None:
30
+ """Initializes the transform.
31
+
32
+ Args:
33
+ roi_size: if `random_size` is True, it specifies the minimum crop
34
+ region. if `random_size` is False, it specifies the expected
35
+ ROI size to crop. e.g. [224, 224, 128]. if a dimension of ROI
36
+ size is larger than image size, will not crop that dimension of
37
+ the image. If its components have non-positive values, the
38
+ corresponding size of input image will be used. for example: if
39
+ the spatial size of input data is [40, 40, 40] and
40
+ `roi_size=[32, 64, -1]`, the spatial size of output data will be
41
+ [32, 40, 40].
42
+ max_roi_size: if `random_size` is True and `roi_size` specifies the
43
+ min crop region size, `max_roi_size` can specify the max crop
44
+ region size. if None, defaults to the input image size. if its
45
+ components have non-positive values, the corresponding size of
46
+ input image will be used.
47
+ random_center: crop at random position as center or the image center.
48
+ random_size: crop with random size or specific size ROI. if True, the
49
+ actual size is sampled from `randint(roi_size, max_roi_size + 1)`.
50
+ """
51
+ super().__init__()
52
+
53
+ self._rand_spatial_crop = monai_croppad_transforms.RandSpatialCrop(
54
+ roi_size=roi_size,
55
+ max_roi_size=max_roi_size,
56
+ random_center=random_center,
57
+ random_size=random_size,
58
+ )
59
+ self._cropper = monai_croppad_transforms.Crop()
60
+
61
+ def set_random_state(self, seed: int) -> None:
62
+ """Set the random state for the transform."""
63
+ self._rand_spatial_crop.set_random_state(seed)
64
+
65
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
66
+ t, h, w = tv_utils.query_chw(flat_inputs)
67
+ self._rand_spatial_crop.randomize((t, h, w))
68
+ return {}
69
+
70
+ @functools.singledispatchmethod
71
+ @override
72
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
73
+ return inpt
74
+
75
+ @_transform.register(tv_tensors.Image)
76
+ @_transform.register(eva_tv_tensors.Volume)
77
+ @_transform.register(tv_tensors.Mask)
78
+ def _(self, inpt: Any, params: Dict[str, Any]) -> Any:
79
+ slices = self._get_crop_slices()
80
+ inpt_rand_crop = self._cropper(inpt, slices=slices)
81
+ return tv_tensors.wrap(inpt_rand_crop, like=inpt)
82
+
83
+ def _get_crop_slices(self) -> Tuple[slice, ...]:
84
+ """Returns the sequence of slices to crop."""
85
+ if self._rand_spatial_crop.random_center:
86
+ return self._rand_spatial_crop._slices
87
+
88
+ central_cropper = monai_croppad_transforms.CenterSpatialCrop(self._size)
89
+ return central_cropper.compute_slices(self._rand_spatial_crop._size) # type: ignore
@@ -7,13 +7,13 @@ import numpy as np
7
7
  from monai.config.type_definitions import DtypeLike
8
8
  from monai.transforms.intensity import array as monai_intensity_transforms
9
9
  from torchvision import tv_tensors
10
- from torchvision.transforms import v2
11
10
  from typing_extensions import override
12
11
 
13
12
  from eva.vision.data import tv_tensors as eva_tv_tensors
13
+ from eva.vision.data.transforms import base
14
14
 
15
15
 
16
- class RandScaleIntensity(v2.Transform):
16
+ class RandScaleIntensity(base.RandomMonaiTransform):
17
17
  """Randomly scale the intensity of input image.
18
18
 
19
19
  The factor is by ``v = v * (1 + factor)``, where
@@ -47,6 +47,10 @@ class RandScaleIntensity(v2.Transform):
47
47
  dtype=dtype,
48
48
  )
49
49
 
50
+ @override
51
+ def set_random_state(self, seed: int) -> None:
52
+ self._rand_scale_intensity.set_random_state(seed)
53
+
50
54
  @functools.singledispatchmethod
51
55
  @override
52
56
  def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -5,13 +5,13 @@ from typing import Any, Dict
5
5
 
6
6
  from monai.transforms.intensity import array as monai_intensity_transforms
7
7
  from torchvision import tv_tensors
8
- from torchvision.transforms import v2
9
8
  from typing_extensions import override
10
9
 
11
10
  from eva.vision.data import tv_tensors as eva_tv_tensors
11
+ from eva.vision.data.transforms import base
12
12
 
13
13
 
14
- class RandShiftIntensity(v2.Transform):
14
+ class RandShiftIntensity(base.RandomMonaiTransform):
15
15
  """Randomly shift intensity with randomly picked offset."""
16
16
 
17
17
  def __init__(
@@ -36,13 +36,17 @@ class RandShiftIntensity(v2.Transform):
36
36
  """
37
37
  super().__init__()
38
38
 
39
- self._rand_swift_intensity = monai_intensity_transforms.RandShiftIntensity(
39
+ self._rand_shift_intensity = monai_intensity_transforms.RandShiftIntensity(
40
40
  offsets=offsets,
41
41
  safe=safe,
42
42
  prob=prob,
43
43
  channel_wise=channel_wise,
44
44
  )
45
45
 
46
+ @override
47
+ def set_random_state(self, seed: int) -> None:
48
+ self._rand_shift_intensity.set_random_state(seed)
49
+
46
50
  @functools.singledispatchmethod
47
51
  @override
48
52
  def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -51,5 +55,5 @@ class RandShiftIntensity(v2.Transform):
51
55
  @_transform.register(tv_tensors.Image)
52
56
  @_transform.register(eva_tv_tensors.Volume)
53
57
  def _(self, inpt: tv_tensors.Image, params: Dict[str, Any]) -> Any:
54
- inpt_scaled = self._rand_swift_intensity(inpt)
58
+ inpt_scaled = self._rand_shift_intensity(inpt)
55
59
  return tv_tensors.wrap(inpt_scaled, like=inpt)