batchgeneratorsv2 0.1.1__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/PKG-INFO +1 -1
  2. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +1 -1
  3. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +1 -1
  4. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/base/basic_transform.py +5 -0
  5. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/contrast.py +3 -0
  6. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/spatial.py +267 -90
  7. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/random.py +5 -1
  8. batchgeneratorsv2-0.2.1/batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  9. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/PKG-INFO +1 -1
  10. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/pyproject.toml +1 -1
  11. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/seg_to_regions.py +0 -23
  12. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/LICENSE +0 -0
  13. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/__init__.py +0 -0
  14. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/__init__.py +0 -0
  15. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  16. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
  17. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/dataloading/__init__.py +0 -0
  18. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/helpers/__init__.py +0 -0
  19. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
  20. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/__init__.py +0 -0
  21. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/base/__init__.py +0 -0
  22. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  23. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/brightness.py +0 -0
  24. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/gamma.py +0 -0
  25. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -0
  26. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +0 -0
  28. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
  29. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
  30. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +0 -0
  32. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  33. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/low_resolution.py +0 -0
  34. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
  35. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  36. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/compose.py +0 -0
  37. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
  38. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
  39. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +0 -0
  40. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/pseudo2d.py +0 -0
  41. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/remove_label.py +0 -0
  42. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/SOURCES.txt +0 -0
  43. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
  44. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/requires.txt +0 -0
  45. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
  46. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/readme.md +0 -0
  47. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/setup.cfg +0 -0
  48. {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: batchgeneratorsv2
3
- Version: 0.1.1
3
+ Version: 0.2.1
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -87,4 +87,4 @@ if __name__ == '__main__':
87
87
  compute_times[i].append(time() - st)
88
88
 
89
89
  for t, ct in zip(tr_transforms, compute_times):
90
- print(t.__class__.__name__, np.percentile(ct, 20))
90
+ print(t.__class__.__name__, np.mean(ct))
@@ -135,4 +135,4 @@ if __name__ == '__main__':
135
135
  compute_times[i].append(time() - st)
136
136
 
137
137
  for t, ct in zip(transforms, compute_times):
138
- print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.percentile(ct, 20))
138
+ print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.mean(ct))
@@ -53,6 +53,11 @@ class BasicTransform(abc.ABC):
53
53
  def get_parameters(self, **data_dict) -> dict:
54
54
  return {}
55
55
 
