batchgeneratorsv2 0.2__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/PKG-INFO +2 -2
  2. batchgeneratorsv2-0.2.2/batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  3. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/spatial/low_resolution.py +6 -2
  4. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/spatial/spatial.py +178 -114
  5. batchgeneratorsv2-0.2.2/batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  6. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2.egg-info/PKG-INFO +2 -2
  7. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2.egg-info/SOURCES.txt +2 -0
  8. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/pyproject.toml +1 -1
  9. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/LICENSE +0 -0
  10. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/__init__.py +0 -0
  11. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/benchmarks/__init__.py +0 -0
  12. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  13. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +0 -0
  14. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +0 -0
  15. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
  16. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/dataloading/__init__.py +0 -0
  17. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/helpers/__init__.py +0 -0
  18. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
  19. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/__init__.py +0 -0
  20. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/base/__init__.py +0 -0
  21. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/base/basic_transform.py +0 -0
  22. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  23. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/intensity/brightness.py +0 -0
  24. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/intensity/contrast.py +0 -0
  25. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/intensity/gamma.py +0 -0
  26. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -0
  27. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  28. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +0 -0
  29. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
  30. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
  31. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  32. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +0 -0
  33. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  34. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
  35. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  36. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/compose.py +0 -0
  37. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
  38. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
  39. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +0 -0
  40. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/pseudo2d.py +0 -0
  41. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/random.py +0 -0
  42. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/remove_label.py +0 -0
  43. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2/transforms/utils/seg_to_regions.py +0 -0
  44. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
  45. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2.egg-info/requires.txt +0 -0
  46. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
  47. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/readme.md +0 -0
  48. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/setup.cfg +0 -0
  49. {batchgeneratorsv2-0.2 → batchgeneratorsv2-0.2.2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: batchgeneratorsv2
3
- Version: 0.2
3
+ Version: 0.2.2
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -0,0 +1,51 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
5
+
6
+
7
+ class InvertImageTransform(ImageOnlyTransform):
8
+ def __init__(self, p_invert_image: float, p_synchronize_channels: float = 1, p_per_channel: float = 1):
9
+ super().__init__()
10
+ self.p_invert_image = p_invert_image
11
+ self.p_synchronize_channels = p_synchronize_channels
12
+ self.p_per_channel = p_per_channel
13
+
14
+ def get_parameters(self, **data_dict) -> dict:
15
+ shape = data_dict['image'].shape
16
+ apply = np.random.uniform() < self.p_invert_image
17
+ if apply:
18
+ if np.random.uniform() < self.p_synchronize_channels:
19
+ apply_to_channel = torch.arange(0, shape[0])
20
+ else:
21
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
22
+ else:
23
+ apply_to_channel = []
24
+ return {
25
+ 'apply_to_channel': apply_to_channel,
26
+ 'apply': apply,
27
+ }
28
+
29
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
30
+ if not params['apply']:
31
+ return img
32
+ else:
33
+ for ch in params['apply_to_channel']:
34
+ mn = img[ch].mean()
35
+ img[ch] -= mn
36
+ img[ch] *= -1
37
+ img[ch] += mn
38
+ return img
39
+
40
+
41
+ if __name__ == '__main__':
42
+ mbt = InvertImageTransform(0.5, 0.5, 0.5)
43
+ from batchviewer import view_batch
44
+
45
+ for _ in range(100):
46
+ data_dict = {'image': torch.ones((2, 20, 192, 64))}
47
+ data_dict['image'][0, :10] = -1
48
+ data_dict['image'][1, :5] = -1
49
+ ret = mbt(**data_dict)
50
+ print(ret['image'][0, 0, 0, 0], ret['image'][1, 0, 0, 0])
51
+ view_batch(mbt(**data_dict)['image'])
@@ -8,9 +8,13 @@ from torch.nn.functional import interpolate
8
8
 
9
9
 
10
10
  class SimulateLowResolutionTransform(ImageOnlyTransform):
11
- def __init__(self, scale: RandomScalar, synchronize_channels: bool, synchronize_axes: bool,
11
+ def __init__(self,
12
+ scale: RandomScalar,
13
+ synchronize_channels: bool,
14
+ synchronize_axes: bool,
12
15
  ignore_axes: Tuple[int, ...],
13
- allowed_channels: Tuple[int, ...] = None, p_per_channel: float = 1):
16
+ allowed_channels: Tuple[int, ...] = None,
17
+ p_per_channel: float = 1):
14
18
  super().__init__()
15
19
  self.scale = scale
16
20
  self.synchronize_channels = synchronize_channels
@@ -7,7 +7,7 @@ import SimpleITK
7
7
  import numpy as np
8
8
  import pandas as pd
9
9
  import torch
10
- from scipy.ndimage import fourier_gaussian
10
+ from scipy.ndimage import fourier_gaussian, gaussian_filter
11
11
  from torch import Tensor
12
12
  from torch.nn.functional import grid_sample
13
13
 
@@ -31,8 +31,18 @@ class SpatialTransform(BasicTransform):
31
31
  scaling: RandomScalar = (0.7, 1.3),
32
32
  p_synchronize_scaling_across_axes: float = 0,
33
33
  bg_style_seg_sampling: bool = True,
34
- mode_seg: str = 'bilinear'
34
+ mode_seg: str = 'bilinear',
35
+ border_mode_seg: str = "zeros",
36
+ center_deformation: bool = True,
37
+ padding_mode_image: str = "zeros"
35
38
  ):
