batchgeneratorsv2 0.1.1__tar.gz → 0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/PKG-INFO +1 -1
  2. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +1 -1
  3. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +1 -1
  4. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/base/basic_transform.py +5 -0
  5. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/contrast.py +3 -0
  6. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/spatial.py +220 -76
  7. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/random.py +5 -1
  8. batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  9. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/PKG-INFO +1 -1
  10. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/pyproject.toml +1 -1
  11. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/seg_to_regions.py +0 -23
  12. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/LICENSE +0 -0
  13. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/__init__.py +0 -0
  14. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/__init__.py +0 -0
  15. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  16. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
  17. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/dataloading/__init__.py +0 -0
  18. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/helpers/__init__.py +0 -0
  19. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
  20. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/__init__.py +0 -0
  21. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/base/__init__.py +0 -0
  22. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  23. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/brightness.py +0 -0
  24. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/gamma.py +0 -0
  25. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -0
  26. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +0 -0
  28. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
  29. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
  30. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +0 -0
  32. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  33. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/low_resolution.py +0 -0
  34. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
  35. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  36. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/compose.py +0 -0
  37. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
  38. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
  39. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +0 -0
  40. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/pseudo2d.py +0 -0
  41. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/remove_label.py +0 -0
  42. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/SOURCES.txt +0 -0
  43. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
  44. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/requires.txt +0 -0
  45. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
  46. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/readme.md +0 -0
  47. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/setup.cfg +0 -0
  48. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: batchgeneratorsv2
3
- Version: 0.1.1
3
+ Version: 0.2
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -87,4 +87,4 @@ if __name__ == '__main__':
87
87
  compute_times[i].append(time() - st)
88
88
 
89
89
  for t, ct in zip(tr_transforms, compute_times):
90
- print(t.__class__.__name__, np.percentile(ct, 20))
90
+ print(t.__class__.__name__, np.mean(ct))
@@ -135,4 +135,4 @@ if __name__ == '__main__':
135
135
  compute_times[i].append(time() - st)
136
136
 
137
137
  for t, ct in zip(transforms, compute_times):
138
- print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.percentile(ct, 20))
138
+ print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.mean(ct))
@@ -53,6 +53,11 @@ class BasicTransform(abc.ABC):
53
53
  def get_parameters(self, **data_dict) -> dict:
54
54
  return {}
55
55
 
