batchgeneratorsv2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. batchgeneratorsv2/benchmarks/__init__.py +0 -0
  2. batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  3. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  4. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  5. batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  6. batchgeneratorsv2/dataloading/__init__.py +0 -0
  7. batchgeneratorsv2/helpers/__init__.py +0 -0
  8. batchgeneratorsv2/helpers/fft_conv.py +149 -0
  9. batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
  13. batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  15. batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
  16. batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  17. batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  18. batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  19. batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  20. batchgeneratorsv2/transforms/local/__init__.py +0 -0
  21. batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  22. batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  23. batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  24. batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  25. batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  26. batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
  28. batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  29. batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  30. batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  32. batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  33. batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  34. batchgeneratorsv2/transforms/noise/rician.py +61 -0
  35. batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  36. batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  37. batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  38. batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
  39. batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  40. batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  41. batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
  42. batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  43. batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  44. batchgeneratorsv2/transforms/utils/compose.py +89 -0
  45. batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  46. batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  47. batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  48. batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
  49. batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
  50. batchgeneratorsv2/transforms/utils/random.py +46 -0
  51. batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
  52. batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  53. batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
  54. batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
  55. batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
  56. batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
  57. batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,224 @@
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ from scipy.ndimage import fourier_gaussian
7
+ from torch.nn.functional import grid_sample
8
+
9
+ from batchgeneratorsv2.helpers.scalar_type import sample_scalar
10
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
11
+ from batchgeneratorsv2.transforms.spatial.spatial import _create_centered_identity_grid2, \
12
+ _convert_my_grid_to_grid_sample_grid, create_affine_matrix_2d, create_affine_matrix_3d
13
+ from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
14
+
15
+
16
+ class ChannelMisalignmentTransform(ImageOnlyTransform):
17
+ """
18
+ The misalignment data augmentation is introduced in Nature Scientific reports 2023.
19
+
20
+ Apply channel-wise misalignment to selected image channels.
21
+ This transform simulates registration errors between channels by randomly
22
+ applying one or more of the following operations to the specified image
23
+ channels:
24
+ - squeezing/scaling (good approximation for misalignments between the T2w and DWI MRI sequences)
25
+ - rotation
26
+ - translation via shifted crop center
27
+
28
+ If you use this augmentation please cite: https://www.nature.com/articles/s41598-023-46747-z
29
+
30
+ Parameters
31
+ ----------
32
+ im_channels_2_misalign : Tuple[int, ...]
33
+ Image channels to which the misalignment is applied.
34
+
35
+ squeezing_zyx : Tuple[float, ...], default=(0.1, 0, 0)
36
+ Maximum relative scaling deviation per axis in ZYX order.
37
+ For each active axis, the scale factor is sampled uniformly from [1 - s, 1 + s].
38
+
39
+ p_squeeze : float, default=0.0
40
+ Probability of applying squeezing/scaling.
41
+
42
+ rotation_ax_cor_sag : Tuple[float, ...], default=(np.pi, np.pi, np.pi)
43
+ Maximum absolute rotation angle per axis in axial/coronal/sagittal
44
+ order. Angles are sampled uniformly from [-a, a].
45
+
46
+ rad_or_deg : {"rad", "deg"}
47
+ Unit of `rotation_ax_cor_sag`.
48
+
49
+ p_rotation : float, default=0.0
50
+ Probability of applying rotation.
51
+
52
+ shift_zyx : Tuple[int, ...], default=(2, 32, 32)
53
+ Maximum integer shift per axis in ZYX order. For each axis, the shift
54
+ is sampled uniformly from [-s, s].
55
+
56
+ p_shift : float, default=0.0
57
+ Probability of applying translation.
58
+
59
+ """
60
+
61
+ def __init__(self,
62
+ im_channels_2_misalign: Tuple[int,] = [0, ],
63
+
64
+ squeezing_zyx: Tuple[float, ...] = (0.1, 0, 0),
65
+ p_squeeze: float = 0.0,
66
+
67
+ rotation_ax_cor_sag: Tuple[float, ...] = (np.pi, np.pi, np.pi),
68
+ rad_or_deg=None,
69
+ p_rotation: float = 0.0,
70
+
71
+ shift_zyx: Tuple[int, ...] = (2, 32, 32),
72
+ p_shift: float = 0.0,
73
+ ):
74
+ super().__init__()
75
+ self.im_channels_2_misalign = im_channels_2_misalign
76
+
77
+ self.squeezingZYX = squeezing_zyx
78
+ self.p_squeeze = p_squeeze
79
+
80
+ if rad_or_deg == "rad":
81
+ if any(rot > np.pi / 12 for rot in rotation_ax_cor_sag):
82
+ raise Warning("The rotation is probably too big")
83
+ if any(rot > np.pi for rot in rotation_ax_cor_sag):
84
+ raise ValueError("The rotation is probably in deg or bigger than 180°")
85
+ self.rotation_ax_cor_sag = rotation_ax_cor_sag
86
+ elif rad_or_deg == "deg":
87
+ self.rotation_ax_cor_sag = [rot / 360 * (2 * np.pi) for rot in rotation_ax_cor_sag]
88
+ else:
89
+ raise RuntimeError('Please define the rad_or_deg: "rad"/"deg"')
90
+ self.p_rotation = p_rotation
91
+
92
+ self.shiftZYX = shift_zyx
93
+ self.p_shift = p_shift
94
+
95
+ def get_parameters(self, **data_dict) -> dict:
96
+ dim = data_dict['image'].ndim - 1
97
+
98
+ do_squeeze = np.random.uniform() < self.p_squeeze
99
+ do_rotation = np.random.uniform() < self.p_rotation
100
+ do_shift = np.random.uniform() < self.p_shift
101
+ do_deform = False
102
+
103
+ # Squeeze
104
+ if do_squeeze:
105
+ squeezes = [np.random.uniform(1 - self.squeezingZYX[i], 1 + self.squeezingZYX[i]) for i in range(dim)]
106
+ else:
107
+ squeezes = [1] * dim
108
+
109
+ # Rotation
110
+ if do_rotation:
111
+ angles = [np.random.uniform(-self.rotation_ax_cor_sag[i], self.rotation_ax_cor_sag[i]) for i in range(dim)]
112
+ else:
113
+ angles = [0] * dim
114
+
115
+ # affine matrix
116
+ if do_squeeze or do_rotation:
117
+ if dim == 3:
118
+ affine = create_affine_matrix_3d(angles, squeezes)
119
+ elif dim == 2:
120
+ affine = create_affine_matrix_2d(angles[-1], squeezes)
121
+ else:
122
+ raise RuntimeError(f'Unsupported dimension: {dim}')
123
+ else:
124
+ affine = None # this will allow us to detect that we can skip computations
125
+
126
+ # elastic deformation. We need to create the displacement field here
127
+ # we use the method from augment_spatial_2 in batchgenerators
128
+ if do_deform:
129
+ if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
130
+ deformation_scales = [
131
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None,
132
+ patch_size=self.patch_size)
133
+ ] * dim
134
+ else:
135
+ deformation_scales = [
136
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i,
137
+ patch_size=self.patch_size)
138
+ for i in range(0, 3)
139
+ ]
140
+
141
+ # sigmas must be in pixels, as this will be applied to the deformation field
142
+ sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
143
+
144
+ magnitude = [
145
+ sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
146
+ dim=i, deformation_scale=deformation_scales[i])
147
+ for i in range(0, 3)]
148
+ # doing it like this for better memory layout for blurring
149
+ offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
150
+
151
+ # all the additional time elastic deform takes is spent here
152
+ for d in range(dim):
153
+ # fft torch, slower
154
+ # for i in range(offsets.ndim - 1):
155
+ # offsets[d] = blur_dimension(offsets[d][None], sigmas[d], i, force_use_fft=True, truncate=6)[0]
156
+
157
+ # fft numpy, this is faster o.O
158
+ tmp = np.fft.fftn(offsets[d].numpy())
159
+ tmp = fourier_gaussian(tmp, sigmas[d])
160
+ offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
161
+
162
+ mx = torch.max(torch.abs(offsets[d]))
163
+ offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
164
+ offsets = torch.permute(offsets, (1, 2, 3, 0))
165
+ else:
166
+ offsets = None
167
+
168
+ # shape = data_dict['image'].shape[1:]
169
+ # if do_shift:
170
+ # for i in shape:
171
+ # print(i)
172
+ # center_location_in_pixels = [i / 2 + np.random.randint(self.shiftXYZ[j], self.shiftXYZ[j]+1) for i, j in zip(shape, range(dim - 1, -1, -1))][::-1]
173
+ # else:
174
+ # center_location_in_pixels = [i / 2 for i in shape][::-1]
175
+
176
+ shape = data_dict['image'].shape[1:]
177
+ if not do_shift:
178
+ center_location_in_pixels = [i / 2 for i in shape]
179
+ else:
180
+ center_location_in_pixels = [shape[i] / 2 + np.random.randint(-self.shiftZYX[i], self.shiftZYX[i] + 1) for i
181
+ in range(dim)]
182
+
183
+ return {
184
+ 'affine': affine,
185
+ 'elastic_offsets': offsets,
186
+ 'center_location_in_pixels': center_location_in_pixels
187
+ }
188
+
189
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
190
+ im_shape = img.shape[1:]
191
+ if params['affine'] is None and params['elastic_offsets'] is None:
192
+ for ch in self.im_channels_2_misalign:
193
+ img[ch, ...] = crop_tensor(img[ch, ...].unsqueeze(0),
194
+ [math.floor(i) for i in params['center_location_in_pixels']], im_shape,
195
+ pad_mode='constant', pad_kwargs={'value': 0})
196
+ return img
197
+ else:
198
+ grid = _create_centered_identity_grid2(im_shape)
199
+
200
+ # we deform first, then rotate
201
+ if params['elastic_offsets'] is not None:
202
+ grid += params['elastic_offsets']
203
+ if params['affine'] is not None:
204
+ grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
205
+
206
+ # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
207
+ # only do this if we elastic deform
208
+ if params['elastic_offsets'] is not None:
209
+ mn = grid.mean(dim=list(range(img.ndim - 1)))
210
+ else:
211
+ mn = 0
212
+
213
+ # new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
214
+ new_center = torch.Tensor([0, 0, 0])
215
+ grid += (new_center - mn)
216
+
217
+ for ch in self.im_channels_2_misalign:
218
+ img[ch, ...] = grid_sample(img[ch, ...].unsqueeze(0).unsqueeze(0),
219
+ _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
220
+ mode='bilinear', padding_mode="zeros", align_corners=False)[0]
221
+ img[ch, ...] = crop_tensor(img[ch, ...].unsqueeze(0),
222
+ [math.floor(i) for i in params['center_location_in_pixels']], im_shape,
223
+ pad_mode='constant', pad_kwargs={'value': 0})
224
+ return img
@@ -0,0 +1,92 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
6
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
7
+ from torch.nn.functional import interpolate
8
+
9
+
10
+ class SimulateLowResolutionTransform(ImageOnlyTransform):
11
+ def __init__(self,
12
+ scale: RandomScalar,
13
+ synchronize_channels: bool,
14
+ synchronize_axes: bool,
15
+ ignore_axes: Tuple[int, ...],
16
+ allowed_channels: Tuple[int, ...] = None,
17
+ p_per_channel: float = 1):
18
+ super().__init__()
19
+ self.scale = scale
20
+ self.synchronize_channels = synchronize_channels
21
+ self.synchronize_axes = synchronize_axes
22
+ self.ignore_axes = ignore_axes
23
+ self.allowed_channels = allowed_channels
24
+ self.p_per_channel = p_per_channel
25
+
26
+ self.upmodes = {
27
+ 1: 'linear',
28
+ 2: 'bilinear',
29
+ 3: 'trilinear'
30
+ }
31
+
32
+ def get_parameters(self, **data_dict) -> dict:
33
+ shape = data_dict['image'].shape
34
+ if self.allowed_channels is None:
35
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
36
+ else:
37
+ apply_to_channel = [i for i in self.allowed_channels if torch.rand(1) < self.p_per_channel]
38
+ if self.synchronize_channels:
39
+ if self.synchronize_axes:
40
+ scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=None)] * (len(shape) - 1)] * len(apply_to_channel))
41
+ else:
42
+ scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=d) for d in range(len(shape) - 1)]] * len(apply_to_channel))
43
+ else:
44
+ if self.synchronize_axes:
45
+ scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=None)] * (len(shape) - 1) for c in apply_to_channel])
46
+ else:
47
+ scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=d) for d in range(len(shape) - 1)] for c in apply_to_channel])
48
+ if len(scales) > 0 and not self.ignore_axes is None:
49
+ scales[:, self.ignore_axes] = 1
50
+ return {
51
+ 'apply_to_channel': apply_to_channel,
52
+ 'scales': scales
53
+ }
54
+
55
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
56
+ orig_shape = img.shape[1:]
57
+ # we cannot batch this because the downsampled shaps will be different for each channel
58
+ for c, s in zip(params['apply_to_channel'], params['scales']):
59
+ new_shape = [round(i * j.item()) for i, j in zip(orig_shape, s)]
60
+ downsampled = interpolate(img[c][None, None], new_shape, mode='nearest-exact')
61
+ img[c] = interpolate(downsampled, orig_shape, mode=self.upmodes[img.ndim - 1])[0, 0]
62
+ return img
63
+
64
+
65
+ if __name__ == '__main__':
66
+ from time import time
67
+ import numpy as np
68
+ import os
69
+
70
+ os.environ['OMP_NUM_THREADS'] = '1'
71
+ torch.set_num_threads(1)
72
+
73
+ mbt = SimulateLowResolutionTransform((0.1, 1.), synchronize_channels=False, synchronize_axes=False, ignore_axes=None, allowed_channels=None, p_per_channel=1)
74
+
75
+ times_torch = []
76
+ for _ in range(30):
77
+ data_dict = {'image': torch.ones((3, 128, 192, 64))}
78
+ st = time()
79
+ out = mbt(**data_dict)
80
+ times_torch.append(time() - st)
81
+ print('torch', np.mean(times_torch))
82
+
83
+ from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform as SLRT
84
+
85
+ gnt_bg = SLRT((0.1, 1), True, p_per_channel=1, order_downsample=0, order_upsample=1, p_per_sample=1)
86
+ times_bg = []
87
+ for _ in range(30):
88
+ data_dict = {'data': np.ones((1, 3, 128, 192, 64))}
89
+ st = time()
90
+ out = gnt_bg(**data_dict)
91
+ times_bg.append(time() - st)
92
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,71 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
6
+
7
+
8
+ class MirrorTransform(BasicTransform):
9
+ def __init__(self, allowed_axes: Tuple[int, ...]):
10
+ super().__init__()
11
+ self.allowed_axes = allowed_axes
12
+
13
+ def get_parameters(self, **data_dict) -> dict:
14
+ axes = [i for i in self.allowed_axes if torch.rand(1) < 0.5]
15
+ return {
16
+ 'axes': axes
17
+ }
18
+
19
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
20
+ if len(params['axes']) == 0:
21
+ return img
22
+ axes = [i + 1 for i in params['axes']]
23
+ return torch.flip(img, axes)
24
+
25
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
26
+ if len(params['axes']) == 0:
27
+ return segmentation
28
+ axes = [i + 1 for i in params['axes']]
29
+ return torch.flip(segmentation, axes)
30
+
31
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
32
+ if len(params['axes']) == 0:
33
+ return regression_target
34
+ axes = [i + 1 for i in params['axes']]
35
+ return torch.flip(regression_target, axes)
36
+
37
+ def _apply_to_bbox(self, bbox, **params):
38
+ raise NotImplementedError
39
+
40
+ def _apply_to_keypoints(self, keypoints, **params):
41
+ raise NotImplementedError
42
+
43
+
44
+ if __name__ == '__main__':
45
+ from time import time
46
+ import numpy as np
47
+ import os
48
+
49
+ os.environ['OMP_NUM_THREADS'] = '1'
50
+ torch.set_num_threads(1)
51
+
52
+ mbt = MirrorTransform((0, 1, 2))
53
+
54
+ times_torch = []
55
+ for _ in range(100):
56
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
57
+ st = time()
58
+ out = mbt(**data_dict)
59
+ times_torch.append(time() - st)
60
+ print('torch', np.mean(times_torch))
61
+
62
+ from batchgenerators.transforms.spatial_transforms import MirrorTransform as BGMirror
63
+
64
+ gnt_bg = BGMirror((0, 1, 2))
65
+ times_bg = []
66
+ for _ in range(100):
67
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
68
+ st = time()
69
+ out = gnt_bg(**data_dict)
70
+ times_bg.append(time() - st)
71
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+ import torch
3
+ from typing import Tuple, Set, List
4
+
5
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
6
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
7
+
8
+
9
+ class Rot90Transform(BasicTransform):
10
+ """
11
+ Applies a random 90-degree rotation to image and associated targets along randomly chosen axes.
12
+
13
+ Attributes:
14
+ num_rot (Tuple[int]): Possible multiples of 90 degrees to rotate (e.g., (1, 2, 3)).
15
+ allowed_axes (Set[int]): Spatial axes to randomly select rotation axes from (e.g., {0, 1, 2}).
16
+ p_per_sample (float): Probability of applying the transform to a sample.
17
+ """
18
+
19
+ def __init__(self, num_axis_combinations: RandomScalar, num_rot_per_combination: Tuple[int, ...] = (1, 2, 3),
20
+ allowed_axes: Set[int] = {0, 1, 2}):
21
+ super().__init__()
22
+ self.num_axis_combinations = num_axis_combinations
23
+ self.num_rot_per_combination = num_rot_per_combination
24
+ self.allowed_axes = allowed_axes
25
+
26
+ def get_parameters(self, **data_dict) -> dict:
27
+ n_axes_combinations = round(sample_scalar(self.num_axis_combinations))
28
+ axis_combinations = []
29
+ num_rot_per_combination = []
30
+ for i in range(n_axes_combinations):
31
+ num_rot_per_combination.append(int(np.random.choice(self.num_rot_per_combination)))
32
+ axis_combinations.append(sorted(np.random.choice(list(self.allowed_axes), size=2, replace=False)))
33
+ # +1 because we skip channel dimension
34
+ axis_combinations[-1] = [a + 1 for a in axis_combinations[-1]]
35
+ return {
36
+ 'num_rot_per_combination': num_rot_per_combination,
37
+ 'axis_combinations': axis_combinations
38
+ }
39
+
40
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
41
+ return self._maybe_rot90(img, **params)
42
+
43
+ def _apply_to_segmentation(self, seg: torch.Tensor, **params) -> torch.Tensor:
44
+ return self._maybe_rot90(seg, **params)
45
+
46
+ def _apply_to_regr_target(self, regression_target: torch.Tensor, **params) -> torch.Tensor:
47
+ return self._maybe_rot90(regression_target, **params)
48
+
49
+ def _maybe_rot90(self, tensor: torch.Tensor, num_rot_per_combination: List[int], axis_combinations: List[Tuple[int, int]]) -> torch.Tensor:
50
+ for n_rot, axes in zip(num_rot_per_combination, axis_combinations):
51
+ tensor = torch.rot90(tensor, k=n_rot, dims=axes)
52
+ return tensor
53
+
54
+ def _apply_to_bbox(self, bbox, **params):
55
+ raise NotImplementedError
56
+
57
+ def _apply_to_keypoints(self, keypoints, **params):
58
+ raise NotImplementedError
59
+
60
+ if __name__ == '__main__':
61
+ # Create dummy 3D image and segmentation tensors: (C, X, Y, Z)
62
+ image = torch.arange(1 * 8 * 8 * 8).reshape(1, 8, 8, 8).float()
63
+ seg = torch.zeros_like(image)
64
+
65
+ # Instantiate the transform
66
+ transform = Rot90Transform(num_axis_combinations=2, num_rot_per_combination=(1, 2, 3), allowed_axes={0, 1, 2}) # always apply for demo
67
+
68
+ # Get random parameters for this sample
69
+ params = transform.get_parameters(image=image, segmentation=seg)
70
+
71
+ # Apply transform
72
+ image_rot = transform._apply_to_image(image, **params)
73
+ seg_rot = transform._apply_to_segmentation(seg, **params)
74
+
75
+ # Print to verify
76
+ print("Original image shape:", image.shape)
77
+ print("Rotated image shape:", image_rot.shape)
78
+ print("Rotation parameters:", params)