39
+ """
40
+ magnitude must be given in pixels!
41
+ deformation scale is given as a paercentage of the edge length
42
+
43
+ padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target
44
+ because both call self._apply_to_image. Can be "zeros", "reflection", "border"
45
+ """
36
46
  super().__init__()
37
47
  self.patch_size = patch_size
38
48
  if not isinstance(patch_center_dist_from_border, (tuple, list)):
@@ -41,7 +51,7 @@ class SpatialTransform(BasicTransform):
41
51
  self.random_crop = random_crop
42
52
  self.p_elastic_deform = p_elastic_deform
43
53
  self.elastic_deform_scale = elastic_deform_scale # sigma for blurring offsets, in % of patch size. Larger values mean coarser deformation
44
- self.elastic_deform_magnitude = elastic_deform_magnitude # determines the maximum displacement, measured in % of patch size
54
+ self.elastic_deform_magnitude = elastic_deform_magnitude # determines the maximum displacement, measured in pixels!!
45
55
  self.p_rotation = p_rotation
46
56
  self.rotation = rotation
47
57
  self.p_scaling = p_scaling
@@ -50,9 +60,11 @@ class SpatialTransform(BasicTransform):
50
60
  self.p_synchronize_def_scale_across_axes = p_synchronize_def_scale_across_axes
51
61
  self.bg_style_seg_sampling = bg_style_seg_sampling
52
62
  self.mode_seg = mode_seg
63
+ self.border_mode_seg = border_mode_seg
64
+ self.center_deformation = center_deformation
65
+ self.padding_mode_image = padding_mode_image
53
66
 
54
67
  def get_parameters(self, **data_dict) -> dict:
55
- # note that we revert the axis order here because grid_sample uses dimensions in reverse order!
56
68
  dim = data_dict['image'].ndim - 1
57
69
 
58
70
  do_rotation = np.random.uniform() < self.p_rotation
@@ -60,14 +72,14 @@ class SpatialTransform(BasicTransform):
60
72
  do_deform = np.random.uniform() < self.p_elastic_deform
61
73
 
62
74
  if do_rotation:
63
- angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(dim - 1, -1, -1)]
75
+ angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(0, 3)]
64
76
  else:
65
77
  angles = [0] * dim
66
78
  if do_scale:
67
79
  if np.random.uniform() <= self.p_synchronize_scaling_across_axes:
68
80
  scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
69
81
  else:
70
- scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(dim - 1, -1, -1)]
82
+ scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(0, 3)]
71
83
  else:
72
84
  scales = [1] * dim
73
85
 
@@ -85,7 +97,6 @@ class SpatialTransform(BasicTransform):
85
97
  # elastic deformation. We need to create the displacement field here
86
98
  # we use the method from augment_spatial_2 in batchgenerators
87
99
  if do_deform:
88
- grid_scale = [i / j for i, j in zip(data_dict['image'].shape[1:], self.patch_size)][::-1]
89
100
  if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
90
101
  deformation_scales = [
91
102
  sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)
@@ -93,16 +104,16 @@ class SpatialTransform(BasicTransform):
93
104
  else:
94
105
  deformation_scales = [
95
106
  sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
96
- for i in range(dim - 1, -1, -1)
107
+ for i in range(dim)
97
108
  ]
98
109
 
