batchgeneratorsv2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. batchgeneratorsv2/benchmarks/__init__.py +0 -0
  2. batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  3. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  4. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  5. batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  6. batchgeneratorsv2/dataloading/__init__.py +0 -0
  7. batchgeneratorsv2/helpers/__init__.py +0 -0
  8. batchgeneratorsv2/helpers/fft_conv.py +149 -0
  9. batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
  13. batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  15. batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
  16. batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  17. batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  18. batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  19. batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  20. batchgeneratorsv2/transforms/local/__init__.py +0 -0
  21. batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  22. batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  23. batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  24. batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  25. batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  26. batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
  28. batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  29. batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  30. batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  32. batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  33. batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  34. batchgeneratorsv2/transforms/noise/rician.py +61 -0
  35. batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  36. batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  37. batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  38. batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
  39. batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  40. batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  41. batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
  42. batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  43. batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  44. batchgeneratorsv2/transforms/utils/compose.py +89 -0
  45. batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  46. batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  47. batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  48. batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
  49. batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
  50. batchgeneratorsv2/transforms/utils/random.py +46 -0
  51. batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
  52. batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  53. batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
  54. batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
  55. batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
  56. batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
  57. batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,601 @@