56
+ def __repr__(self):
57
+ ret_str = str(type(self).__name__) + "( " + ", ".join(
58
+ [key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
59
+ return ret_str
60
+
56
61
 
57
62
  class ImageOnlyTransform(BasicTransform):
58
63
  def apply(self, data_dict: dict, **params) -> dict:
@@ -22,6 +22,9 @@ class BGContrast():
22
22
  def __call__(self, *args, **kwargs):
23
23
  return self.sample_contrast(*args, **kwargs)
24
24
 
25
+ def __repr__(self):
26
+ return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
27
+
25
28
  class ContrastTransform(ImageOnlyTransform):
26
29
  def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
27
30
  super().__init__()
@@ -1,6 +1,9 @@
1
1
  from copy import deepcopy
2
2
  from typing import Tuple, List, Union
3
3
 
4
+ import math
5
+
6
+ import SimpleITK
4
7
  import numpy as np
5
8
  import pandas as pd
6
9
  import torch
@@ -14,14 +17,19 @@ from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
14
17
 
15
18
 
16
19
  class SpatialTransform(BasicTransform):
17
- def __init__(self, patch_size: Tuple[int, ...],
20
+ def __init__(self,
21
+ patch_size: Tuple[int, ...],
18
22
  patch_center_dist_from_border: Union[int, List[int], Tuple[int, ...]],
19
23
  random_crop: bool,
20
- p_elastic_deform: float = 0, elastic_deform_scale: RandomScalar = (0, 0.2),
24
+ p_elastic_deform: float = 0,
25
+ elastic_deform_scale: RandomScalar = (0, 0.2),
21
26
  elastic_deform_magnitude: RandomScalar = (0, 0.2),
22
- p_synchronize_def_scale_across_axes: float = False,
23
- p_rotation: float = 0, rotation: RandomScalar = (0, 2 * np.pi),
24
- p_scaling: float = 0, scaling: RandomScalar = (0.7, 1.3), p_synchronize_scaling_across_axes: float = False,
27
+ p_synchronize_def_scale_across_axes: float = 0,
28
+ p_rotation: float = 0,
29
+ rotation: RandomScalar = (0, 2 * np.pi),
30
+ p_scaling: float = 0,
31
+ scaling: RandomScalar = (0.7, 1.3),
32
+ p_synchronize_scaling_across_axes: float = 0,
25
33
  bg_style_seg_sampling: bool = True,
26
34
  mode_seg: str = 'bilinear'
27
35
  ):
@@ -44,6 +52,7 @@ class SpatialTransform(BasicTransform):
44
52
  self.mode_seg = mode_seg
45
53
 
46
54
  def get_parameters(self, **data_dict) -> dict:
55
+ # note that we revert the axis order here because grid_sample uses dimensions in reverse order!
47
56
  dim = data_dict['image'].ndim - 1
48
57
 
49
58
  do_rotation = np.random.uniform() < self.p_rotation
@@ -51,21 +60,23 @@ class SpatialTransform(BasicTransform):
51
60
  do_deform = np.random.uniform() < self.p_elastic_deform
52
61
 
53
62
  if do_rotation:
54
- angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(dim)]
63
+ angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(dim - 1, -1, -1)]
55
64
  else:
56
65
  angles = [0] * dim
57
66
  if do_scale:
58
- scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in
59
- range(dim)] if np.random.uniform() < self.p_synchronize_scaling_across_axes else [sample_scalar(
60
- self.scaling, image=data_dict['image'], dim=None)] * dim
67
+ if np.random.uniform() <= self.p_synchronize_scaling_across_axes:
68
+ scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
69
+ else:
70
+ scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(dim - 1, -1, -1)]
61
71
  else:
62
72
  scales = [1] * dim
73
+
63
74
  # affine matrix
64
75
  if do_scale or do_rotation:
65
76
  if dim == 3:
66
77
  affine = create_affine_matrix_3d(angles, scales)
67
78
  elif dim == 2:
68
- affine = create_affine_matrix_2d(angles[0], scales)
79
+ affine = create_affine_matrix_2d(angles[-1], scales)
69
80
  else:
70
81
  raise RuntimeError(f'Unsupported dimension: {dim}')
71
82
  else:
@@ -74,19 +85,24 @@ class SpatialTransform(BasicTransform):
74
85
  # elastic deformation. We need to create the displacement field here
75
86
  # we use the method from augment_spatial_2 in batchgenerators
76
87
  if do_deform:
77
- grid_scale = [i / j for i, j in zip(data_dict['image'].shape[1:], self.patch_size)]
78
- deformation_scales = [
79
- sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
80
- for i in
81
- range(dim)] if np.random.uniform() < self.p_synchronize_scaling_across_axes else [sample_scalar(
82
- self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)] * dim
88
+ grid_scale = [i / j for i, j in zip(data_dict['image'].shape[1:], self.patch_size)][::-1]
89
+ if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
90
+ deformation_scales = [
91
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)
92
+ ] * dim
93
+ else:
94
+ deformation_scales = [
95
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
96
+ for i in range(dim - 1, -1, -1)
97
+ ]
98
+
83
99
  # sigmas must be in pixels, as this will be applied to the deformation field
84
- sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
100
+ sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)][::-1]
85
101
  # the magnitude of the deformation field must adhere to the torch's value range for grid_sample, i.e. [-1. 1] and not pixel coordinates. Do not use sigmas here
86
102
  # we need to correct magnitude by grid_scale to account for the fact that the grid will be wrt to the image size but the magnitude should be wrt the patch size. oof.
87
103
  magnitude = [
88
104
  sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
89
- dim=i, deformation_scale=deformation_scales[i]) / grid_scale[i] for i in range(dim)]
105
+ dim=i, deformation_scale=deformation_scales[i]) / grid_scale[i] for i in range(dim - 1, -1, -1)]
90
106
  # doing it like this for better memory layout for blurring