56
+ def __repr__(self):
57
+ ret_str = str(type(self).__name__) + "( " + ", ".join(
58
+ [key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
59
+ return ret_str
60
+
56
61
 
57
62
  class ImageOnlyTransform(BasicTransform):
58
63
  def apply(self, data_dict: dict, **params) -> dict:
@@ -22,6 +22,9 @@ class BGContrast():
22
22
  def __call__(self, *args, **kwargs):
23
23
  return self.sample_contrast(*args, **kwargs)
24
24
 
25
+ def __repr__(self):
26
+ return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
27
+
25
28
  class ContrastTransform(ImageOnlyTransform):
26
29
  def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
27
30
  super().__init__()
@@ -1,6 +1,9 @@
1
1
  from copy import deepcopy
2
2
  from typing import Tuple, List, Union
3
3
 
4
+ import math
5
+
6
+ import SimpleITK
4
7
  import numpy as np
5
8
  import pandas as pd
6
9
  import torch
@@ -14,17 +17,25 @@ from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
14
17
 
15
18
 
16
19
  class SpatialTransform(BasicTransform):
17
- def __init__(self, patch_size: Tuple[int, ...],
20
+ def __init__(self,
21
+ patch_size: Tuple[int, ...],
18
22
  patch_center_dist_from_border: Union[int, List[int], Tuple[int, ...]],
19
23
  random_crop: bool,
20
- p_elastic_deform: float = 0, elastic_deform_scale: RandomScalar = (0, 0.2),
24
+ p_elastic_deform: float = 0,
25
+ elastic_deform_scale: RandomScalar = (0, 0.2),
21
26
  elastic_deform_magnitude: RandomScalar = (0, 0.2),
22
- p_synchronize_def_scale_across_axes: float = False,
23
- p_rotation: float = 0, rotation: RandomScalar = (0, 2 * np.pi),
24
- p_scaling: float = 0, scaling: RandomScalar = (0.7, 1.3), p_synchronize_scaling_across_axes: float = False,
27
+ p_synchronize_def_scale_across_axes: float = 0,
28
+ p_rotation: float = 0,
29
+ rotation: RandomScalar = (0, 2 * np.pi),
30
+ p_scaling: float = 0,
31
+ scaling: RandomScalar = (0.7, 1.3),
32
+ p_synchronize_scaling_across_axes: float = 0,
25
33
  bg_style_seg_sampling: bool = True,
26
34
  mode_seg: str = 'bilinear'
27
35
  ):
36
+ """
37
+ magnitude must be given in pixels!
38
+ """
28
39
  super().__init__()
29
40
  self.patch_size = patch_size
30
41
  if not isinstance(patch_center_dist_from_border, (tuple, list)):
@@ -51,21 +62,23 @@ class SpatialTransform(BasicTransform):
51
62
  do_deform = np.random.uniform() < self.p_elastic_deform
52
63
 
53
64
  if do_rotation:
54
- angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(dim)]
65
+ angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(0, 3)]
55
66
  else:
56
67
  angles = [0] * dim
57
68
  if do_scale:
58
- scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in
59
- range(dim)] if np.random.uniform() < self.p_synchronize_scaling_across_axes else [sample_scalar(
60
- self.scaling, image=data_dict['image'], dim=None)] * dim
69
+ if np.random.uniform() <= self.p_synchronize_scaling_across_axes:
70
+ scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
71
+ else:
72
+ scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(0, 3)]
61
73
  else:
62
74
  scales = [1] * dim
75
+
63
76
  # affine matrix
64
77
  if do_scale or do_rotation:
65
78
  if dim == 3:
66
79
  affine = create_affine_matrix_3d(angles, scales)
67
80
  elif dim == 2:
68
- affine = create_affine_matrix_2d(angles[0], scales)
81
+ affine = create_affine_matrix_2d(angles[-1], scales)
69
82
  else:
70
83
  raise RuntimeError(f'Unsupported dimension: {dim}')
71
84
  else:
@@ -74,19 +87,23 @@ class SpatialTransform(BasicTransform):
74
87
  # elastic deformation. We need to create the displacement field here
75
88
  # we use the method from augment_spatial_2 in batchgenerators
76
89
  if do_deform:
77
- grid_scale = [i / j for i, j in zip(data_dict['image'].shape[1:], self.patch_size)]
78
- deformation_scales = [
79
- sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
80
- for i in
81
- range(dim)] if np.random.uniform() < self.p_synchronize_scaling_across_axes else [sample_scalar(
82
- self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)] * dim
90
+ if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
91
+ deformation_scales = [
92
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)
93
+ ] * dim
94
+ else:
95
+ deformation_scales = [
96
+ sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
97
+ for i in range(0, 3)
98
+ ]
99
+
83
100
  # sigmas must be in pixels, as this will be applied to the deformation field
84
101
  sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
85
- # the magnitude of the deformation field must adhere to the torch's value range for grid_sample, i.e. [-1. 1] and not pixel coordinates. Do not use sigmas here
86
- # we need to correct magnitude by grid_scale to account for the fact that the grid will be wrt to the image size but the magnitude should be wrt the patch size. oof.
102
+
87
103
  magnitude = [
88
104
  sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
89
- dim=i, deformation_scale=deformation_scales[i]) / grid_scale[i] for i in range(dim)]
105
+ dim=i, deformation_scale=deformation_scales[i])
106
+ for i in range(0, 3)]
90
107
  # doing it like this for better memory layout for blurring