1
+ import math
2
+ from copy import deepcopy
3
+ from typing import Tuple, List, Union
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from scipy.ndimage import fourier_gaussian
9
+ from torch.nn.functional import grid_sample
10
+
11
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
12
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
13
+ from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
14
+
15
+
16
+ class SpatialTransform(BasicTransform):
17
+ def __init__(self,
18
+ patch_size: Tuple[int, ...],
19
+ patch_center_dist_from_border: Union[int, List[int], Tuple[int, ...]],
20
+ random_crop: bool,
21
+ p_elastic_deform: float = 0,
22
+ elastic_deform_scale: RandomScalar = (0, 0.2),
23
+ elastic_deform_magnitude: RandomScalar = (0, 0.2),
24
+ p_synchronize_def_scale_across_axes: float = 0,
25
+ p_rotation: float = 0,
26
+ rotation: RandomScalar = (0, 2 * np.pi),
27
+ p_rot_per_axis: float = 1,
28
+ p_scaling: float = 0,
29
+ scaling: RandomScalar = (0.7, 1.3),
30
+ p_synchronize_scaling_across_axes: float = 0,
31
+ bg_style_seg_sampling: bool = True,
32
+ mode_seg: str = 'bilinear',
33
+ border_mode_seg: str = "zeros",
34
+ center_deformation: bool = True,
35
+ mode_image: str = 'bilinear',
36
+ padding_mode_image: str = "zeros",
37
+ padding_value_seg: float = 0,
38
+ padding_value_image: float = 0,
39
+ align_corners: bool = False
40
+ ):
41
+ """
42
+ magnitude must be given in pixels!
43
+ deformation scale is given as a paercentage of the edge length
44
+
45
+ padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target
46
+ because both call self._apply_to_image. Can be "zeros", "constant", "reflection", "border"
47
+
48
+ border_mode_seg: can be "zeros", "constant", "reflection", "border". padding values are only considered for
49
+ the corresponding "constant" modes.
50
+ """
51
+ super().__init__()
52
+ self.patch_size = patch_size
53
+ if not isinstance(patch_center_dist_from_border, (tuple, list)):
54
+ patch_center_dist_from_border = [patch_center_dist_from_border] * len(patch_size)
55
+ self.patch_center_dist_from_border = patch_center_dist_from_border
56
+ self.random_crop = random_crop
57
+ self.p_elastic_deform = p_elastic_deform
58
+ self.elastic_deform_scale = elastic_deform_scale # sigma for blurring offsets, in % of patch size. Larger values mean coarser deformation
59
+ self.elastic_deform_magnitude = elastic_deform_magnitude # determines the maximum displacement, measured in pixels!!
60
+ self.p_rotation = p_rotation
61
+ self.rotation = rotation
62
+ self.p_rot_per_axis = p_rot_per_axis
63
+ self.p_scaling = p_scaling
64
+ self.scaling = scaling # larger numbers = smaller objects!
65
+ self.p_synchronize_scaling_across_axes = p_synchronize_scaling_across_axes
66
+ self.p_synchronize_def_scale_across_axes = p_synchronize_def_scale_across_axes
67
+ self.bg_style_seg_sampling = bg_style_seg_sampling
68
+ self.mode_seg = mode_seg
69
+ self.border_mode_seg = border_mode_seg
70
+ self.center_deformation = center_deformation
71
+ self.mode_image = mode_image
72
+ self.padding_mode_image = padding_mode_image
73
+ self.padding_value_seg = padding_value_seg
74
+ self.padding_value_image = padding_value_image
75
+ self.align_corners = align_corners
76
+ self._grid_cache = {} # key: (patch_size, dtype) -> base grid tensor
77
+
78
+ def _get_base_grid_clone(self) -> torch.Tensor:
79
+ key = tuple(self.patch_size)
80
+ g = self._grid_cache.get(key)
81
+ if g is None:
82
+ g = _create_centered_identity_grid2(self.patch_size).float().contiguous()
83
+ self._grid_cache[key] = g
84
+ return g.clone()
85
+
86
+ @staticmethod
87
+ def _get_crop_pad_settings(padding_mode: str, padding_value: float):
88
+ if padding_mode == 'reflection':
89
+ return 'reflect', {}
90
+ if padding_mode == 'border':
91
+ return 'replicate', {}
92
+ if padding_mode == 'zeros':
93
+ return 'constant', {'value': 0}
94
+ if padding_mode == 'constant':
95
+ return 'constant', {'value': padding_value}
96
+ raise RuntimeError(f'Unknown pad mode: {padding_mode}')
97
+
98
+ @staticmethod
99
+ def _get_grid_sample_padding_mode(padding_mode: str) -> str:
100
+ if padding_mode in ('zeros', 'constant'):
101
+ return 'zeros'
102
+ if padding_mode in ('border', 'reflection'):
103
+ return padding_mode
104
+ raise RuntimeError(f'Unknown pad mode: {padding_mode}')
105
+
106
+ @staticmethod
107
+ def _requires_constant_padding_fixup(padding_mode: str, padding_value: float) -> bool:
108
+ return padding_mode == 'constant' and padding_value != 0
109
+
110
+ def _compute_out_of_bounds_mask(self, grid: torch.Tensor, spatial_shape: Tuple[int, ...]) -> torch.Tensor:
111
+ if self.align_corners:
112
+ lo = grid.new_tensor(-1.)
113
+ hi = grid.new_tensor(1.)
114
+ else:
115
+ size = grid.new_tensor(spatial_shape)
116
+ lo = -1 + 1 / size
117
+ hi = 1 - 1 / size
118
+ return ((grid < lo) | (grid > hi)).any(dim=-1)
119
+
120
+ def get_parameters(self, **data_dict) -> dict:
121
+ dim = data_dict['image'].ndim - 1
122
+
123
+ do_rotation = np.random.uniform() < self.p_rotation
124
+ do_scale = np.random.uniform() < self.p_scaling
125
+ do_deform = np.random.uniform() < self.p_elastic_deform
126
+
127
+ if do_rotation:
128
+ angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(0, dim)]
129
+ if self.p_rot_per_axis < 1:
130
+ for i in range(dim):
131
+ if np.random.uniform() > self.p_rot_per_axis:
132
+ angles[i] = 0
133
+ else:
134
+ angles = [0] * dim
135
+ if do_scale:
136
+ if np.random.uniform() <= self.p_synchronize_scaling_across_axes:
137
+ scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
138
+ else:
139
+ scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(0, dim)]
140
+ else:
141
+ scales = [1] * dim
142
+
143
+ # affine matrix
144
+ if do_scale or do_rotation:
145
+ if dim == 3:
146
+ affine = create_affine_matrix_3d(angles, scales)
147
+ elif dim == 2:
148
+ affine = create_affine_matrix_2d(angles[-1], scales)
149
+ else:
150
+ raise RuntimeError(f'Unsupported dimension: {dim}')
151
+ else:
152
+ affine = None # this will allow us to detect that we can skip computations
153
+
154
+ # elastic deformation. We need to create the displacement field here
155
+ # we use the method from augment_spatial_2 in batchgenerators
156
+ if do_deform:
157
+ if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
158
+ deformation_scales = [
159
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None,
160
+ patch_size=self.patch_size)
161
+ ] * dim
162
+ else:
163
+ deformation_scales = [
164
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i,
165
+ patch_size=self.patch_size)
166
+ for i in range(dim)
167
+ ]
168
+
169
+ # sigmas must be in pixels, as this will be applied to the deformation field
170
+ sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
171
+
172
+ magnitude = [
173
+ sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
174
+ dim=i, deformation_scale=deformation_scales[i])
175
+ for i in range(dim)]
176
+ # doing it like this for better memory layout for blurring
177
+ offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
178
+
179
+ # all the additional time elastic deform takes is spent here
180
+ for d in range(dim):
181
+ # fft torch, slower
182
+ # for i in range(offsets.ndim - 1):
183
+ # offsets[d] = blur_dimension(offsets[d][None], sigmas[d], i, force_use_fft=True, truncate=6)[0]
184
+
185
+ # fft numpy, this is faster o.O
186
+ tmp = np.fft.fftn(offsets[d].numpy())
187
+ tmp = fourier_gaussian(tmp, sigmas[d])
188
+ offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
189
+
190
+ # tmp = offsets[d].numpy().astype(np.float64)
191
+ # gaussian_filter(tmp, sigmas[d], 0, output=tmp)
192
+ # offsets[d] = torch.from_numpy(tmp).to(offsets.dtype)
193
+ # print(offsets.dtype)
194
+
195
+ mx = torch.max(torch.abs(offsets[d]))
196
+ offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
197
+ spatial_dims = tuple(list(range(1, dim + 1)))
198
+ offsets = torch.permute(offsets, (*spatial_dims, 0))
199
+ else:
200
+ offsets = None
201
+
202
+ shape = data_dict['image'].shape[1:]
203
+ if not self.random_crop:
204
+ center_location_in_pixels = [i / 2 for i in shape]
205
+ else:
206
+ center_location_in_pixels = []
207
+ for d in range(0, dim):
208
+ mn = self.patch_center_dist_from_border[d]
209
+ mx = shape[d] - self.patch_center_dist_from_border[d]
210
+ if mx < mn:
211
+ center_location_in_pixels.append(shape[d] / 2)
212
+ else:
213
+ center_location_in_pixels.append(np.random.uniform(mn, mx))
214
+ # Precompute the deformed grid once (shared by image, segmentation, regression target)
215
+ if affine is not None or offsets is not None:
216
+ grid = self._get_base_grid_clone()
217
+
218
+ # we deform first, then rotate
219
+ if offsets is not None:
220
+ grid += offsets
221
+ if affine is not None:
222
+ grid = torch.matmul(grid, torch.from_numpy(affine).float())
223
+
224
+ # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
225
+ # only do this if we elastic deform
226
+ if self.center_deformation and offsets is not None:
227
+ mn = grid.mean(dim=list(range(len(shape))))
228
+ else:
229
+ mn = 0
230
+
231
+ new_center = torch.Tensor([c - s / 2 for c, s in zip(center_location_in_pixels, shape)])
232
+ grid += (new_center - mn)
233
+ grid = _convert_my_grid_to_grid_sample_grid(grid, shape)
234
+ else:
235
+ grid = None
236
+
237
+ return {
238
+ 'center_location_in_pixels': center_location_in_pixels,
239
+ 'grid': grid,
240
+ # we don't need them but we keep them so that we can debug better
241
+ 'affine': affine,
242
+ 'elastic_offsets': offsets,
243
+ }
244
+
245
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
246
+ if params['grid'] is None:
247
+ # No spatial transformation is being done. Round grid_center and crop without having to interpolate.
248
+ # This saves compute.
249
+ # cropping requires the center to be given as integer coordinates
250
+ pad_mode, pad_kwargs = self._get_crop_pad_settings(self.padding_mode_image, self.padding_value_image)
251
+ return crop_tensor(
252
+ img,
253
+ [math.floor(i) for i in params['center_location_in_pixels']],
254
+ self.patch_size,
255
+ pad_mode=pad_mode,
256
+ pad_kwargs=pad_kwargs,
257
+ )
258
+
259
+ grid = params['grid']
260
+ result = grid_sample(
261
+ img[None],
262
+ grid[None],
263
+ mode=self.mode_image,
264
+ padding_mode=self._get_grid_sample_padding_mode(self.padding_mode_image),
265
+ align_corners=self.align_corners,
266
+ )[0]
267
+ if self._requires_constant_padding_fixup(self.padding_mode_image, self.padding_value_image):
268
+ out_of_bounds_mask = self._compute_out_of_bounds_mask(grid, img.shape[1:])
269
+ result.masked_fill_(out_of_bounds_mask.unsqueeze(0), self.padding_value_image)
270
+ return result
271
+
272
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
273
+ segmentation = segmentation.contiguous()
274
+ if params['grid'] is None:
275
+ # No spatial transformation is being done. Round grid_center and crop without having to interpolate.
276
+ # This saves compute.
277
+ # cropping requires the center to be given as integer coordinates
278
+ pad_mode, pad_kwargs = self._get_crop_pad_settings(self.border_mode_seg, self.padding_value_seg)
279
+ return crop_tensor(
280
+ segmentation,
281
+ [math.floor(i) for i in params['center_location_in_pixels']],
282
+ self.patch_size,
283
+ pad_mode=pad_mode,
284
+ pad_kwargs=pad_kwargs,
285
+ )
286
+
287
+ grid = params['grid']
288
+ grid_sample_padding_mode = self._get_grid_sample_padding_mode(self.border_mode_seg)
289
+
290
+ if self.mode_seg == 'nearest':
291
+ result_seg = grid_sample(
292
+ segmentation[None].float(),
293
+ grid[None],
294
+ mode=self.mode_seg,
295
+ padding_mode=grid_sample_padding_mode,
296
+ align_corners=self.align_corners
297
+ )[0].to(segmentation.dtype)
298
+ else:
299
+ result_seg = torch.zeros((segmentation.shape[0], *self.patch_size), dtype=segmentation.dtype)
300
+ if self.bg_style_seg_sampling:
301
+ for c in range(segmentation.shape[0]):
302
+ labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
303
+ # if we only have 2 labels then we can save compute time
304
+ if len(labels) == 2:
305
+ out = grid_sample(
306
+ ((segmentation[c] == labels[1]).float())[None, None],
307
+ grid[None],
308
+ mode=self.mode_seg,
309
+ padding_mode=grid_sample_padding_mode,
310
+ align_corners=self.align_corners
311
+ )[0][0] >= 0.5
312
+ result_seg[c][out] = labels[1]
313
+ result_seg[c][~out] = labels[0]
314
+ else:
315
+ for i, u in enumerate(labels):
316
+ result_seg[c][
317
+ grid_sample(
318
+ ((segmentation[c] == u).float())[None, None],
319
+ grid[None],
320
+ mode=self.mode_seg,
321
+ padding_mode=grid_sample_padding_mode,
322
+ align_corners=self.align_corners
323
+ )[0][0] >= 0.5] = u
324
+ else:
325
+ for c in range(segmentation.shape[0]):
326
+ labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
327
+ # torch.where(torch.bincount(segmentation.ravel()) > 0)[0].to(segmentation.dtype)
328
+ tmp = torch.zeros((len(labels), *self.patch_size), dtype=torch.float16)
329
+ scale_factor = 1000
330
+ done_mask = torch.zeros(*self.patch_size, dtype=torch.bool)
331
+ for i, u in enumerate(labels):
332
+ tmp[i] = grid_sample(
333
+ ((segmentation[c] == u).float() * scale_factor)[None, None],
334
+ grid[None],
335
+ mode=self.mode_seg,
336
+ padding_mode=grid_sample_padding_mode,
337
+ align_corners=self.align_corners
338
+ )[0][0]
339
+ mask = tmp[i] > (0.7 * scale_factor)
340
+ result_seg[c][mask] = u
341
+ done_mask = done_mask | mask
342
+ if not torch.all(done_mask):
343
+ result_seg[c][~done_mask] = labels[tmp[:, ~done_mask].argmax(0)]
344
+ del tmp
345
+
346
+ if self._requires_constant_padding_fixup(self.border_mode_seg, self.padding_value_seg):
347
+ out_of_bounds_mask = self._compute_out_of_bounds_mask(grid, segmentation.shape[1:])
348
+ result_seg.masked_fill_(out_of_bounds_mask.unsqueeze(0), self.padding_value_seg)
349
+ del grid
350
+ return result_seg.contiguous()
351
+
352
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
353
+ return self._apply_to_image(regression_target, **params)
354
+
355
+ def _apply_to_keypoints(self, keypoints, **params):
356
+ raise NotImplementedError
357
+
358
+ def _apply_to_bbox(self, bbox, **params):
359
+ raise NotImplementedError
360
+
361
+
362
+ def create_affine_matrix_3d(rotation_angles, scaling_factors):
363
+ # Rotation matrices for each axis
364
+ Rx = np.array([[1, 0, 0],
365
+ [0, np.cos(rotation_angles[0]), -np.sin(rotation_angles[0])],
366
+ [0, np.sin(rotation_angles[0]), np.cos(rotation_angles[0])]])
367
+
368
+ Ry = np.array([[np.cos(rotation_angles[1]), 0, np.sin(rotation_angles[1])],
369
+ [0, 1, 0],
370
+ [-np.sin(rotation_angles[1]), 0, np.cos(rotation_angles[1])]])
371
+
372
+ Rz = np.array([[np.cos(rotation_angles[2]), -np.sin(rotation_angles[2]), 0],
373
+ [np.sin(rotation_angles[2]), np.cos(rotation_angles[2]), 0],
374
+ [0, 0, 1]])
375
+
376
+ # Scaling matrix
377
+ S = np.diag(scaling_factors)
378
+
379
+ # Combine rotation and scaling
380
+ RS = Rz @ Ry @ Rx @ S
381
+ return RS
382
+
383
+
384
+ def create_affine_matrix_2d(rotation_angle, scaling_factors):
385
+ # Rotation matrix
386
+ R = np.array([[np.cos(rotation_angle), -np.sin(rotation_angle)],
387
+ [np.sin(rotation_angle), np.cos(rotation_angle)]])
388
+
389
+ # Scaling matrix
390
+ S = np.diag(scaling_factors)
391
+
392
+ # Combine rotation and scaling
393
+ RS = R @ S
394
+ return RS
395
+
396
+
397
+ def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor:
398
+ space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size]
399
+ grid = torch.meshgrid(space, indexing="ij")
400
+ grid = torch.stack(grid, -1)
401
+ return grid
402
+
403
+
404
+ def _convert_my_grid_to_grid_sample_grid(my_grid: torch.Tensor, original_shape: Union[Tuple[int, ...], List[int]]):
405
+ # rescale
406
+ for d in range(len(original_shape)):
407
+ s = original_shape[d]
408
+ my_grid[..., d] /= (s / 2)
409
+ my_grid = torch.flip(my_grid, (len(my_grid.shape) - 1,))
410
+ # my_grid = my_grid.flip((len(my_grid.shape) - 1,))
411
+ return my_grid
412
+
413
+
414
+ if __name__ == '__main__':
415
+ torch.set_num_threads(1)
416
+
417
+ shape = (128, 128, 128)
418
+ patch_size = (128, 128, 128)
419
+ labels = 2
420
+
421
+
422
+ # seg = torch.rand([i // 32 for i in shape]) * labels
423
+ # seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
424
+ # decimals=0).to(torch.int16)
425
+ # img = torch.ones((1, *shape))
426
+ # img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
427
+
428
+
429
+ import SimpleITK as sitk
430
+ # img = camera()
431
+ # seg = None
432
+ img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset226_BraTS2024-BraTS-GLI/imagesTr/BraTS-GLI-00005-100_0001.nii.gz'))
433
+ seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset226_BraTS2024-BraTS-GLI/labelsTr/BraTS-GLI-00005-100.nii.gz'))
434
+
435
+ patch_size = (192, 192, 192)
436
+ sp = SpatialTransform(
437
+ patch_size=(192, 192, 192),
438
+ patch_center_dist_from_border=[i / 2 for i in patch_size],
439
+ random_crop=True,
440
+ p_elastic_deform=0,
441
+ elastic_deform_magnitude=(0.1, 0.1),
442
+ elastic_deform_scale=(0.1, 0.1),
443
+ p_synchronize_def_scale_across_axes=0.5,
444
+ p_rotation=1,
445
+ rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
446
+ p_scaling=1,
447
+ scaling=(0.75, 1),
448
+ p_synchronize_scaling_across_axes=0.5,
449
+ bg_style_seg_sampling=True,
450
+ mode_seg='bilinear'
451
+ )
452
+
453
+ data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
454
+ if seg is not None:
455
+ data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
456
+ # out = sp(**data_dict)
457
+ #
458
+ # view_batch(out['image'], out['segmentation'])
459
+
460
+ from time import time
461
+ times = []
462
+ for _ in range(10):
463
+ data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
464
+ if seg is not None:
465
+ data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
466
+ st = time()
467
+ out = sp(**data_dict)
468
+ times.append(time() - st)
469
+ print(np.median(times))
470
+
471
+
472
+ #################
473
+ # with this part we can qualitatively test that the correct axes are ebing augmented. Just set one of the probs to 1 and off you go
474
+ #################
475
+
476
+ # def eldef_scale(image, dim, patch_size):
477
+ # return 0.1
478
+ #
479
+ # def eldef_magnitude(image, dim, patch_size, deformation_scale):
480
+ # return 10 if dim == 2 else 0
481
+ #
482
+ # def rot(image, dim):
483
+ # return 45/360 * 2 * np.pi if dim == 0 else 0
484
+ #
485
+ # def scaling(image, dim):
486
+ # return 0.5 if dim == 0 else 1
487
+ #
488
+ # # lines
489
+ # patch = torch.zeros((1, 64, 60, 68))
490
+ # patch[:, :, 10, 30] = 1
491
+ # patch[:, 50, :, 30] = 1
492
+ # patch[:, 40, 20, :] = 1
493
+ #
494
+ # # patch_block
495
+ # patch_block = torch.zeros((1, 64, 60, 68))
496
+ # patch_block[:, 22:42, 20:40, 24:44] = 1
497
+ #
498
+ # patch_line = torch.zeros((1, 64, 60, 128))
499
+ # patch_line[:, 22:24, 30:32, 10:-10] = 1
500
+ # use = patch_line
501
+ #
502
+ # sp = SpatialTransform(
503
+ # patch_size=patch.shape[1:],
504
+ # patch_center_dist_from_border=0,
505
+ # random_crop=False,
506
+ # p_elastic_deform=0,
507
+ # p_rotation=1,
508
+ # p_scaling=0,
509
+ # elastic_deform_scale=eldef_scale,
510
+ # elastic_deform_magnitude=eldef_magnitude,
511
+ # p_synchronize_def_scale_across_axes=0,
512
+ # rotation=rot,
513
+ # scaling=scaling,
514
+ # p_synchronize_scaling_across_axes=0,
515
+ # bg_style_seg_sampling=False,
516
+ # mode_seg='bilinear'
517
+ # )
518
+ #
519
+ #
520
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(use[0].numpy()), 'orig.nii.gz')
521
+ #
522
+ # params = sp.get_parameters(image=use)
523
+ # transformed = sp._apply_to_image(use, **params)
524
+ #
525
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
526
+
527
+ # p = torch.zeros((1, 1, 8, 16, 32))
528
+ # p[:, :, 2:6, 10:16, 10:24] = 1
529
+ # grid = _create_identity_grid(p.shape[2:])
530
+ # grid[:, :, :, 0] *= 0.5
531
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
532
+ # torch.all(out == p)
533
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
534
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
535
+
536
+ #################
537
+ # with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
538
+ #################
539
+
540
+ # sp = SpatialTransform(
541
+ # patch_size=(48, 52, 54),
542
+ # patch_center_dist_from_border=0,
543
+ # random_crop=True,
544
+ # p_elastic_deform=0,
545
+ # p_rotation=1,
546
+ # p_scaling=0,
547
+ # rotation=0
548
+ # )
549
+ # sp2 = SpatialTransform(
550
+ # patch_size=(48, 52, 54),
551
+ # patch_center_dist_from_border=0,
552
+ # random_crop=True,
553
+ # p_elastic_deform=0,
554
+ # p_rotation=0,
555
+ # p_scaling=0,
556
+ # )
557
+ #
558
+ # patch = torch.zeros((1, 64, 60, 68))
559
+ # patch[:, :, 10, 30] = 1
560
+ # patch[:, 50, :, 30] = 1
561
+ # patch[:, 40, 20, :] = 1
562
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
563
+ #
564
+ # center_coords = [50, 10, 16]
565
+ # params = sp.get_parameters(image=patch)
566
+ # params['center_location_in_pixels'] = center_coords
567
+ # params2 = sp2.get_parameters(image=patch)
568
+ # params2['center_location_in_pixels'] = center_coords
569
+ # transformed = sp._apply_to_image(patch, **params)
570
+ # transformed2 = sp._apply_to_image(patch, **params2)
571
+ #
572
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
573
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
574
+
575
+
576
+
577
+ ####################
578
+ # This is exploraroty code to check how to retrieve coordinates. I used it to verify that grid_sample does in fact
579
+ # use coordinates in reversed dimension order (zyx and not xyz)
580
+ ####################
581
+ # # create a dummy input which has a unique shape in each exis
582
+ # p = torch.zeros((1, 1, 8, 16, 32))
583
+ # # set one pixel to 1
584
+ # p[:, :, 4, 0, 31] = 1
585
+ # # now create an identity grid. I have verified that this grid yields the same image as the input when used in grid_sample. So the grid is correct
586
+ # grid = _create_identity_grid((8, 16, 32)).contiguous() # grid is shape torch.Size([8, 16, 32, 3])
587
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
588
+ # assert torch.all(out == p) # this passes
589
+ # # reduce the grid to the location we are interested in. That are the coordinates where we placed the 1. The 4:5 etc is only so that we keep the number of dimensions
590
+ # grid = grid[4:5, 0:1, 31:32]
591
+ # # What coordinate would we expect? Note that grid is [-1, 1]
592
+ # # For the first dimension, coordinate 4 out of shape 8 is approximately in the middle, so about 0
593
+ # # For the second dimension, coordinate 0 out of shape 16 is very low, so we expect -1 ish (remember there is aligned corners and shit)
594
+ # # For the third dimension, coordinate 31 out of shape 32 is very high, so we expect 1 ish (remember there is aligned corners and shit)
595
+ # # So we expect [0, -1, 1]
596
+ # # What do we get?
597
+ # print(grid)
598
+ # # > tensor([[[[ 0.9688, -0.9375, 0.1250]]]])
599
+ # # not what we expect
600
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
601
+ # assert out.item() == 1
@@ -0,0 +1,67 @@
1
+ from typing import Set
2
+ import numpy as np
3
+ import torch
4
+
5
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
6
+
7
+
8
+ class TransposeAxesTransform(BasicTransform):
9
+ """
10
+ A transformation class to permute specified spatial axes of an image and related data.
11
+
12
+ Attributes:
13
+ allowed_axes (Set[int]): Set of spatial axes allowed for permutation (e.g., {1, 2} for y and z axes in an
14
+ image of shape (c, x, y, z)).
15
+ """
16
+
17
+ def __init__(self, allowed_axes: Set[int]):
18
+ """
19
+ Initialize the transform with allowed spatial axes for permutation.
20
+
21
+ Args:
22
+ allowed_axes (Set[int]): Set of spatial axis indices for permutation.
23
+ """
24
+ super().__init__()
25
+ self.allowed_axes = allowed_axes
26
+
27
+ def get_parameters(self, **data_dict) -> dict:
28
+ """
29
+ Generate a random axis permutation order.
30
+
31
+ Args:
32
+ data_dict (dict): Dictionary containing `image` tensor data.
33
+
34
+ Returns:
35
+ dict: Permutation order of axes as 'axis_order'.
36
+ """
37
+ shape_of_allowed = [data_dict['image'].shape[1 + i] for i in self.allowed_axes]
38
+ if len(shape_of_allowed) < 2:
39
+ return {'axis_order': list(range(len(data_dict['image'].shape)))}
40
+ if not all(i == shape_of_allowed[0] for i in shape_of_allowed[1:]):
41
+ raise ValueError(f"Axis shapes are not identical: {shape_of_allowed}. Cannot permute.\n"
42
+ f"Image shape: {data_dict['image'].shape}. Allowed axes: {self.allowed_axes}")
43
+
44
+ axes = [i + 1 for i in self.allowed_axes]
45
+ np.random.shuffle(axes)
46
+ axis_order = np.arange(len(data_dict['image'].shape))
47
+ axis_order[np.isin(axis_order, axes)] = axes
48
+ return {'axis_order': [int(i) for i in axis_order]}
49
+
50
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
51
+ return segmentation.permute(params['axis_order']).contiguous()
52
+
53
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
54
+ return img.permute(params['axis_order']).contiguous()
55
+
56
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
57
+ return regression_target.permute(params['axis_order']).contiguous()
58
+
59
+ def _apply_to_bbox(self, bbox, **params):
60
+ raise NotImplementedError
61
+
62
+ def _apply_to_keypoints(self, keypoints, **params):
63
+ raise NotImplementedError
64
+
65
+ if __name__ == '__main__':
66
+ t = TransposeAxesTransform((1, 2))
67
+ ret = t(**{'image': torch.rand((2, 31, 32, 32)), 'segmentation': torch.ones((1, 31, 32, 32))})
File without changes