99
110
  # sigmas must be in pixels, as this will be applied to the deformation field
100
- sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)][::-1]
101
- # the magnitude of the deformation field must adhere to the torch's value range for grid_sample, i.e. [-1. 1] and not pixel coordinates. Do not use sigmas here
102
- # we need to correct magnitude by grid_scale to account for the fact that the grid will be wrt to the image size but the magnitude should be wrt the patch size. oof.
111
+ sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
112
+
103
113
  magnitude = [
104
114
  sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
105
- dim=i, deformation_scale=deformation_scales[i]) / grid_scale[i] for i in range(dim - 1, -1, -1)]
115
+ dim=i, deformation_scale=deformation_scales[i])
116
+ for i in range(dim)]
106
117
  # doing it like this for better memory layout for blurring
107
118
  offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
108
119
 
@@ -117,18 +128,24 @@ class SpatialTransform(BasicTransform):
117
128
  tmp = fourier_gaussian(tmp, sigmas[d])
118
129
  offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
119
130
 
131
+ # tmp = offsets[d].numpy().astype(np.float64)
132
+ # gaussian_filter(tmp, sigmas[d], 0, output=tmp)
133
+ # offsets[d] = torch.from_numpy(tmp).to(offsets.dtype)
134
+ # print(offsets.dtype)
135
+
120
136
  mx = torch.max(torch.abs(offsets[d]))
121
137
  offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
122
- offsets = torch.permute(offsets, (1, 2, 3, 0))
138
+ spatial_dims = tuple(list(range(1, dim + 1)))
139
+ offsets = torch.permute(offsets, (*spatial_dims, 0))
123
140
  else:
124
141
  offsets = None
125
- # grid center must be in [-1, 1] as required by grid_sample
142
+
126
143
  shape = data_dict['image'].shape[1:]
127
144
  if not self.random_crop:
128
- center_location_in_pixels = [i / 2 for i in shape][::-1]
145
+ center_location_in_pixels = [i / 2 for i in shape]
129
146
  else:
130
147
  center_location_in_pixels = []
131
- for d in range(dim - 1, -1, -1):
148
+ for d in range(0, 3):
132
149
  mn = self.patch_center_dist_from_border[d]
133
150
  mx = shape[d] - self.patch_center_dist_from_border[d]
134
151
  if mx < mn:
@@ -146,15 +163,25 @@ class SpatialTransform(BasicTransform):
146
163
  # No spatial transformation is being done. Round grid_center and crop without having to interpolate.
147
164
  # This saves compute.
148
165
  # cropping requires the center to be given as integer coordinates
149
- img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']][::-1], self.patch_size, pad_mode='constant',
150
- pad_kwargs={'value': 0})
166
+
167
+ # torch is inconsistent. AAAAaaah
168
+ if self.padding_mode_image == 'reflection':
169
+ pad_mode = 'reflect'
170
+ pad_kwargs = {}
171
+ elif self.padding_mode_image == 'zeros':
172
+ pad_mode = 'constant'
173
+ {'value': 0}
174
+ elif self.padding_mode_image == 'border':
175
+ pad_mode = 'replicate'
176
+ pad_kwargs = {}
177
+ else:
178
+ raise RuntimeError('Unknown pad mode')
179
+
180
+ img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']], self.patch_size, pad_mode=pad_mode,
181
+ pad_kwargs=pad_kwargs)
151
182
  return img
152
183
  else:
153
- grid = _create_identity_grid(self.patch_size)
154
-
155
- # the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
156
- grid_scale = torch.Tensor([i / j for i, j in zip(img.shape[1:], self.patch_size)][::-1])
157
- grid /= grid_scale
184
+ grid = _create_centered_identity_grid2(self.patch_size)
158
185
 
159
186
  # we deform first, then rotate
160
187
  if params['elastic_offsets'] is not None:
@@ -164,14 +191,16 @@ class SpatialTransform(BasicTransform):
164
191
 
165
192
  # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
166
193
  # only do this if we elastic deform
167
- if params['elastic_offsets'] is not None:
194
+ if self.center_deformation and params['elastic_offsets'] is not None:
168
195
  mn = grid.mean(dim=list(range(img.ndim - 1)))
169
196
  else:
170
197
  mn = 0
