batchgeneratorsv2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2/helpers/fft_conv.py +149 -0
- batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
- batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
- batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
- batchgeneratorsv2/transforms/local/__init__.py +0 -0
- batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
- batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
- batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
- batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
- batchgeneratorsv2/transforms/local/local_transform.py +86 -0
- batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
- batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
- batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
- batchgeneratorsv2/transforms/noise/rician.py +61 -0
- batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
- batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
- batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
- batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
- batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
- batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
- batchgeneratorsv2/transforms/utils/random.py +46 -0
- batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
- batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
- batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
- batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
- batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
- batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from scipy.ndimage import fourier_gaussian
|
|
7
|
+
from torch.nn.functional import grid_sample
|
|
8
|
+
|
|
9
|
+
from batchgeneratorsv2.helpers.scalar_type import sample_scalar
|
|
10
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
11
|
+
from batchgeneratorsv2.transforms.spatial.spatial import _create_centered_identity_grid2, \
|
|
12
|
+
_convert_my_grid_to_grid_sample_grid, create_affine_matrix_2d, create_affine_matrix_3d
|
|
13
|
+
from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChannelMisalignmentTransform(ImageOnlyTransform):
|
|
17
|
+
"""
|
|
18
|
+
The misalignment data augmentation is introduced in Nature Scientific reports 2023.
|
|
19
|
+
|
|
20
|
+
Apply channel-wise misalignment to selected image channels.
|
|
21
|
+
This transform simulates registration errors between channels by randomly
|
|
22
|
+
applying one or more of the following operations to the specified image
|
|
23
|
+
channels:
|
|
24
|
+
- squeezing/scaling (good approximation for misalignments between the T2w and DWI MRI sequences)
|
|
25
|
+
- rotation
|
|
26
|
+
- translation via shifted crop center
|
|
27
|
+
|
|
28
|
+
If you use this augmentation please cite: https://www.nature.com/articles/s41598-023-46747-z
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
im_channels_2_misalign : Tuple[int, ...]
|
|
33
|
+
Image channels to which the misalignment is applied.
|
|
34
|
+
|
|
35
|
+
squeezing_zyx : Tuple[float, ...], default=(0.1, 0, 0)
|
|
36
|
+
Maximum relative scaling deviation per axis in ZYX order.
|
|
37
|
+
For each active axis, the scale factor is sampled uniformly from [1 - s, 1 + s].
|
|
38
|
+
|
|
39
|
+
p_squeeze : float, default=0.0
|
|
40
|
+
Probability of applying squeezing/scaling.
|
|
41
|
+
|
|
42
|
+
rotation_ax_cor_sag : Tuple[float, ...], default=(np.pi, np.pi, np.pi)
|
|
43
|
+
Maximum absolute rotation angle per axis in axial/coronal/sagittal
|
|
44
|
+
order. Angles are sampled uniformly from [-a, a].
|
|
45
|
+
|
|
46
|
+
rad_or_deg : {"rad", "deg"}
|
|
47
|
+
Unit of `rotation_ax_cor_sag`.
|
|
48
|
+
|
|
49
|
+
p_rotation : float, default=0.0
|
|
50
|
+
Probability of applying rotation.
|
|
51
|
+
|
|
52
|
+
shift_zyx : Tuple[int, ...], default=(2, 32, 32)
|
|
53
|
+
Maximum integer shift per axis in ZYX order. For each axis, the shift
|
|
54
|
+
is sampled uniformly from [-s, s].
|
|
55
|
+
|
|
56
|
+
p_shift : float, default=0.0
|
|
57
|
+
Probability of applying translation.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self,
|
|
62
|
+
im_channels_2_misalign: Tuple[int,] = [0, ],
|
|
63
|
+
|
|
64
|
+
squeezing_zyx: Tuple[float, ...] = (0.1, 0, 0),
|
|
65
|
+
p_squeeze: float = 0.0,
|
|
66
|
+
|
|
67
|
+
rotation_ax_cor_sag: Tuple[float, ...] = (np.pi, np.pi, np.pi),
|
|
68
|
+
rad_or_deg=None,
|
|
69
|
+
p_rotation: float = 0.0,
|
|
70
|
+
|
|
71
|
+
shift_zyx: Tuple[int, ...] = (2, 32, 32),
|
|
72
|
+
p_shift: float = 0.0,
|
|
73
|
+
):
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.im_channels_2_misalign = im_channels_2_misalign
|
|
76
|
+
|
|
77
|
+
self.squeezingZYX = squeezing_zyx
|
|
78
|
+
self.p_squeeze = p_squeeze
|
|
79
|
+
|
|
80
|
+
if rad_or_deg == "rad":
|
|
81
|
+
if any(rot > np.pi / 12 for rot in rotation_ax_cor_sag):
|
|
82
|
+
raise Warning("The rotation is probably too big")
|
|
83
|
+
if any(rot > np.pi for rot in rotation_ax_cor_sag):
|
|
84
|
+
raise ValueError("The rotation is probably in deg or bigger than 180°")
|
|
85
|
+
self.rotation_ax_cor_sag = rotation_ax_cor_sag
|
|
86
|
+
elif rad_or_deg == "deg":
|
|
87
|
+
self.rotation_ax_cor_sag = [rot / 360 * (2 * np.pi) for rot in rotation_ax_cor_sag]
|
|
88
|
+
else:
|
|
89
|
+
raise RuntimeError('Please define the rad_or_deg: "rad"/"deg"')
|
|
90
|
+
self.p_rotation = p_rotation
|
|
91
|
+
|
|
92
|
+
self.shiftZYX = shift_zyx
|
|
93
|
+
self.p_shift = p_shift
|
|
94
|
+
|
|
95
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
96
|
+
dim = data_dict['image'].ndim - 1
|
|
97
|
+
|
|
98
|
+
do_squeeze = np.random.uniform() < self.p_squeeze
|
|
99
|
+
do_rotation = np.random.uniform() < self.p_rotation
|
|
100
|
+
do_shift = np.random.uniform() < self.p_shift
|
|
101
|
+
do_deform = False
|
|
102
|
+
|
|
103
|
+
# Squeeze
|
|
104
|
+
if do_squeeze:
|
|
105
|
+
squeezes = [np.random.uniform(1 - self.squeezingZYX[i], 1 + self.squeezingZYX[i]) for i in range(dim)]
|
|
106
|
+
else:
|
|
107
|
+
squeezes = [1] * dim
|
|
108
|
+
|
|
109
|
+
# Rotation
|
|
110
|
+
if do_rotation:
|
|
111
|
+
angles = [np.random.uniform(-self.rotation_ax_cor_sag[i], self.rotation_ax_cor_sag[i]) for i in range(dim)]
|
|
112
|
+
else:
|
|
113
|
+
angles = [0] * dim
|
|
114
|
+
|
|
115
|
+
# affine matrix
|
|
116
|
+
if do_squeeze or do_rotation:
|
|
117
|
+
if dim == 3:
|
|
118
|
+
affine = create_affine_matrix_3d(angles, squeezes)
|
|
119
|
+
elif dim == 2:
|
|
120
|
+
affine = create_affine_matrix_2d(angles[-1], squeezes)
|
|
121
|
+
else:
|
|
122
|
+
raise RuntimeError(f'Unsupported dimension: {dim}')
|
|
123
|
+
else:
|
|
124
|
+
affine = None # this will allow us to detect that we can skip computations
|
|
125
|
+
|
|
126
|
+
# elastic deformation. We need to create the displacement field here
|
|
127
|
+
# we use the method from augment_spatial_2 in batchgenerators
|
|
128
|
+
if do_deform:
|
|
129
|
+
if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
|
|
130
|
+
deformation_scales = [
|
|
131
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None,
|
|
132
|
+
patch_size=self.patch_size)
|
|
133
|
+
] * dim
|
|
134
|
+
else:
|
|
135
|
+
deformation_scales = [
|
|
136
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i,
|
|
137
|
+
patch_size=self.patch_size)
|
|
138
|
+
for i in range(0, 3)
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
# sigmas must be in pixels, as this will be applied to the deformation field
|
|
142
|
+
sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
|
|
143
|
+
|
|
144
|
+
magnitude = [
|
|
145
|
+
sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
|
|
146
|
+
dim=i, deformation_scale=deformation_scales[i])
|
|
147
|
+
for i in range(0, 3)]
|
|
148
|
+
# doing it like this for better memory layout for blurring
|
|
149
|
+
offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
|
|
150
|
+
|
|
151
|
+
# all the additional time elastic deform takes is spent here
|
|
152
|
+
for d in range(dim):
|
|
153
|
+
# fft torch, slower
|
|
154
|
+
# for i in range(offsets.ndim - 1):
|
|
155
|
+
# offsets[d] = blur_dimension(offsets[d][None], sigmas[d], i, force_use_fft=True, truncate=6)[0]
|
|
156
|
+
|
|
157
|
+
# fft numpy, this is faster o.O
|
|
158
|
+
tmp = np.fft.fftn(offsets[d].numpy())
|
|
159
|
+
tmp = fourier_gaussian(tmp, sigmas[d])
|
|
160
|
+
offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
|
|
161
|
+
|
|
162
|
+
mx = torch.max(torch.abs(offsets[d]))
|
|
163
|
+
offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
|
|
164
|
+
offsets = torch.permute(offsets, (1, 2, 3, 0))
|
|
165
|
+
else:
|
|
166
|
+
offsets = None
|
|
167
|
+
|
|
168
|
+
# shape = data_dict['image'].shape[1:]
|
|
169
|
+
# if do_shift:
|
|
170
|
+
# for i in shape:
|
|
171
|
+
# print(i)
|
|
172
|
+
# center_location_in_pixels = [i / 2 + np.random.randint(self.shiftXYZ[j], self.shiftXYZ[j]+1) for i, j in zip(shape, range(dim - 1, -1, -1))][::-1]
|
|
173
|
+
# else:
|
|
174
|
+
# center_location_in_pixels = [i / 2 for i in shape][::-1]
|
|
175
|
+
|
|
176
|
+
shape = data_dict['image'].shape[1:]
|
|
177
|
+
if not do_shift:
|
|
178
|
+
center_location_in_pixels = [i / 2 for i in shape]
|
|
179
|
+
else:
|
|
180
|
+
center_location_in_pixels = [shape[i] / 2 + np.random.randint(-self.shiftZYX[i], self.shiftZYX[i] + 1) for i
|
|
181
|
+
in range(dim)]
|
|
182
|
+
|
|
183
|
+
return {
|
|
184
|
+
'affine': affine,
|
|
185
|
+
'elastic_offsets': offsets,
|
|
186
|
+
'center_location_in_pixels': center_location_in_pixels
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
190
|
+
im_shape = img.shape[1:]
|
|
191
|
+
if params['affine'] is None and params['elastic_offsets'] is None:
|
|
192
|
+
for ch in self.im_channels_2_misalign:
|
|
193
|
+
img[ch, ...] = crop_tensor(img[ch, ...].unsqueeze(0),
|
|
194
|
+
[math.floor(i) for i in params['center_location_in_pixels']], im_shape,
|
|
195
|
+
pad_mode='constant', pad_kwargs={'value': 0})
|
|
196
|
+
return img
|
|
197
|
+
else:
|
|
198
|
+
grid = _create_centered_identity_grid2(im_shape)
|
|
199
|
+
|
|
200
|
+
# we deform first, then rotate
|
|
201
|
+
if params['elastic_offsets'] is not None:
|
|
202
|
+
grid += params['elastic_offsets']
|
|
203
|
+
if params['affine'] is not None:
|
|
204
|
+
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
|
|
205
|
+
|
|
206
|
+
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
|
|
207
|
+
# only do this if we elastic deform
|
|
208
|
+
if params['elastic_offsets'] is not None:
|
|
209
|
+
mn = grid.mean(dim=list(range(img.ndim - 1)))
|
|
210
|
+
else:
|
|
211
|
+
mn = 0
|
|
212
|
+
|
|
213
|
+
# new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
|
|
214
|
+
new_center = torch.Tensor([0, 0, 0])
|
|
215
|
+
grid += (new_center - mn)
|
|
216
|
+
|
|
217
|
+
for ch in self.im_channels_2_misalign:
|
|
218
|
+
img[ch, ...] = grid_sample(img[ch, ...].unsqueeze(0).unsqueeze(0),
|
|
219
|
+
_convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
|
|
220
|
+
mode='bilinear', padding_mode="zeros", align_corners=False)[0]
|
|
221
|
+
img[ch, ...] = crop_tensor(img[ch, ...].unsqueeze(0),
|
|
222
|
+
[math.floor(i) for i in params['center_location_in_pixels']], im_shape,
|
|
223
|
+
pad_mode='constant', pad_kwargs={'value': 0})
|
|
224
|
+
return img
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
6
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
7
|
+
from torch.nn.functional import interpolate
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SimulateLowResolutionTransform(ImageOnlyTransform):
|
|
11
|
+
def __init__(self,
|
|
12
|
+
scale: RandomScalar,
|
|
13
|
+
synchronize_channels: bool,
|
|
14
|
+
synchronize_axes: bool,
|
|
15
|
+
ignore_axes: Tuple[int, ...],
|
|
16
|
+
allowed_channels: Tuple[int, ...] = None,
|
|
17
|
+
p_per_channel: float = 1):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.scale = scale
|
|
20
|
+
self.synchronize_channels = synchronize_channels
|
|
21
|
+
self.synchronize_axes = synchronize_axes
|
|
22
|
+
self.ignore_axes = ignore_axes
|
|
23
|
+
self.allowed_channels = allowed_channels
|
|
24
|
+
self.p_per_channel = p_per_channel
|
|
25
|
+
|
|
26
|
+
self.upmodes = {
|
|
27
|
+
1: 'linear',
|
|
28
|
+
2: 'bilinear',
|
|
29
|
+
3: 'trilinear'
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
33
|
+
shape = data_dict['image'].shape
|
|
34
|
+
if self.allowed_channels is None:
|
|
35
|
+
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
|
|
36
|
+
else:
|
|
37
|
+
apply_to_channel = [i for i in self.allowed_channels if torch.rand(1) < self.p_per_channel]
|
|
38
|
+
if self.synchronize_channels:
|
|
39
|
+
if self.synchronize_axes:
|
|
40
|
+
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=None)] * (len(shape) - 1)] * len(apply_to_channel))
|
|
41
|
+
else:
|
|
42
|
+
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=d) for d in range(len(shape) - 1)]] * len(apply_to_channel))
|
|
43
|
+
else:
|
|
44
|
+
if self.synchronize_axes:
|
|
45
|
+
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=None)] * (len(shape) - 1) for c in apply_to_channel])
|
|
46
|
+
else:
|
|
47
|
+
scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=d) for d in range(len(shape) - 1)] for c in apply_to_channel])
|
|
48
|
+
if len(scales) > 0 and not self.ignore_axes is None:
|
|
49
|
+
scales[:, self.ignore_axes] = 1
|
|
50
|
+
return {
|
|
51
|
+
'apply_to_channel': apply_to_channel,
|
|
52
|
+
'scales': scales
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
56
|
+
orig_shape = img.shape[1:]
|
|
57
|
+
# we cannot batch this because the downsampled shaps will be different for each channel
|
|
58
|
+
for c, s in zip(params['apply_to_channel'], params['scales']):
|
|
59
|
+
new_shape = [round(i * j.item()) for i, j in zip(orig_shape, s)]
|
|
60
|
+
downsampled = interpolate(img[c][None, None], new_shape, mode='nearest-exact')
|
|
61
|
+
img[c] = interpolate(downsampled, orig_shape, mode=self.upmodes[img.ndim - 1])[0, 0]
|
|
62
|
+
return img
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
from time import time
|
|
67
|
+
import numpy as np
|
|
68
|
+
import os
|
|
69
|
+
|
|
70
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
71
|
+
torch.set_num_threads(1)
|
|
72
|
+
|
|
73
|
+
mbt = SimulateLowResolutionTransform((0.1, 1.), synchronize_channels=False, synchronize_axes=False, ignore_axes=None, allowed_channels=None, p_per_channel=1)
|
|
74
|
+
|
|
75
|
+
times_torch = []
|
|
76
|
+
for _ in range(30):
|
|
77
|
+
data_dict = {'image': torch.ones((3, 128, 192, 64))}
|
|
78
|
+
st = time()
|
|
79
|
+
out = mbt(**data_dict)
|
|
80
|
+
times_torch.append(time() - st)
|
|
81
|
+
print('torch', np.mean(times_torch))
|
|
82
|
+
|
|
83
|
+
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform as SLRT
|
|
84
|
+
|
|
85
|
+
gnt_bg = SLRT((0.1, 1), True, p_per_channel=1, order_downsample=0, order_upsample=1, p_per_sample=1)
|
|
86
|
+
times_bg = []
|
|
87
|
+
for _ in range(30):
|
|
88
|
+
data_dict = {'data': np.ones((1, 3, 128, 192, 64))}
|
|
89
|
+
st = time()
|
|
90
|
+
out = gnt_bg(**data_dict)
|
|
91
|
+
times_bg.append(time() - st)
|
|
92
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MirrorTransform(BasicTransform):
|
|
9
|
+
def __init__(self, allowed_axes: Tuple[int, ...]):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.allowed_axes = allowed_axes
|
|
12
|
+
|
|
13
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
14
|
+
axes = [i for i in self.allowed_axes if torch.rand(1) < 0.5]
|
|
15
|
+
return {
|
|
16
|
+
'axes': axes
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
20
|
+
if len(params['axes']) == 0:
|
|
21
|
+
return img
|
|
22
|
+
axes = [i + 1 for i in params['axes']]
|
|
23
|
+
return torch.flip(img, axes)
|
|
24
|
+
|
|
25
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
26
|
+
if len(params['axes']) == 0:
|
|
27
|
+
return segmentation
|
|
28
|
+
axes = [i + 1 for i in params['axes']]
|
|
29
|
+
return torch.flip(segmentation, axes)
|
|
30
|
+
|
|
31
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
32
|
+
if len(params['axes']) == 0:
|
|
33
|
+
return regression_target
|
|
34
|
+
axes = [i + 1 for i in params['axes']]
|
|
35
|
+
return torch.flip(regression_target, axes)
|
|
36
|
+
|
|
37
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == '__main__':
|
|
45
|
+
from time import time
|
|
46
|
+
import numpy as np
|
|
47
|
+
import os
|
|
48
|
+
|
|
49
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
50
|
+
torch.set_num_threads(1)
|
|
51
|
+
|
|
52
|
+
mbt = MirrorTransform((0, 1, 2))
|
|
53
|
+
|
|
54
|
+
times_torch = []
|
|
55
|
+
for _ in range(100):
|
|
56
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
57
|
+
st = time()
|
|
58
|
+
out = mbt(**data_dict)
|
|
59
|
+
times_torch.append(time() - st)
|
|
60
|
+
print('torch', np.mean(times_torch))
|
|
61
|
+
|
|
62
|
+
from batchgenerators.transforms.spatial_transforms import MirrorTransform as BGMirror
|
|
63
|
+
|
|
64
|
+
gnt_bg = BGMirror((0, 1, 2))
|
|
65
|
+
times_bg = []
|
|
66
|
+
for _ in range(100):
|
|
67
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
68
|
+
st = time()
|
|
69
|
+
out = gnt_bg(**data_dict)
|
|
70
|
+
times_bg.append(time() - st)
|
|
71
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Tuple, Set, List
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
6
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Rot90Transform(BasicTransform):
|
|
10
|
+
"""
|
|
11
|
+
Applies a random 90-degree rotation to image and associated targets along randomly chosen axes.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
num_rot (Tuple[int]): Possible multiples of 90 degrees to rotate (e.g., (1, 2, 3)).
|
|
15
|
+
allowed_axes (Set[int]): Spatial axes to randomly select rotation axes from (e.g., {0, 1, 2}).
|
|
16
|
+
p_per_sample (float): Probability of applying the transform to a sample.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, num_axis_combinations: RandomScalar, num_rot_per_combination: Tuple[int, ...] = (1, 2, 3),
|
|
20
|
+
allowed_axes: Set[int] = {0, 1, 2}):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.num_axis_combinations = num_axis_combinations
|
|
23
|
+
self.num_rot_per_combination = num_rot_per_combination
|
|
24
|
+
self.allowed_axes = allowed_axes
|
|
25
|
+
|
|
26
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
27
|
+
n_axes_combinations = round(sample_scalar(self.num_axis_combinations))
|
|
28
|
+
axis_combinations = []
|
|
29
|
+
num_rot_per_combination = []
|
|
30
|
+
for i in range(n_axes_combinations):
|
|
31
|
+
num_rot_per_combination.append(int(np.random.choice(self.num_rot_per_combination)))
|
|
32
|
+
axis_combinations.append(sorted(np.random.choice(list(self.allowed_axes), size=2, replace=False)))
|
|
33
|
+
# +1 because we skip channel dimension
|
|
34
|
+
axis_combinations[-1] = [a + 1 for a in axis_combinations[-1]]
|
|
35
|
+
return {
|
|
36
|
+
'num_rot_per_combination': num_rot_per_combination,
|
|
37
|
+
'axis_combinations': axis_combinations
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
41
|
+
return self._maybe_rot90(img, **params)
|
|
42
|
+
|
|
43
|
+
def _apply_to_segmentation(self, seg: torch.Tensor, **params) -> torch.Tensor:
|
|
44
|
+
return self._maybe_rot90(seg, **params)
|
|
45
|
+
|
|
46
|
+
def _apply_to_regr_target(self, regression_target: torch.Tensor, **params) -> torch.Tensor:
|
|
47
|
+
return self._maybe_rot90(regression_target, **params)
|
|
48
|
+
|
|
49
|
+
def _maybe_rot90(self, tensor: torch.Tensor, num_rot_per_combination: List[int], axis_combinations: List[Tuple[int, int]]) -> torch.Tensor:
|
|
50
|
+
for n_rot, axes in zip(num_rot_per_combination, axis_combinations):
|
|
51
|
+
tensor = torch.rot90(tensor, k=n_rot, dims=axes)
|
|
52
|
+
return tensor
|
|
53
|
+
|
|
54
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
55
|
+
raise NotImplementedError
|
|
56
|
+
|
|
57
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
if __name__ == '__main__':
|
|
61
|
+
# Create dummy 3D image and segmentation tensors: (C, X, Y, Z)
|
|
62
|
+
image = torch.arange(1 * 8 * 8 * 8).reshape(1, 8, 8, 8).float()
|
|
63
|
+
seg = torch.zeros_like(image)
|
|
64
|
+
|
|
65
|
+
# Instantiate the transform
|
|
66
|
+
transform = Rot90Transform(num_axis_combinations=2, num_rot_per_combination=(1, 2, 3), allowed_axes={0, 1, 2}) # always apply for demo
|
|
67
|
+
|
|
68
|
+
# Get random parameters for this sample
|
|
69
|
+
params = transform.get_parameters(image=image, segmentation=seg)
|
|
70
|
+
|
|
71
|
+
# Apply transform
|
|
72
|
+
image_rot = transform._apply_to_image(image, **params)
|
|
73
|
+
seg_rot = transform._apply_to_segmentation(seg, **params)
|
|
74
|
+
|
|
75
|
+
# Print to verify
|
|
76
|
+
print("Original image shape:", image.shape)
|
|
77
|
+
print("Rotated image shape:", image_rot.shape)
|
|
78
|
+
print("Rotation parameters:", params)
|