91
107
  offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
92
108
 
@@ -98,7 +114,7 @@ class SpatialTransform(BasicTransform):
98
114
 
99
115
  # fft numpy, this is faster o.O
100
116
  tmp = np.fft.fftn(offsets[d].numpy())
101
- tmp = fourier_gaussian(tmp, sigmas)
117
+ tmp = fourier_gaussian(tmp, sigmas[d])
102
118
  offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
103
119
 
104
120
  mx = torch.max(torch.abs(offsets[d]))
@@ -109,10 +125,10 @@ class SpatialTransform(BasicTransform):
109
125
  # grid center must be in [-1, 1] as required by grid_sample
110
126
  shape = data_dict['image'].shape[1:]
111
127
  if not self.random_crop:
112
- center_location_in_pixels = [i / 2 for i in shape]
128
+ center_location_in_pixels = [i / 2 for i in shape][::-1]
113
129
  else:
114
130
  center_location_in_pixels = []
115
- for d in range(dim):
131
+ for d in range(dim - 1, -1, -1):
116
132
  mn = self.patch_center_dist_from_border[d]
117
133
  mx = shape[d] - self.patch_center_dist_from_border[d]
118
134
  if mx < mn:
@@ -130,14 +146,14 @@ class SpatialTransform(BasicTransform):
130
146
  # No spatial transformation is being done. Round grid_center and crop without having to interpolate.
131
147
  # This saves compute.
132
148
  # cropping requires the center to be given as integer coordinates
133
- img = crop_tensor(img, [round(i) for i in params['center_location_in_pixels']], self.patch_size, pad_mode='constant',
149
+ img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']][::-1], self.patch_size, pad_mode='constant',
134
150
  pad_kwargs={'value': 0})
135
151
  return img
136
152
  else:
137
153
  grid = _create_identity_grid(self.patch_size)
138
154
 
139
155
  # the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
140
- grid_scale = torch.Tensor([i / j for i, j in zip(img.shape[1:], self.patch_size)])
156
+ grid_scale = torch.Tensor([i / j for i, j in zip(img.shape[1:], self.patch_size)][::-1])
141
157
  grid /= grid_scale
142
158
 
143
159
  # we deform first, then rotate
@@ -147,9 +163,13 @@ class SpatialTransform(BasicTransform):
147
163
  grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
148
164
 
149
165
  # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
150
- mn = grid.mean(dim=list(range(img.ndim - 1)))
166
+ # only do this if we elastic deform
167
+ if params['elastic_offsets'] is not None:
168
+ mn = grid.mean(dim=list(range(img.ndim - 1)))
169
+ else:
170
+ mn = 0
151
171
  new_center = torch.Tensor(
152
- [(j / (i / 2) - 1) for i, j in zip(img.shape[1:], params['center_location_in_pixels'])])
172
+ [(j / (i / 2) - 1) for i, j in zip(img.shape[1:][::-1], params['center_location_in_pixels'])])
153
173
  grid += - mn + new_center
154
174
  return grid_sample(img[None], grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)[0]
155
175
 
@@ -159,14 +179,17 @@ class SpatialTransform(BasicTransform):
159
179
  # No spatial transformation is being done. Round grid_center and crop without having to interpolate.
160
180
  # This saves compute.
161
181
  # cropping requires the center to be given as integer coordinates
162
- segmentation = crop_tensor(segmentation, [round(i) for i in params['center_location_in_pixels']], self.patch_size,
163
- pad_mode='constant', pad_kwargs={'value': 0})
182
+ segmentation = crop_tensor(segmentation,
183
+ [math.floor(i) for i in params['center_location_in_pixels']][::-1],
184
+ self.patch_size,
185
+ pad_mode='constant',
186
+ pad_kwargs={'value': 0})
164
187
  return segmentation
165
188
  else:
166
189
  grid = _create_identity_grid(self.patch_size)
167
190
 
168
191
  # the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