171
- new_center = torch.Tensor(
172
- [(j / (i / 2) - 1) for i, j in zip(img.shape[1:][::-1], params['center_location_in_pixels'])])
173
- grid += - mn + new_center
174
- return grid_sample(img[None], grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)[0]
198
+
199
+ new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
200
+ grid += (new_center - mn)
201
+ # print(f'grid sample with pad mode {self.padding_mode_image}')
202
+ return grid_sample(img[None], _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
203
+ mode='bilinear', padding_mode=self.padding_mode_image, align_corners=False)[0]
175
204
 
176
205
  def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
177
206
  segmentation = segmentation.contiguous()
@@ -180,17 +209,13 @@ class SpatialTransform(BasicTransform):
180
209
  # This saves compute.
181
210
  # cropping requires the center to be given as integer coordinates
182
211
  segmentation = crop_tensor(segmentation,
183
- [math.floor(i) for i in params['center_location_in_pixels']][::-1],
212
+ [math.floor(i) for i in params['center_location_in_pixels']],
184
213
  self.patch_size,
185
214
  pad_mode='constant',
186
215
  pad_kwargs={'value': 0})
187
216
  return segmentation
188
217
  else:
189
- grid = _create_identity_grid(self.patch_size)
190
-
191
- # the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
192
- grid_scale = torch.Tensor([i / j for i, j in zip(segmentation.shape[1:], self.patch_size)][::-1])
193
- grid /= grid_scale
218
+ grid = _create_centered_identity_grid2(self.patch_size)
194
219
 
195
220
  # we deform first, then rotate
196
221
  if params['elastic_offsets'] is not None:
@@ -199,20 +224,22 @@ class SpatialTransform(BasicTransform):
199
224
  grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
200
225
 
201
226
  # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate
202
- if params['elastic_offsets'] is not None:
227
+ if self.center_deformation and params['elastic_offsets'] is not None:
203
228
  mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
204
229
  else:
205
230
  mn = 0
206
- new_center = torch.Tensor(
207
- [(j / (i / 2) - 1) for i, j in zip(segmentation.shape[1:][::-1], params['center_location_in_pixels'])])
208
- grid += - mn + new_center
231
+
232
+ new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], segmentation.shape[1:])])
233
+
234
+ grid += (new_center - mn)
235
+ grid = _convert_my_grid_to_grid_sample_grid(grid, segmentation.shape[1:])
209
236
 
210
237
  if self.mode_seg == 'nearest':
211
238
  result_seg = grid_sample(
212
239
  segmentation[None].float(),
213
240
  grid[None],
214
241
  mode=self.mode_seg,
215
- padding_mode="zeros",
242
+ padding_mode=self.border_mode_seg,
216
243
  align_corners=False
217
244
  )[0].to(segmentation.dtype)
218
245
  else:
@@ -226,7 +253,7 @@ class SpatialTransform(BasicTransform):
226
253
  ((segmentation[c] == labels[1]).float())[None, None],
227
254
  grid[None],
228
255
  mode=self.mode_seg,
229
- padding_mode="zeros",
256
+ padding_mode=self.border_mode_seg,
230
257
  align_corners=False
231
258
  )[0][0] >= 0.5
232
259
  result_seg[c][out] = labels[1]
@@ -238,7 +265,7 @@ class SpatialTransform(BasicTransform):
238
265
  ((segmentation[c] == u).float())[None, None],
239
266
  grid[None],
240
267
  mode=self.mode_seg,
241
- padding_mode="zeros",
268
+ padding_mode=self.border_mode_seg,
242
269
  align_corners=False
243
270
  )[0][0] >= 0.5] = u
244
271
  else:
@@ -250,7 +277,7 @@ class SpatialTransform(BasicTransform):
250
277
  done_mask = torch.zeros(*self.patch_size, dtype=torch.bool)
251
278
  for i, u in enumerate(labels):
252
279
  tmp[i] = grid_sample(((segmentation[c] == u).float() * scale_factor)[None, None], grid[None],
253
- mode=self.mode_seg, padding_mode="zeros", align_corners=False)[0][0]
280
+ mode=self.mode_seg, padding_mode=self.border_mode_seg, align_corners=False)[0][0]
254
281
  mask = tmp[i] > (0.7 * scale_factor)
255
282
  result_seg[c][mask] = u
256
283
  done_mask = done_mask | mask
@@ -305,15 +332,38 @@ def create_affine_matrix_2d(rotation_angle, scaling_factors):
305
332
  return RS
306
333
 
307
334
 