91
108
  offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
92
109
 
@@ -98,7 +115,7 @@ class SpatialTransform(BasicTransform):
98
115
 
99
116
  # fft numpy, this is faster o.O
100
117
  tmp = np.fft.fftn(offsets[d].numpy())
101
- tmp = fourier_gaussian(tmp, sigmas)
118
+ tmp = fourier_gaussian(tmp, sigmas[d])
102
119
  offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
103
120
 
104
121
  mx = torch.max(torch.abs(offsets[d]))
@@ -106,13 +123,13 @@ class SpatialTransform(BasicTransform):
106
123
  offsets = torch.permute(offsets, (1, 2, 3, 0))
107
124
  else:
108
125
  offsets = None
109
- # grid center must be in [-1, 1] as required by grid_sample
126
+
110
127
  shape = data_dict['image'].shape[1:]
111
128
  if not self.random_crop:
112
129
  center_location_in_pixels = [i / 2 for i in shape]
113
130
  else:
114
131
  center_location_in_pixels = []
115
- for d in range(dim):
132
+ for d in range(0, 3):
116
133
  mn = self.patch_center_dist_from_border[d]
117
134
  mx = shape[d] - self.patch_center_dist_from_border[d]
118
135
  if mx < mn:
@@ -130,15 +147,11 @@ class SpatialTransform(BasicTransform):
130
147
  # No spatial transformation is being done. Round grid_center and crop without having to interpolate.
131
148
  # This saves compute.
132
149
  # cropping requires the center to be given as integer coordinates
133
- img = crop_tensor(img, [round(i) for i in params['center_location_in_pixels']], self.patch_size, pad_mode='constant',
150
+ img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']], self.patch_size, pad_mode='constant',
134
151
  pad_kwargs={'value': 0})
135
152
  return img
136
153
  else:
137
- grid = _create_identity_grid(self.patch_size)
138
-
139
- # the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
140
- grid_scale = torch.Tensor([i / j for i, j in zip(img.shape[1:], self.patch_size)])
141
- grid /= grid_scale
154
+ grid = _create_centered_identity_grid2(self.patch_size)
142
155
 
143
156
  # we deform first, then rotate
144
157
  if params['elastic_offsets'] is not None:
@@ -147,11 +160,16 @@ class SpatialTransform(BasicTransform):
147
160
  grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
148
161
 
149
162
  # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
150
- mn = grid.mean(dim=list(range(img.ndim - 1)))
151
- new_center = torch.Tensor(
152
- [(j / (i / 2) - 1) for i, j in zip(img.shape[1:], params['center_location_in_pixels'])])
153
- grid += - mn + new_center
154
- return grid_sample(img[None], grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)[0]
163
+ # only do this if we elastic deform
164
+ if params['elastic_offsets'] is not None:
165
+ mn = grid.mean(dim=list(range(img.ndim - 1)))
166
+ else:
167
+ mn = 0
168
+
169
+ new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
170
+ grid += (new_center - mn)
171
+ return grid_sample(img[None], _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
172
+ mode='bilinear', padding_mode="zeros", align_corners=False)[0]
155
173
 
156
174
  def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
157
175
  segmentation = segmentation.contiguous()
@@ -159,15 +177,14 @@ class SpatialTransform(BasicTransform):
159
177
  # No spatial transformation is being done. Round grid_center and crop without having to interpolate.
160
178
  # This saves compute.
161
179
  # cropping requires the center to be given as integer coordinates
162
- segmentation = crop_tensor(segmentation, [round(i) for i in params['center_location_in_pixels']], self.patch_size,
163
- pad_mode='constant', pad_kwargs={'value': 0})
180
+ segmentation = crop_tensor(segmentation,
181
+ [math.floor(i) for i in params['center_location_in_pixels']],
182
+ self.patch_size,
183
+ pad_mode='constant',
184
+ pad_kwargs={'value': 0})
164
185
  return segmentation