169
- grid_scale = torch.Tensor([i / j for i, j in zip(segmentation.shape[1:], self.patch_size)])
192
+ grid_scale = torch.Tensor([i / j for i, j in zip(segmentation.shape[1:], self.patch_size)][::-1])
170
193
  grid /= grid_scale
171
194
 
172
195
  # we deform first, then rotate
@@ -176,9 +199,12 @@ class SpatialTransform(BasicTransform):
176
199
  grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
177
200
 
178
201
  # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate
179
- mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
202
+ if params['elastic_offsets'] is not None:
203
+ mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
204
+ else:
205
+ mn = 0
180
206
  new_center = torch.Tensor(
181
- [(j / (i / 2) - 1) for i, j in zip(segmentation.shape[1:], params['center_location_in_pixels'])])
207
+ [(j / (i / 2) - 1) for i, j in zip(segmentation.shape[1:][::-1], params['center_location_in_pixels'])])
182
208
  grid += - mn + new_center
183
209
 
184
210
  if self.mode_seg == 'nearest':
@@ -308,58 +334,176 @@ def _create_identity_grid(size: List[int]) -> Tensor:
308
334
 
309
335
 
310
336
  if __name__ == '__main__':
311
- torch.set_num_threads(1)
312
-
313
- shape = (128, 128, 128)
314
- patch_size = (128, 128, 128)
315
- labels = 2
316
-
317
-
318
- # seg = torch.rand([i // 32 for i in shape]) * labels
319
- # seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
320
- # decimals=0).to(torch.int16)
321
- # img = torch.ones((1, *shape))
322
- # img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
337
+ # torch.set_num_threads(1)
338
+ #
339
+ # shape = (128, 128, 128)
340
+ # patch_size = (128, 128, 128)
341
+ # labels = 2
342
+ #
343
+ #
344
+ # # seg = torch.rand([i // 32 for i in shape]) * labels
345
+ # # seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
346
+ # # decimals=0).to(torch.int16)
347
+ # # img = torch.ones((1, *shape))
348
+ # # img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
349
+ #
350
+ #
351
+ # import SimpleITK as sitk
352
+ # # img = camera()
353
+ # # seg = None
354
+ # img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/imagesTr/BraTS2021_00000_0000.nii.gz'))
355
+ # seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/labelsTr/BraTS2021_00000.nii.gz'))
356
+ #
357
+ # patch_size = (192, 192, 192)
358
+ # sp = SpatialTransform(
359
+ # patch_size=(192, 192, 192),
360
+ # patch_center_dist_from_border=[i / 2 for i in patch_size],
361
+ # random_crop=True,
362
+ # p_elastic_deform=0,
363
+ # elastic_deform_magnitude=(0.1, 0.1),
364
+ # elastic_deform_scale=(0.1, 0.1),
365
+ # p_synchronize_def_scale_across_axes=0.5,
366
+ # p_rotation=1,
367
+ # rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
368
+ # p_scaling=1,
369
+ # scaling=(0.75, 1),
370
+ # p_synchronize_scaling_across_axes=0.5,
371
+ # bg_style_seg_sampling=True,
372
+ # mode_seg='bilinear'
373
+ # )
374
+ #
375
+ # data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
376
+ # if seg is not None:
377
+ # data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
378
+ # # out = sp(**data_dict)
379
+ # #
380
+ # # view_batch(out['image'], out['segmentation'])
381
+ #
382
+ # from time import time
383
+ # times = []
384
+ # for _ in range(10):
385
+ # data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
386
+ # if seg is not None:
387
+ # data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
388
+ # st = time()
389
+ # out = sp(**data_dict)
390
+ # times.append(time() - st)
391
+ # print(np.median(times))
392
+
393
+
394
+ #################
395
+ # with this part we can qualitatively test that the correct axes are ebing augmented. Just set one of the probs to 1 and off you go
396
+ #################
397
+
398
+ # def constant_scaling(image, dim, patch_size):
399
+ # return 0.1
400
+ #
401
+ # def constant_magnitude(image, dim, patch_size, deformation_scale):
402
+ # return 0.25 if dim == 2 else 0
403
+ #
404
+ # def rot(image, dim):
405
+ # return 45/360 * 2 * np.pi if dim == 1 else 0
406
+ #
407
+ # sp = SpatialTransform(
408
+ # patch_size=(64, 60, 68),
409
+ # patch_center_dist_from_border=0,
410
+ # random_crop=False,
411
+ # p_elastic_deform=0,
412
+ # elastic_deform_scale=0,
413
+ # elastic_deform_magnitude=0,
414
+ # p_synchronize_def_scale_across_axes=0,
415
+ # p_rotation=1,
416
+ # rotation=rot,
417
+ # p_scaling=0,
418
+ # scaling=constant_scaling,
419
+ # p_synchronize_scaling_across_axes=0,
420
+ # bg_style_seg_sampling=False,
421
+ # mode_seg='bilinear'
422
+ # )
423
+ #
424
+ # patch = torch.zeros((1, 64, 60, 68))
425
+ # patch[:, :, 10, 30] = 1
426
+ # patch[:, 50, :, 30] = 1
427
+ # patch[:, 40, 20, :] = 1
428
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
429
+ #
430
+ # params = sp.get_parameters(image=patch)
431
+ # transformed = sp._apply_to_image(patch, **params)
432
+ #
433
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
323
434
 