308
- def _create_identity_grid(size: List[int]) -> Tensor:
309
- space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size[::-1]]
335
+ # def _create_identity_grid(size: List[int]) -> Tensor:
336
+ # space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size[::-1]]
337
+ # grid = torch.meshgrid(space, indexing="ij")
338
+ # grid = torch.stack(grid, -1)
339
+ # spatial_dims = list(range(len(size)))
340
+ # grid = grid.permute((*spatial_dims[::-1], len(size)))
341
+ # return grid
342
+
343
+
344
+ def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor:
345
+ space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size]
310
346
  grid = torch.meshgrid(space, indexing="ij")
311
347
  grid = torch.stack(grid, -1)
312
- spatial_dims = list(range(len(size)))
313
- grid = grid.permute((*spatial_dims[::-1], len(size)))
314
348
  return grid
315
349
 
316
350
 
351
+ def _convert_my_grid_to_grid_sample_grid(my_grid: torch.Tensor, original_shape: Union[Tuple[int, ...], List[int]]):
352
+ # rescale
353
+ for d in range(len(original_shape)):
354
+ s = original_shape[d]
355
+ my_grid[..., d] /= (s / 2)
356
+ my_grid = torch.flip(my_grid, (len(my_grid.shape) - 1, ))
357
+ # my_grid = my_grid.flip((len(my_grid.shape) - 1,))
358
+ return my_grid
359
+
360
+
361
+ # size = (4, 5, 6)
362
+ # grid_old = _create_identity_grid(size)
363
+ # grid_new = _create_centered_identity_grid2(size)
364
+ # grid_new_converted = _convert_my_grid_to_grid_sample_grid(grid_new, size)
365
+ # torch.all(torch.isclose(grid_new_converted, grid_old))
366
+
317
367
  # An alternative way of generating the displacement fieldQ
318
368
  # def displacement_field(data: torch.Tensor):
319
369
  # downscaling_global = np.random.uniform() ** 2 * 4 + 2
@@ -395,30 +445,86 @@ if __name__ == '__main__':
395
445
  # with this part we can qualitatively test that the correct axes are ebing augmented. Just set one of the probs to 1 and off you go
396
446
  #################
397
447
 
398
- # def constant_scaling(image, dim, patch_size):
399
- # return 0.1
400
- #
401
- # def constant_magnitude(image, dim, patch_size, deformation_scale):
402
- # return 0.25 if dim == 2 else 0
403
- #
404
- # def rot(image, dim):
405
- # return 45/360 * 2 * np.pi if dim == 1 else 0
406
- #
448
+ def eldef_scale(image, dim, patch_size):
449
+ return 0.1
450
+
451
+ def eldef_magnitude(image, dim, patch_size, deformation_scale):
452
+ return 10 if dim == 2 else 0
453
+
454
+ def rot(image, dim):
455
+ return 45/360 * 2 * np.pi if dim == 0 else 0
456
+
457
+ def scaling(image, dim):
458
+ return 0.5 if dim == 0 else 1
459
+
460
+ # lines
461
+ patch = torch.zeros((1, 64, 60, 68))
462
+ patch[:, :, 10, 30] = 1
463
+ patch[:, 50, :, 30] = 1
464
+ patch[:, 40, 20, :] = 1
465
+
466
+ # patch_block
467
+ patch_block = torch.zeros((1, 64, 60, 68))
468
+ patch_block[:, 22:42, 20:40, 24:44] = 1
469
+
470
+ patch_line = torch.zeros((1, 64, 60, 128))
471
+ patch_line[:, 22:24, 30:32, 10:-10] = 1
472
+ use = patch_line
473
+
474
+ sp = SpatialTransform(
475
+ patch_size=patch.shape[1:],
476
+ patch_center_dist_from_border=0,
477
+ random_crop=False,
478
+ p_elastic_deform=0,
479
+ p_rotation=1,
480
+ p_scaling=0,
481
+ elastic_deform_scale=eldef_scale,
482
+ elastic_deform_magnitude=eldef_magnitude,
483
+ p_synchronize_def_scale_across_axes=0,
484
+ rotation=rot,
485
+ scaling=scaling,
486
+ p_synchronize_scaling_across_axes=0,
487
+ bg_style_seg_sampling=False,
488
+ mode_seg='bilinear'
489
+ )
490
+
491
+
492
+ SimpleITK.WriteImage(SimpleITK.GetImageFromArray(use[0].numpy()), 'orig.nii.gz')
493
+
494
+ params = sp.get_parameters(image=use)
495
+ transformed = sp._apply_to_image(use, **params)
496
+
497
+ SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
498
+
499
+ # p = torch.zeros((1, 1, 8, 16, 32))
500
+ # p[:, :, 2:6, 10:16, 10:24] = 1
501
+ # grid = _create_identity_grid(p.shape[2:])
502
+ # grid[:, :, :, 0] *= 0.5
503
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
504
+ # torch.all(out == p)
505
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
506
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
507
+
508
+ #################
509
+ # with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
510
+ #################
511
+
407
512
  # sp = SpatialTransform(