165
186
  else:
166
- grid = _create_identity_grid(self.patch_size)
167
-
168
- # the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
169
- grid_scale = torch.Tensor([i / j for i, j in zip(segmentation.shape[1:], self.patch_size)])
170
- grid /= grid_scale
187
+ grid = _create_centered_identity_grid2(self.patch_size)
171
188
 
172
189
  # we deform first, then rotate
173
190
  if params['elastic_offsets'] is not None:
@@ -176,10 +193,15 @@ class SpatialTransform(BasicTransform):
176
193
  grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
177
194
 
178
195
  # we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate
179
- mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
180
- new_center = torch.Tensor(
181
- [(j / (i / 2) - 1) for i, j in zip(segmentation.shape[1:], params['center_location_in_pixels'])])
182
- grid += - mn + new_center
196
+ if params['elastic_offsets'] is not None:
197
+ mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
198
+ else:
199
+ mn = 0
200
+
201
+ new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], segmentation.shape[1:])])
202
+
203
+ grid += (new_center - mn)
204
+ grid = _convert_my_grid_to_grid_sample_grid(grid, segmentation.shape[1:])
183
205
 
184
206
  if self.mode_seg == 'nearest':
185
207
  result_seg = grid_sample(
@@ -279,15 +301,38 @@ def create_affine_matrix_2d(rotation_angle, scaling_factors):
279
301
  return RS
280
302
 
281
303
 
282
- def _create_identity_grid(size: List[int]) -> Tensor:
283
- space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size[::-1]]
304
+ # def _create_identity_grid(size: List[int]) -> Tensor:
305
+ # space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size[::-1]]
306
+ # grid = torch.meshgrid(space, indexing="ij")
307
+ # grid = torch.stack(grid, -1)
308
+ # spatial_dims = list(range(len(size)))
309
+ # grid = grid.permute((*spatial_dims[::-1], len(size)))
310
+ # return grid
311
+
312
+
313
+ def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor:
314
+ space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size]
284
315
  grid = torch.meshgrid(space, indexing="ij")
285
316
  grid = torch.stack(grid, -1)
286
- spatial_dims = list(range(len(size)))
287
- grid = grid.permute((*spatial_dims[::-1], len(size)))
288
317
  return grid
289
318
 
290
319
 
320
+ def _convert_my_grid_to_grid_sample_grid(my_grid: torch.Tensor, original_shape: Union[Tuple[int, ...], List[int]]):
321
+ # rescale
322
+ for d in range(len(original_shape)):
323
+ s = original_shape[d]
324
+ my_grid[..., d] /= (s / 2)
325
+ my_grid = torch.flip(my_grid, (len(my_grid.shape) - 1, ))
326
+ # my_grid = my_grid.flip((len(my_grid.shape) - 1,))
327
+ return my_grid
328
+
329
+
330
+ # size = (4, 5, 6)
331
+ # grid_old = _create_identity_grid(size)
332
+ # grid_new = _create_centered_identity_grid2(size)
333
+ # grid_new_converted = _convert_my_grid_to_grid_sample_grid(grid_new, size)
334
+ # torch.all(torch.isclose(grid_new_converted, grid_old))
335
+
291
336
  # An alternative way of generating the displacement fieldQ
292
337
  # def displacement_field(data: torch.Tensor):
293
338
  # downscaling_global = np.random.uniform() ** 2 * 4 + 2
@@ -308,58 +353,190 @@ def _create_identity_grid(size: List[int]) -> Tensor:
308
353
 
309
354
 
310
355
  if __name__ == '__main__':