435
+ # p = torch.zeros((1, 1, 8, 16, 32))
436
+ # p[:, :, 2:6, 10:16, 10:24] = 1
437
+ # grid = _create_identity_grid(p.shape[2:])
438
+ # grid[:, :, :, 0] *= 0.5
439
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
440
+ # torch.all(out == p)
441
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
442
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
324
443
 
325
- import SimpleITK as sitk
326
- # img = camera()
327
- # seg = None
328
- img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/imagesTr/BraTS2021_00000_0000.nii.gz'))
329
- seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/labelsTr/BraTS2021_00000.nii.gz'))
444
+ #################
445
+ # with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
446
+ #################
330
447
 
331
- patch_size = (192, 192, 192)
332
448
  sp = SpatialTransform(
333
- patch_size=(192, 192, 192),
334
- patch_center_dist_from_border=[i / 2 for i in patch_size],
449
+ patch_size=(48, 52, 54),
450
+ patch_center_dist_from_border=0,
335
451
  random_crop=True,
336
452
  p_elastic_deform=0,
337
- elastic_deform_magnitude=(0.1, 0.1),
338
- elastic_deform_scale=(0.1, 0.1),
339
- p_synchronize_def_scale_across_axes=0.5,
340
453
  p_rotation=1,
341
- rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
342
- p_scaling=1,
343
- scaling=(0.75, 1),
344
- p_synchronize_scaling_across_axes=0.5,
345
- bg_style_seg_sampling=True,
346
- mode_seg='bilinear'
454
+ p_scaling=0,
455
+ rotation=0
456
+ )
457
+ sp2 = SpatialTransform(
458
+ patch_size=(48, 52, 54),
459
+ patch_center_dist_from_border=0,
460
+ random_crop=True,
461
+ p_elastic_deform=0,
462
+ p_rotation=0,
463
+ p_scaling=0,
347
464
  )
348
465
 