408
- # patch_size=(64, 60, 68),
513
+ # patch_size=(48, 52, 54),
409
514
  # patch_center_dist_from_border=0,
410
- # random_crop=False,
515
+ # random_crop=True,
411
516
  # p_elastic_deform=0,
412
- # elastic_deform_scale=0,
413
- # elastic_deform_magnitude=0,
414
- # p_synchronize_def_scale_across_axes=0,
415
517
  # p_rotation=1,
416
- # rotation=rot,
417
518
  # p_scaling=0,
418
- # scaling=constant_scaling,
419
- # p_synchronize_scaling_across_axes=0,
420
- # bg_style_seg_sampling=False,
421
- # mode_seg='bilinear'
519
+ # rotation=0
520
+ # )
521
+ # sp2 = SpatialTransform(
522
+ # patch_size=(48, 52, 54),
523
+ # patch_center_dist_from_border=0,
524
+ # random_crop=True,
525
+ # p_elastic_deform=0,
526
+ # p_rotation=0,
527
+ # p_scaling=0,
422
528
  # )
423
529
  #
424
530
  # patch = torch.zeros((1, 64, 60, 68))
@@ -427,62 +533,20 @@ if __name__ == '__main__':
427
533
  # patch[:, 40, 20, :] = 1
428
534
  # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
429
535
  #
536
+ # center_coords = [50, 10, 16]
430
537
  # params = sp.get_parameters(image=patch)
538
+ # params['center_location_in_pixels'] = center_coords
539
+ # params2 = sp2.get_parameters(image=patch)
540
+ # params2['center_location_in_pixels'] = center_coords
431
541
  # transformed = sp._apply_to_image(patch, **params)
542
+ # transformed2 = sp._apply_to_image(patch, **params2)
432
543
  #
433
544
  # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
434
-
435
- # p = torch.zeros((1, 1, 8, 16, 32))
436
- # p[:, :, 2:6, 10:16, 10:24] = 1
437
- # grid = _create_identity_grid(p.shape[2:])
438
- # grid[:, :, :, 0] *= 0.5
439
- # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
440
- # torch.all(out == p)
441
- # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
442
- # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
443
-
444
- #################
445
- # with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
446
- #################
447
-
448
- sp = SpatialTransform(
449
- patch_size=(48, 52, 54),
450
- patch_center_dist_from_border=0,
451
- random_crop=True,
452
- p_elastic_deform=0,
453
- p_rotation=1,
454
- p_scaling=0,
455
- rotation=0
456
- )
457
- sp2 = SpatialTransform(
458
- patch_size=(48, 52, 54),
459
- patch_center_dist_from_border=0,
460
- random_crop=True,
461
- p_elastic_deform=0,
462
- p_rotation=0,
463
- p_scaling=0,
464
- )
465
-
466
- patch = torch.zeros((1, 64, 60, 68))
467
- patch[:, :, 10, 30] = 1
468
- patch[:, 50, :, 30] = 1
469
- patch[:, 40, 20, :] = 1
470
- SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
471
-
472
- center_coords = [30, 28, 44]
473
- params = sp.get_parameters(image=patch)
474
- params['center_location_in_pixels'] = center_coords
475
- params2 = sp2.get_parameters(image=patch)
476
- params2['center_location_in_pixels'] = center_coords
477
- transformed = sp._apply_to_image(patch, **params)
478
- transformed2 = sp._apply_to_image(patch, **params)
479
-
480
- SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
481
- SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
545
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
482
546
 
483
547
 
484
548
 
485
- ####################
549
+ ####################
486
550
  # This is exploraroty code to check how to retrieve coordinates. I used it to verify that grid_sample does in fact
487
551
  # use coordinates in reversed dimension order (zyx and not xyz)
488
552
  ####################