311
- torch.set_num_threads(1)
356
+ # torch.set_num_threads(1)
357
+ #
358
+ # shape = (128, 128, 128)
359
+ # patch_size = (128, 128, 128)
360
+ # labels = 2
361
+ #
362
+ #
363
+ # # seg = torch.rand([i // 32 for i in shape]) * labels
364
+ # # seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
365
+ # # decimals=0).to(torch.int16)
366
+ # # img = torch.ones((1, *shape))
367
+ # # img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
368
+ #
369
+ #
370
+ # import SimpleITK as sitk
371
+ # # img = camera()
372
+ # # seg = None
373
+ # img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/imagesTr/BraTS2021_00000_0000.nii.gz'))
374
+ # seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/labelsTr/BraTS2021_00000.nii.gz'))
375
+ #
376
+ # patch_size = (192, 192, 192)
377
+ # sp = SpatialTransform(
378
+ # patch_size=(192, 192, 192),
379
+ # patch_center_dist_from_border=[i / 2 for i in patch_size],
380
+ # random_crop=True,
381
+ # p_elastic_deform=0,
382
+ # elastic_deform_magnitude=(0.1, 0.1),
383
+ # elastic_deform_scale=(0.1, 0.1),
384
+ # p_synchronize_def_scale_across_axes=0.5,
385
+ # p_rotation=1,
386
+ # rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
387
+ # p_scaling=1,
388
+ # scaling=(0.75, 1),
389
+ # p_synchronize_scaling_across_axes=0.5,
390
+ # bg_style_seg_sampling=True,
391
+ # mode_seg='bilinear'
392
+ # )
393
+ #
394
+ # data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
395
+ # if seg is not None:
396
+ # data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
397
+ # # out = sp(**data_dict)
398
+ # #
399
+ # # view_batch(out['image'], out['segmentation'])
400
+ #
401
+ # from time import time
402
+ # times = []
403
+ # for _ in range(10):
404
+ # data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
405
+ # if seg is not None:
406
+ # data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
407
+ # st = time()
408
+ # out = sp(**data_dict)
409
+ # times.append(time() - st)
410
+ # print(np.median(times))
411
+
412
+
413
+ #################
414
+ # with this part we can qualitatively test that the correct axes are ebing augmented. Just set one of the probs to 1 and off you go
415
+ #################
416
+
417
+ def eldef_scale(image, dim, patch_size):
418
+ return 0.1
312
419
 
313
- shape = (128, 128, 128)
314
- patch_size = (128, 128, 128)
315
- labels = 2
420
+ def eldef_magnitude(image, dim, patch_size, deformation_scale):
421
+ return 10 if dim == 2 else 0
316
422
 
423
+ def rot(image, dim):
424
+ return 45/360 * 2 * np.pi if dim == 0 else 0
317
425
 
318
- # seg = torch.rand([i // 32 for i in shape]) * labels
319
- # seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
320
- # decimals=0).to(torch.int16)
321
- # img = torch.ones((1, *shape))
322
- # img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
426
+ def scaling(image, dim):
427
+ return 0.5 if dim == 0 else 1
323
428
 
429
+ # lines
430
+ patch = torch.zeros((1, 64, 60, 68))
431
+ patch[:, :, 10, 30] = 1
432
+ patch[:, 50, :, 30] = 1
433
+ patch[:, 40, 20, :] = 1
324
434
 
325
- import SimpleITK as sitk
326
- # img = camera()
327
- # seg = None
328
- img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/imagesTr/BraTS2021_00000_0000.nii.gz'))
329
- seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/labelsTr/BraTS2021_00000.nii.gz'))
435
+ # patch_block
436
+ patch_block = torch.zeros((1, 64, 60, 68))
437
+ patch_block[:, 22:42, 20:40, 24:44] = 1
438
+
439
+ patch_line = torch.zeros((1, 64, 60, 128))
440
+ patch_line[:, 22:24, 30:32, 10:-10] = 1
441
+ use = patch_line
330
442
 