349
- data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
350
- if seg is not None:
351
- data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
352
- # out = sp(**data_dict)
353
- #
354
- # view_batch(out['image'], out['segmentation'])
355
-
356
- from time import time
357
- times = []
358
- for _ in range(10):
359
- data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
360
- if seg is not None:
361
- data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
362
- st = time()
363
- out = sp(**data_dict)
364
- times.append(time() - st)
365
- print(np.median(times))
466
+ patch = torch.zeros((1, 64, 60, 68))
467
+ patch[:, :, 10, 30] = 1
468
+ patch[:, 50, :, 30] = 1
469
+ patch[:, 40, 20, :] = 1
470
+ SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
471
+
472
+ center_coords = [30, 28, 44]
473
+ params = sp.get_parameters(image=patch)
474
+ params['center_location_in_pixels'] = center_coords
475
+ params2 = sp2.get_parameters(image=patch)
476
+ params2['center_location_in_pixels'] = center_coords
477
+ transformed = sp._apply_to_image(patch, **params)
478
+ transformed2 = sp._apply_to_image(patch, **params)
479
+
480
+ SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
481
+ SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
482
+
483
+
484
+
485
+ ####################
486
+ # This is exploraroty code to check how to retrieve coordinates. I used it to verify that grid_sample does in fact
487
+ # use coordinates in reversed dimension order (zyx and not xyz)
488
+ ####################
489
+ # # create a dummy input which has a unique shape in each exis
490
+ # p = torch.zeros((1, 1, 8, 16, 32))
491
+ # # set one pixel to 1
492
+ # p[:, :, 4, 0, 31] = 1
493
+ # # now create an identity grid. I have verified that this grid yields the same image as the input when used in grid_sample. So the grid is correct
494
+ # grid = _create_identity_grid((8, 16, 32)).contiguous() # grid is shape torch.Size([8, 16, 32, 3])
495
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
496
+ # assert torch.all(out == p) # this passes
497
+ # # reduce the grid to the location we are interested in. That are the coordinates where we placed the 1. The 4:5 etc is only so that we keep the number of dimensions
498
+ # grid = grid[4:5, 0:1, 31:32]
499
+ # # What coordinate would we expect? Note that grid is [-1, 1]
500
+ # # For the first dimension, coordinate 4 out of shape 8 is approximately in the middle, so about 0
501
+ # # For the second dimension, coordinate 0 out of shape 16 is very low, so we expect -1 ish (remember there is aligned corners and shit)
502
+ # # For the third dimension, coordinate 31 out of shape 32 is very high, so we expect 1 ish (remember there is aligned corners and shit)
503
+ # # So we expect [0, -1, 1]
504
+ # # What do we get?
505
+ # print(grid)
506
+ # # > tensor([[[[ 0.9688, -0.9375, 0.1250]]]])
507
+ # # not what we expect
508
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
509
+ # assert out.item() == 1
@@ -16,4 +16,8 @@ class RandomTransform(BasicTransform):
16
16
  if params['apply_transform']:
17
17
  return self.transform(**data_dict)
18
18
  else:
19
- return data_dict
19
+ return data_dict
20
+
21
+ def __repr__(self):
22
+ ret_str = f"{type(self).__name__}(p={self.apply_probability}, transform={self.transform})"
23
+ return ret_str
@@ -0,0 +1,24 @@
1
+ from typing import Union, List, Tuple
2
+ import torch
3
+
4
+ from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
5
+
6
+
7
+ class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
8
+ def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
9
+ super().__init__()
10
+ self.regions = [torch.Tensor(i) if not isinstance(i, int) else torch.Tensor([i]) for i in regions]
11
+ self.channel_in_seg = channel_in_seg
12
+
13
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
14
+ num_regions = len(self.regions)
15
+ region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
16
+ for region_id, region_labels in enumerate(self.regions):
17
+ if len(region_labels) == 1:
18
+ region_output[region_id] = segmentation[self.channel_in_seg] == region_labels
19
+ else:
20
+ region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels)
21
+ # we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to
22
+ # device followed by cast on device should be faster than having fp32 here and transferring that
23
+ return region_output
24
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: batchgeneratorsv2
3
- Version: 0.1.1
3
+ Version: 0.2
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "batchgeneratorsv2"
3
- version = "0.1.1"
3
+ version = "0.2"
4
4
  requires-python = ">=3.9"
5
5
  description = "Batchgenerators but better"
6
6
  readme = "readme.md"
@@ -1,23 +0,0 @@
1
- from typing import Union, List, Tuple
2
-
3
- import torch
4
-
5
- from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
6
-
7
-
8
- class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
9
- def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
10
- super().__init__()
11
- self.regions = regions
12
- self.channel_in_seg = channel_in_seg
13
-
14
- def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
15
- num_regions = len(self.regions)
16
- region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
17
- if isinstance(region_labels, int) or len(region_labels) == 1:
18
- if not isinstance(region_labels, int):
19
- region_labels = region_labels[0]
20
- region_output[:, region_id] = seg[:, self.seg_channel] == region_labels
21
- else:
22
- region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
23
- return region_output.to(segmentation.dtype)