@@ -0,0 +1,67 @@
1
+ from typing import Set
2
+ import numpy as np
3
+ import torch
4
+
5
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
6
+
7
+
8
+ class TransposeAxesTransform(BasicTransform):
9
+ """
10
+ A transformation class to permute specified spatial axes of an image and related data.
11
+
12
+ Attributes:
13
+ allowed_axes (Set[int]): Set of spatial axes allowed for permutation (e.g., {1, 2} for y and z axes in an
14
+ image of shape (c, x, y, z)).
15
+ """
16
+
17
+ def __init__(self, allowed_axes: Set[int]):
18
+ """
19
+ Initialize the transform with allowed spatial axes for permutation.
20
+
21
+ Args:
22
+ allowed_axes (Set[int]): Set of spatial axis indices for permutation.
23
+ """
24
+ super().__init__()
25
+ self.allowed_axes = allowed_axes
26
+
27
+ def get_parameters(self, **data_dict) -> dict:
28
+ """
29
+ Generate a random axis permutation order.
30
+
31
+ Args:
32
+ data_dict (dict): Dictionary containing `image` tensor data.
33
+
34
+ Returns:
35
+ dict: Permutation order of axes as 'axis_order'.
36
+ """
37
+ shape_of_allowed = [data_dict['image'].shape[1 + i] for i in self.allowed_axes]
38
+ if len(shape_of_allowed) < 2:
39
+ return {'axis_order': list(range(len(data_dict['image'].shape)))}
40
+ if not all(i == shape_of_allowed[0] for i in shape_of_allowed[1:]):
41
+ raise ValueError(f"Axis shapes are not identical: {shape_of_allowed}. Cannot permute.\n"
42
+ f"Image shape: {data_dict['image'].shape}. Allowed axes: {self.allowed_axes}")
43
+
44
+ axes = [i + 1 for i in self.allowed_axes]
45
+ np.random.shuffle(axes)
46
+ axis_order = np.arange(len(data_dict['image'].shape))
47
+ axis_order[np.isin(axis_order, axes)] = axes
48
+ return {'axis_order': [int(i) for i in axis_order]}
49
+
50
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
51
+ return segmentation.permute(params['axis_order']).contiguous()
52
+
53
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
54
+ return img.permute(params['axis_order']).contiguous()
55
+
56
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
57
+ return regression_target.permute(params['axis_order']).contiguous()
58
+
59
+ def _apply_to_bbox(self, bbox, **params):
60
+ raise NotImplementedError
61
+
62
+ def _apply_to_keypoints(self, keypoints, **params):
63
+ raise NotImplementedError
64
+
65
+ if __name__ == '__main__':
66
+ t = TransposeAxesTransform((1, 2))
67
+ ret = t(**{'image': torch.rand((2, 31, 32, 32)), 'segmentation': torch.ones((1, 31, 32, 32))})
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: batchgeneratorsv2
3
- Version: 0.2
3
+ Version: 0.2.2
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -24,6 +24,7 @@ batchgeneratorsv2/transforms/intensity/brightness.py
24
24
  batchgeneratorsv2/transforms/intensity/contrast.py
25
25
  batchgeneratorsv2/transforms/intensity/gamma.py
26
26
  batchgeneratorsv2/transforms/intensity/gaussian_noise.py
27
+ batchgeneratorsv2/transforms/intensity/inversion.py
27
28
  batchgeneratorsv2/transforms/nnunet/__init__.py
28
29
  batchgeneratorsv2/transforms/nnunet/random_binary_operator.py
29
30
  batchgeneratorsv2/transforms/nnunet/remove_connected_components.py
@@ -34,6 +35,7 @@ batchgeneratorsv2/transforms/spatial/__init__.py
34
35
  batchgeneratorsv2/transforms/spatial/low_resolution.py
35
36
  batchgeneratorsv2/transforms/spatial/mirroring.py
36
37
  batchgeneratorsv2/transforms/spatial/spatial.py
38
+ batchgeneratorsv2/transforms/spatial/transpose.py
37
39
  batchgeneratorsv2/transforms/utils/__init__.py
38
40
  batchgeneratorsv2/transforms/utils/compose.py
39
41
  batchgeneratorsv2/transforms/utils/cropping.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "batchgeneratorsv2"
3
- version = "0.2"
3
+ version = "0.2.2"
4
4
  requires-python = ">=3.9"
5
5
  description = "Batchgenerators but better"
6
6
  readme = "readme.md"