331
- patch_size = (192, 192, 192)
332
443
  sp = SpatialTransform(
333
- patch_size=(192, 192, 192),
334
- patch_center_dist_from_border=[i / 2 for i in patch_size],
335
- random_crop=True,
444
+ patch_size=patch.shape[1:],
445
+ patch_center_dist_from_border=0,
446
+ random_crop=False,
336
447
  p_elastic_deform=0,
337
- elastic_deform_magnitude=(0.1, 0.1),
338
- elastic_deform_scale=(0.1, 0.1),
339
- p_synchronize_def_scale_across_axes=0.5,
340
448
  p_rotation=1,
341
- rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
342
- p_scaling=1,
343
- scaling=(0.75, 1),
344
- p_synchronize_scaling_across_axes=0.5,
345
- bg_style_seg_sampling=True,
449
+ p_scaling=0,
450
+ elastic_deform_scale=eldef_scale,
451
+ elastic_deform_magnitude=eldef_magnitude,
452
+ p_synchronize_def_scale_across_axes=0,
453
+ rotation=rot,
454
+ scaling=scaling,
455
+ p_synchronize_scaling_across_axes=0,
456
+ bg_style_seg_sampling=False,
346
457
  mode_seg='bilinear'
347
458
  )
348
459
 
349
- data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
350
- if seg is not None:
351
- data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
352
- # out = sp(**data_dict)
460
+
461
+ SimpleITK.WriteImage(SimpleITK.GetImageFromArray(use[0].numpy()), 'orig.nii.gz')
462
+
463
+ params = sp.get_parameters(image=use)
464
+ transformed = sp._apply_to_image(use, **params)
465
+
466
+ SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
467
+
468
+ # p = torch.zeros((1, 1, 8, 16, 32))
469
+ # p[:, :, 2:6, 10:16, 10:24] = 1
470
+ # grid = _create_identity_grid(p.shape[2:])
471
+ # grid[:, :, :, 0] *= 0.5
472
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
473
+ # torch.all(out == p)
474
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
475
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
476
+
477
+ #################
478
+ # with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
479
+ #################
480
+
481
+ # sp = SpatialTransform(
482
+ # patch_size=(48, 52, 54),
483
+ # patch_center_dist_from_border=0,
484
+ # random_crop=True,
485
+ # p_elastic_deform=0,
486
+ # p_rotation=1,
487
+ # p_scaling=0,
488
+ # rotation=0
489
+ # )
490
+ # sp2 = SpatialTransform(
491
+ # patch_size=(48, 52, 54),
492
+ # patch_center_dist_from_border=0,
493
+ # random_crop=True,
494
+ # p_elastic_deform=0,
495
+ # p_rotation=0,
496
+ # p_scaling=0,
497
+ # )
498
+ #
499
+ # patch = torch.zeros((1, 64, 60, 68))
500
+ # patch[:, :, 10, 30] = 1
501
+ # patch[:, 50, :, 30] = 1
502
+ # patch[:, 40, 20, :] = 1
503
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
504
+ #
505
+ # center_coords = [50, 10, 16]
506
+ # params = sp.get_parameters(image=patch)
507
+ # params['center_location_in_pixels'] = center_coords
508
+ # params2 = sp2.get_parameters(image=patch)
509
+ # params2['center_location_in_pixels'] = center_coords
510
+ # transformed = sp._apply_to_image(patch, **params)
511
+ # transformed2 = sp._apply_to_image(patch, **params2)
353
512
  #
354
- # view_batch(out['image'], out['segmentation'])
355
-
356
- from time import time
357
- times = []
358
- for _ in range(10):
359
- data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
360
- if seg is not None:
361
- data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
362
- st = time()
363
- out = sp(**data_dict)
364
- times.append(time() - st)
365
- print(np.median(times))
513
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
514
+ # SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
515
+
516
+
517
+
518
+ ####################
519
+ # This is exploraroty code to check how to retrieve coordinates. I used it to verify that grid_sample does in fact
520
+ # use coordinates in reversed dimension order (zyx and not xyz)
521
+ ####################
522
+ # # create a dummy input which has a unique shape in each exis
523
+ # p = torch.zeros((1, 1, 8, 16, 32))
524
+ # # set one pixel to 1
525
+ # p[:, :, 4, 0, 31] = 1
526
+ # # now create an identity grid. I have verified that this grid yields the same image as the input when used in grid_sample. So the grid is correct
527
+ # grid = _create_identity_grid((8, 16, 32)).contiguous() # grid is shape torch.Size([8, 16, 32, 3])
528
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
529
+ # assert torch.all(out == p) # this passes
530
+ # # reduce the grid to the location we are interested in. That are the coordinates where we placed the 1. The 4:5 etc is only so that we keep the number of dimensions
531
+ # grid = grid[4:5, 0:1, 31:32]
532
+ # # What coordinate would we expect? Note that grid is [-1, 1]
533
+ # # For the first dimension, coordinate 4 out of shape 8 is approximately in the middle, so about 0
534
+ # # For the second dimension, coordinate 0 out of shape 16 is very low, so we expect -1 ish (remember there is aligned corners and shit)
535
+ # # For the third dimension, coordinate 31 out of shape 32 is very high, so we expect 1 ish (remember there is aligned corners and shit)
536
+ # # So we expect [0, -1, 1]
537
+ # # What do we get?
538
+ # print(grid)
539
+ # # > tensor([[[[ 0.9688, -0.9375, 0.1250]]]])
540
+ # # not what we expect
541
+ # out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
542
+ # assert out.item() == 1
@@ -16,4 +16,8 @@ class RandomTransform(BasicTransform):
16
16
  if params['apply_transform']:
17
17
  return self.transform(**data_dict)
18
18
  else:
19
- return data_dict
19
+ return data_dict
20
+
21
+ def __repr__(self):
22
+ ret_str = f"{type(self).__name__}(p={self.apply_probability}, transform={self.transform})"
23
+ return ret_str
@@ -0,0 +1,24 @@
1
+ from typing import Union, List, Tuple
2
+ import torch
3
+
4
+ from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
5
+
6
+
7
+ class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
8
+ def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
9
+ super().__init__()
10
+ self.regions = [torch.Tensor(i) if not isinstance(i, int) else torch.Tensor([i]) for i in regions]
11
+ self.channel_in_seg = channel_in_seg
12
+
13
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
14
+ num_regions = len(self.regions)
15
+ region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
16
+ for region_id, region_labels in enumerate(self.regions):
17
+ if len(region_labels) == 1:
18
+ region_output[region_id] = segmentation[self.channel_in_seg] == region_labels
19
+ else:
20
+ region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels)
21
+ # we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to
22
+ # device followed by cast on device should be faster than having fp32 here and transferring that
23
+ return region_output
24
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: batchgeneratorsv2
3
- Version: 0.1.1
3
+ Version: 0.2.1
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "batchgeneratorsv2"
3
- version = "0.1.1"
3
+ version = "0.2.1"
4
4
  requires-python = ">=3.9"
5
5
  description = "Batchgenerators but better"
6
6
  readme = "readme.md"
@@ -1,23 +0,0 @@
1
- from typing import Union, List, Tuple
2
-
3
- import torch
4
-
5
- from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
6
-
7
-
8
- class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
9
- def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
10
- super().__init__()
11
- self.regions = regions
12
- self.channel_in_seg = channel_in_seg
13
-
14
- def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
15
- num_regions = len(self.regions)
16
- region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
17
- if isinstance(region_labels, int) or len(region_labels) == 1:
18
- if not isinstance(region_labels, int):
19
- region_labels = region_labels[0]
20
- region_output[:, region_id] = seg[:, self.seg_channel] == region_labels
21
- else:
22
- region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
23
- return region_output.to(segmentation.dtype)