batchgeneratorsv2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2/helpers/fft_conv.py +149 -0
- batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
- batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
- batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
- batchgeneratorsv2/transforms/local/__init__.py +0 -0
- batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
- batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
- batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
- batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
- batchgeneratorsv2/transforms/local/local_transform.py +86 -0
- batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
- batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
- batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
- batchgeneratorsv2/transforms/noise/rician.py +61 -0
- batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
- batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
- batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
- batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
- batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
- batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
- batchgeneratorsv2/transforms/utils/random.py +46 -0
- batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
- batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
- batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
- batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
- batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
- batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import Tuple, List, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
from scipy.ndimage import fourier_gaussian
|
|
9
|
+
from torch.nn.functional import grid_sample
|
|
10
|
+
|
|
11
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
12
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
13
|
+
from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SpatialTransform(BasicTransform):
|
|
17
|
+
def __init__(self,
|
|
18
|
+
patch_size: Tuple[int, ...],
|
|
19
|
+
patch_center_dist_from_border: Union[int, List[int], Tuple[int, ...]],
|
|
20
|
+
random_crop: bool,
|
|
21
|
+
p_elastic_deform: float = 0,
|
|
22
|
+
elastic_deform_scale: RandomScalar = (0, 0.2),
|
|
23
|
+
elastic_deform_magnitude: RandomScalar = (0, 0.2),
|
|
24
|
+
p_synchronize_def_scale_across_axes: float = 0,
|
|
25
|
+
p_rotation: float = 0,
|
|
26
|
+
rotation: RandomScalar = (0, 2 * np.pi),
|
|
27
|
+
p_rot_per_axis: float = 1,
|
|
28
|
+
p_scaling: float = 0,
|
|
29
|
+
scaling: RandomScalar = (0.7, 1.3),
|
|
30
|
+
p_synchronize_scaling_across_axes: float = 0,
|
|
31
|
+
bg_style_seg_sampling: bool = True,
|
|
32
|
+
mode_seg: str = 'bilinear',
|
|
33
|
+
border_mode_seg: str = "zeros",
|
|
34
|
+
center_deformation: bool = True,
|
|
35
|
+
mode_image: str = 'bilinear',
|
|
36
|
+
padding_mode_image: str = "zeros",
|
|
37
|
+
padding_value_seg: float = 0,
|
|
38
|
+
padding_value_image: float = 0,
|
|
39
|
+
align_corners: bool = False
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
magnitude must be given in pixels!
|
|
43
|
+
deformation scale is given as a paercentage of the edge length
|
|
44
|
+
|
|
45
|
+
padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target
|
|
46
|
+
because both call self._apply_to_image. Can be "zeros", "constant", "reflection", "border"
|
|
47
|
+
|
|
48
|
+
border_mode_seg: can be "zeros", "constant", "reflection", "border". padding values are only considered for
|
|
49
|
+
the corresponding "constant" modes.
|
|
50
|
+
"""
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.patch_size = patch_size
|
|
53
|
+
if not isinstance(patch_center_dist_from_border, (tuple, list)):
|
|
54
|
+
patch_center_dist_from_border = [patch_center_dist_from_border] * len(patch_size)
|
|
55
|
+
self.patch_center_dist_from_border = patch_center_dist_from_border
|
|
56
|
+
self.random_crop = random_crop
|
|
57
|
+
self.p_elastic_deform = p_elastic_deform
|
|
58
|
+
self.elastic_deform_scale = elastic_deform_scale # sigma for blurring offsets, in % of patch size. Larger values mean coarser deformation
|
|
59
|
+
self.elastic_deform_magnitude = elastic_deform_magnitude # determines the maximum displacement, measured in pixels!!
|
|
60
|
+
self.p_rotation = p_rotation
|
|
61
|
+
self.rotation = rotation
|
|
62
|
+
self.p_rot_per_axis = p_rot_per_axis
|
|
63
|
+
self.p_scaling = p_scaling
|
|
64
|
+
self.scaling = scaling # larger numbers = smaller objects!
|
|
65
|
+
self.p_synchronize_scaling_across_axes = p_synchronize_scaling_across_axes
|
|
66
|
+
self.p_synchronize_def_scale_across_axes = p_synchronize_def_scale_across_axes
|
|
67
|
+
self.bg_style_seg_sampling = bg_style_seg_sampling
|
|
68
|
+
self.mode_seg = mode_seg
|
|
69
|
+
self.border_mode_seg = border_mode_seg
|
|
70
|
+
self.center_deformation = center_deformation
|
|
71
|
+
self.mode_image = mode_image
|
|
72
|
+
self.padding_mode_image = padding_mode_image
|
|
73
|
+
self.padding_value_seg = padding_value_seg
|
|
74
|
+
self.padding_value_image = padding_value_image
|
|
75
|
+
self.align_corners = align_corners
|
|
76
|
+
self._grid_cache = {} # key: (patch_size, dtype) -> base grid tensor
|
|
77
|
+
|
|
78
|
+
def _get_base_grid_clone(self) -> torch.Tensor:
|
|
79
|
+
key = tuple(self.patch_size)
|
|
80
|
+
g = self._grid_cache.get(key)
|
|
81
|
+
if g is None:
|
|
82
|
+
g = _create_centered_identity_grid2(self.patch_size).float().contiguous()
|
|
83
|
+
self._grid_cache[key] = g
|
|
84
|
+
return g.clone()
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _get_crop_pad_settings(padding_mode: str, padding_value: float):
|
|
88
|
+
if padding_mode == 'reflection':
|
|
89
|
+
return 'reflect', {}
|
|
90
|
+
if padding_mode == 'border':
|
|
91
|
+
return 'replicate', {}
|
|
92
|
+
if padding_mode == 'zeros':
|
|
93
|
+
return 'constant', {'value': 0}
|
|
94
|
+
if padding_mode == 'constant':
|
|
95
|
+
return 'constant', {'value': padding_value}
|
|
96
|
+
raise RuntimeError(f'Unknown pad mode: {padding_mode}')
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _get_grid_sample_padding_mode(padding_mode: str) -> str:
|
|
100
|
+
if padding_mode in ('zeros', 'constant'):
|
|
101
|
+
return 'zeros'
|
|
102
|
+
if padding_mode in ('border', 'reflection'):
|
|
103
|
+
return padding_mode
|
|
104
|
+
raise RuntimeError(f'Unknown pad mode: {padding_mode}')
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _requires_constant_padding_fixup(padding_mode: str, padding_value: float) -> bool:
|
|
108
|
+
return padding_mode == 'constant' and padding_value != 0
|
|
109
|
+
|
|
110
|
+
def _compute_out_of_bounds_mask(self, grid: torch.Tensor, spatial_shape: Tuple[int, ...]) -> torch.Tensor:
|
|
111
|
+
if self.align_corners:
|
|
112
|
+
lo = grid.new_tensor(-1.)
|
|
113
|
+
hi = grid.new_tensor(1.)
|
|
114
|
+
else:
|
|
115
|
+
size = grid.new_tensor(spatial_shape)
|
|
116
|
+
lo = -1 + 1 / size
|
|
117
|
+
hi = 1 - 1 / size
|
|
118
|
+
return ((grid < lo) | (grid > hi)).any(dim=-1)
|
|
119
|
+
|
|
120
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
121
|
+
dim = data_dict['image'].ndim - 1
|
|
122
|
+
|
|
123
|
+
do_rotation = np.random.uniform() < self.p_rotation
|
|
124
|
+
do_scale = np.random.uniform() < self.p_scaling
|
|
125
|
+
do_deform = np.random.uniform() < self.p_elastic_deform
|
|
126
|
+
|
|
127
|
+
if do_rotation:
|
|
128
|
+
angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(0, dim)]
|
|
129
|
+
if self.p_rot_per_axis < 1:
|
|
130
|
+
for i in range(dim):
|
|
131
|
+
if np.random.uniform() > self.p_rot_per_axis:
|
|
132
|
+
angles[i] = 0
|
|
133
|
+
else:
|
|
134
|
+
angles = [0] * dim
|
|
135
|
+
if do_scale:
|
|
136
|
+
if np.random.uniform() <= self.p_synchronize_scaling_across_axes:
|
|
137
|
+
scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
|
|
138
|
+
else:
|
|
139
|
+
scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(0, dim)]
|
|
140
|
+
else:
|
|
141
|
+
scales = [1] * dim
|
|
142
|
+
|
|
143
|
+
# affine matrix
|
|
144
|
+
if do_scale or do_rotation:
|
|
145
|
+
if dim == 3:
|
|
146
|
+
affine = create_affine_matrix_3d(angles, scales)
|
|
147
|
+
elif dim == 2:
|
|
148
|
+
affine = create_affine_matrix_2d(angles[-1], scales)
|
|
149
|
+
else:
|
|
150
|
+
raise RuntimeError(f'Unsupported dimension: {dim}')
|
|
151
|
+
else:
|
|
152
|
+
affine = None # this will allow us to detect that we can skip computations
|
|
153
|
+
|
|
154
|
+
# elastic deformation. We need to create the displacement field here
|
|
155
|
+
# we use the method from augment_spatial_2 in batchgenerators
|
|
156
|
+
if do_deform:
|
|
157
|
+
if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
|
|
158
|
+
deformation_scales = [
|
|
159
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None,
|
|
160
|
+
patch_size=self.patch_size)
|
|
161
|
+
] * dim
|
|
162
|
+
else:
|
|
163
|
+
deformation_scales = [
|
|
164
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i,
|
|
165
|
+
patch_size=self.patch_size)
|
|
166
|
+
for i in range(dim)
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
# sigmas must be in pixels, as this will be applied to the deformation field
|
|
170
|
+
sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
|
|
171
|
+
|
|
172
|
+
magnitude = [
|
|
173
|
+
sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
|
|
174
|
+
dim=i, deformation_scale=deformation_scales[i])
|
|
175
|
+
for i in range(dim)]
|
|
176
|
+
# doing it like this for better memory layout for blurring
|
|
177
|
+
offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
|
|
178
|
+
|
|
179
|
+
# all the additional time elastic deform takes is spent here
|
|
180
|
+
for d in range(dim):
|
|
181
|
+
# fft torch, slower
|
|
182
|
+
# for i in range(offsets.ndim - 1):
|
|
183
|
+
# offsets[d] = blur_dimension(offsets[d][None], sigmas[d], i, force_use_fft=True, truncate=6)[0]
|
|
184
|
+
|
|
185
|
+
# fft numpy, this is faster o.O
|
|
186
|
+
tmp = np.fft.fftn(offsets[d].numpy())
|
|
187
|
+
tmp = fourier_gaussian(tmp, sigmas[d])
|
|
188
|
+
offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
|
|
189
|
+
|
|
190
|
+
# tmp = offsets[d].numpy().astype(np.float64)
|
|
191
|
+
# gaussian_filter(tmp, sigmas[d], 0, output=tmp)
|
|
192
|
+
# offsets[d] = torch.from_numpy(tmp).to(offsets.dtype)
|
|
193
|
+
# print(offsets.dtype)
|
|
194
|
+
|
|
195
|
+
mx = torch.max(torch.abs(offsets[d]))
|
|
196
|
+
offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
|
|
197
|
+
spatial_dims = tuple(list(range(1, dim + 1)))
|
|
198
|
+
offsets = torch.permute(offsets, (*spatial_dims, 0))
|
|
199
|
+
else:
|
|
200
|
+
offsets = None
|
|
201
|
+
|
|
202
|
+
shape = data_dict['image'].shape[1:]
|
|
203
|
+
if not self.random_crop:
|
|
204
|
+
center_location_in_pixels = [i / 2 for i in shape]
|
|
205
|
+
else:
|
|
206
|
+
center_location_in_pixels = []
|
|
207
|
+
for d in range(0, dim):
|
|
208
|
+
mn = self.patch_center_dist_from_border[d]
|
|
209
|
+
mx = shape[d] - self.patch_center_dist_from_border[d]
|
|
210
|
+
if mx < mn:
|
|
211
|
+
center_location_in_pixels.append(shape[d] / 2)
|
|
212
|
+
else:
|
|
213
|
+
center_location_in_pixels.append(np.random.uniform(mn, mx))
|
|
214
|
+
# Precompute the deformed grid once (shared by image, segmentation, regression target)
|
|
215
|
+
if affine is not None or offsets is not None:
|
|
216
|
+
grid = self._get_base_grid_clone()
|
|
217
|
+
|
|
218
|
+
# we deform first, then rotate
|
|
219
|
+
if offsets is not None:
|
|
220
|
+
grid += offsets
|
|
221
|
+
if affine is not None:
|
|
222
|
+
grid = torch.matmul(grid, torch.from_numpy(affine).float())
|
|
223
|
+
|
|
224
|
+
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
|
|
225
|
+
# only do this if we elastic deform
|
|
226
|
+
if self.center_deformation and offsets is not None:
|
|
227
|
+
mn = grid.mean(dim=list(range(len(shape))))
|
|
228
|
+
else:
|
|
229
|
+
mn = 0
|
|
230
|
+
|
|
231
|
+
new_center = torch.Tensor([c - s / 2 for c, s in zip(center_location_in_pixels, shape)])
|
|
232
|
+
grid += (new_center - mn)
|
|
233
|
+
grid = _convert_my_grid_to_grid_sample_grid(grid, shape)
|
|
234
|
+
else:
|
|
235
|
+
grid = None
|
|
236
|
+
|
|
237
|
+
return {
|
|
238
|
+
'center_location_in_pixels': center_location_in_pixels,
|
|
239
|
+
'grid': grid,
|
|
240
|
+
# we don't need them but we keep them so that we can debug better
|
|
241
|
+
'affine': affine,
|
|
242
|
+
'elastic_offsets': offsets,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
246
|
+
if params['grid'] is None:
|
|
247
|
+
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
|
|
248
|
+
# This saves compute.
|
|
249
|
+
# cropping requires the center to be given as integer coordinates
|
|
250
|
+
pad_mode, pad_kwargs = self._get_crop_pad_settings(self.padding_mode_image, self.padding_value_image)
|
|
251
|
+
return crop_tensor(
|
|
252
|
+
img,
|
|
253
|
+
[math.floor(i) for i in params['center_location_in_pixels']],
|
|
254
|
+
self.patch_size,
|
|
255
|
+
pad_mode=pad_mode,
|
|
256
|
+
pad_kwargs=pad_kwargs,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
grid = params['grid']
|
|
260
|
+
result = grid_sample(
|
|
261
|
+
img[None],
|
|
262
|
+
grid[None],
|
|
263
|
+
mode=self.mode_image,
|
|
264
|
+
padding_mode=self._get_grid_sample_padding_mode(self.padding_mode_image),
|
|
265
|
+
align_corners=self.align_corners,
|
|
266
|
+
)[0]
|
|
267
|
+
if self._requires_constant_padding_fixup(self.padding_mode_image, self.padding_value_image):
|
|
268
|
+
out_of_bounds_mask = self._compute_out_of_bounds_mask(grid, img.shape[1:])
|
|
269
|
+
result.masked_fill_(out_of_bounds_mask.unsqueeze(0), self.padding_value_image)
|
|
270
|
+
return result
|
|
271
|
+
|
|
272
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
273
|
+
segmentation = segmentation.contiguous()
|
|
274
|
+
if params['grid'] is None:
|
|
275
|
+
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
|
|
276
|
+
# This saves compute.
|
|
277
|
+
# cropping requires the center to be given as integer coordinates
|
|
278
|
+
pad_mode, pad_kwargs = self._get_crop_pad_settings(self.border_mode_seg, self.padding_value_seg)
|
|
279
|
+
return crop_tensor(
|
|
280
|
+
segmentation,
|
|
281
|
+
[math.floor(i) for i in params['center_location_in_pixels']],
|
|
282
|
+
self.patch_size,
|
|
283
|
+
pad_mode=pad_mode,
|
|
284
|
+
pad_kwargs=pad_kwargs,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
grid = params['grid']
|
|
288
|
+
grid_sample_padding_mode = self._get_grid_sample_padding_mode(self.border_mode_seg)
|
|
289
|
+
|
|
290
|
+
if self.mode_seg == 'nearest':
|
|
291
|
+
result_seg = grid_sample(
|
|
292
|
+
segmentation[None].float(),
|
|
293
|
+
grid[None],
|
|
294
|
+
mode=self.mode_seg,
|
|
295
|
+
padding_mode=grid_sample_padding_mode,
|
|
296
|
+
align_corners=self.align_corners
|
|
297
|
+
)[0].to(segmentation.dtype)
|
|
298
|
+
else:
|
|
299
|
+
result_seg = torch.zeros((segmentation.shape[0], *self.patch_size), dtype=segmentation.dtype)
|
|
300
|
+
if self.bg_style_seg_sampling:
|
|
301
|
+
for c in range(segmentation.shape[0]):
|
|
302
|
+
labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
|
|
303
|
+
# if we only have 2 labels then we can save compute time
|
|
304
|
+
if len(labels) == 2:
|
|
305
|
+
out = grid_sample(
|
|
306
|
+
((segmentation[c] == labels[1]).float())[None, None],
|
|
307
|
+
grid[None],
|
|
308
|
+
mode=self.mode_seg,
|
|
309
|
+
padding_mode=grid_sample_padding_mode,
|
|
310
|
+
align_corners=self.align_corners
|
|
311
|
+
)[0][0] >= 0.5
|
|
312
|
+
result_seg[c][out] = labels[1]
|
|
313
|
+
result_seg[c][~out] = labels[0]
|
|
314
|
+
else:
|
|
315
|
+
for i, u in enumerate(labels):
|
|
316
|
+
result_seg[c][
|
|
317
|
+
grid_sample(
|
|
318
|
+
((segmentation[c] == u).float())[None, None],
|
|
319
|
+
grid[None],
|
|
320
|
+
mode=self.mode_seg,
|
|
321
|
+
padding_mode=grid_sample_padding_mode,
|
|
322
|
+
align_corners=self.align_corners
|
|
323
|
+
)[0][0] >= 0.5] = u
|
|
324
|
+
else:
|
|
325
|
+
for c in range(segmentation.shape[0]):
|
|
326
|
+
labels = torch.from_numpy(np.sort(pd.unique(segmentation[c].numpy().ravel())))
|
|
327
|
+
# torch.where(torch.bincount(segmentation.ravel()) > 0)[0].to(segmentation.dtype)
|
|
328
|
+
tmp = torch.zeros((len(labels), *self.patch_size), dtype=torch.float16)
|
|
329
|
+
scale_factor = 1000
|
|
330
|
+
done_mask = torch.zeros(*self.patch_size, dtype=torch.bool)
|
|
331
|
+
for i, u in enumerate(labels):
|
|
332
|
+
tmp[i] = grid_sample(
|
|
333
|
+
((segmentation[c] == u).float() * scale_factor)[None, None],
|
|
334
|
+
grid[None],
|
|
335
|
+
mode=self.mode_seg,
|
|
336
|
+
padding_mode=grid_sample_padding_mode,
|
|
337
|
+
align_corners=self.align_corners
|
|
338
|
+
)[0][0]
|
|
339
|
+
mask = tmp[i] > (0.7 * scale_factor)
|
|
340
|
+
result_seg[c][mask] = u
|
|
341
|
+
done_mask = done_mask | mask
|
|
342
|
+
if not torch.all(done_mask):
|
|
343
|
+
result_seg[c][~done_mask] = labels[tmp[:, ~done_mask].argmax(0)]
|
|
344
|
+
del tmp
|
|
345
|
+
|
|
346
|
+
if self._requires_constant_padding_fixup(self.border_mode_seg, self.padding_value_seg):
|
|
347
|
+
out_of_bounds_mask = self._compute_out_of_bounds_mask(grid, segmentation.shape[1:])
|
|
348
|
+
result_seg.masked_fill_(out_of_bounds_mask.unsqueeze(0), self.padding_value_seg)
|
|
349
|
+
del grid
|
|
350
|
+
return result_seg.contiguous()
|
|
351
|
+
|
|
352
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
353
|
+
return self._apply_to_image(regression_target, **params)
|
|
354
|
+
|
|
355
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
356
|
+
raise NotImplementedError
|
|
357
|
+
|
|
358
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
359
|
+
raise NotImplementedError
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def create_affine_matrix_3d(rotation_angles, scaling_factors):
|
|
363
|
+
# Rotation matrices for each axis
|
|
364
|
+
Rx = np.array([[1, 0, 0],
|
|
365
|
+
[0, np.cos(rotation_angles[0]), -np.sin(rotation_angles[0])],
|
|
366
|
+
[0, np.sin(rotation_angles[0]), np.cos(rotation_angles[0])]])
|
|
367
|
+
|
|
368
|
+
Ry = np.array([[np.cos(rotation_angles[1]), 0, np.sin(rotation_angles[1])],
|
|
369
|
+
[0, 1, 0],
|
|
370
|
+
[-np.sin(rotation_angles[1]), 0, np.cos(rotation_angles[1])]])
|
|
371
|
+
|
|
372
|
+
Rz = np.array([[np.cos(rotation_angles[2]), -np.sin(rotation_angles[2]), 0],
|
|
373
|
+
[np.sin(rotation_angles[2]), np.cos(rotation_angles[2]), 0],
|
|
374
|
+
[0, 0, 1]])
|
|
375
|
+
|
|
376
|
+
# Scaling matrix
|
|
377
|
+
S = np.diag(scaling_factors)
|
|
378
|
+
|
|
379
|
+
# Combine rotation and scaling
|
|
380
|
+
RS = Rz @ Ry @ Rx @ S
|
|
381
|
+
return RS
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def create_affine_matrix_2d(rotation_angle, scaling_factors):
|
|
385
|
+
# Rotation matrix
|
|
386
|
+
R = np.array([[np.cos(rotation_angle), -np.sin(rotation_angle)],
|
|
387
|
+
[np.sin(rotation_angle), np.cos(rotation_angle)]])
|
|
388
|
+
|
|
389
|
+
# Scaling matrix
|
|
390
|
+
S = np.diag(scaling_factors)
|
|
391
|
+
|
|
392
|
+
# Combine rotation and scaling
|
|
393
|
+
RS = R @ S
|
|
394
|
+
return RS
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor:
|
|
398
|
+
space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size]
|
|
399
|
+
grid = torch.meshgrid(space, indexing="ij")
|
|
400
|
+
grid = torch.stack(grid, -1)
|
|
401
|
+
return grid
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _convert_my_grid_to_grid_sample_grid(my_grid: torch.Tensor, original_shape: Union[Tuple[int, ...], List[int]]):
|
|
405
|
+
# rescale
|
|
406
|
+
for d in range(len(original_shape)):
|
|
407
|
+
s = original_shape[d]
|
|
408
|
+
my_grid[..., d] /= (s / 2)
|
|
409
|
+
my_grid = torch.flip(my_grid, (len(my_grid.shape) - 1,))
|
|
410
|
+
# my_grid = my_grid.flip((len(my_grid.shape) - 1,))
|
|
411
|
+
return my_grid
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
if __name__ == '__main__':
|
|
415
|
+
torch.set_num_threads(1)
|
|
416
|
+
|
|
417
|
+
shape = (128, 128, 128)
|
|
418
|
+
patch_size = (128, 128, 128)
|
|
419
|
+
labels = 2
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
# seg = torch.rand([i // 32 for i in shape]) * labels
|
|
423
|
+
# seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
|
|
424
|
+
# decimals=0).to(torch.int16)
|
|
425
|
+
# img = torch.ones((1, *shape))
|
|
426
|
+
# img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
import SimpleITK as sitk
|
|
430
|
+
# img = camera()
|
|
431
|
+
# seg = None
|
|
432
|
+
img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset226_BraTS2024-BraTS-GLI/imagesTr/BraTS-GLI-00005-100_0001.nii.gz'))
|
|
433
|
+
seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset226_BraTS2024-BraTS-GLI/labelsTr/BraTS-GLI-00005-100.nii.gz'))
|
|
434
|
+
|
|
435
|
+
patch_size = (192, 192, 192)
|
|
436
|
+
sp = SpatialTransform(
|
|
437
|
+
patch_size=(192, 192, 192),
|
|
438
|
+
patch_center_dist_from_border=[i / 2 for i in patch_size],
|
|
439
|
+
random_crop=True,
|
|
440
|
+
p_elastic_deform=0,
|
|
441
|
+
elastic_deform_magnitude=(0.1, 0.1),
|
|
442
|
+
elastic_deform_scale=(0.1, 0.1),
|
|
443
|
+
p_synchronize_def_scale_across_axes=0.5,
|
|
444
|
+
p_rotation=1,
|
|
445
|
+
rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
|
|
446
|
+
p_scaling=1,
|
|
447
|
+
scaling=(0.75, 1),
|
|
448
|
+
p_synchronize_scaling_across_axes=0.5,
|
|
449
|
+
bg_style_seg_sampling=True,
|
|
450
|
+
mode_seg='bilinear'
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
|
|
454
|
+
if seg is not None:
|
|
455
|
+
data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
|
|
456
|
+
# out = sp(**data_dict)
|
|
457
|
+
#
|
|
458
|
+
# view_batch(out['image'], out['segmentation'])
|
|
459
|
+
|
|
460
|
+
from time import time
|
|
461
|
+
times = []
|
|
462
|
+
for _ in range(10):
|
|
463
|
+
data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
|
|
464
|
+
if seg is not None:
|
|
465
|
+
data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
|
|
466
|
+
st = time()
|
|
467
|
+
out = sp(**data_dict)
|
|
468
|
+
times.append(time() - st)
|
|
469
|
+
print(np.median(times))
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
#################
|
|
473
|
+
# with this part we can qualitatively test that the correct axes are ebing augmented. Just set one of the probs to 1 and off you go
|
|
474
|
+
#################
|
|
475
|
+
|
|
476
|
+
# def eldef_scale(image, dim, patch_size):
|
|
477
|
+
# return 0.1
|
|
478
|
+
#
|
|
479
|
+
# def eldef_magnitude(image, dim, patch_size, deformation_scale):
|
|
480
|
+
# return 10 if dim == 2 else 0
|
|
481
|
+
#
|
|
482
|
+
# def rot(image, dim):
|
|
483
|
+
# return 45/360 * 2 * np.pi if dim == 0 else 0
|
|
484
|
+
#
|
|
485
|
+
# def scaling(image, dim):
|
|
486
|
+
# return 0.5 if dim == 0 else 1
|
|
487
|
+
#
|
|
488
|
+
# # lines
|
|
489
|
+
# patch = torch.zeros((1, 64, 60, 68))
|
|
490
|
+
# patch[:, :, 10, 30] = 1
|
|
491
|
+
# patch[:, 50, :, 30] = 1
|
|
492
|
+
# patch[:, 40, 20, :] = 1
|
|
493
|
+
#
|
|
494
|
+
# # patch_block
|
|
495
|
+
# patch_block = torch.zeros((1, 64, 60, 68))
|
|
496
|
+
# patch_block[:, 22:42, 20:40, 24:44] = 1
|
|
497
|
+
#
|
|
498
|
+
# patch_line = torch.zeros((1, 64, 60, 128))
|
|
499
|
+
# patch_line[:, 22:24, 30:32, 10:-10] = 1
|
|
500
|
+
# use = patch_line
|
|
501
|
+
#
|
|
502
|
+
# sp = SpatialTransform(
|
|
503
|
+
# patch_size=patch.shape[1:],
|
|
504
|
+
# patch_center_dist_from_border=0,
|
|
505
|
+
# random_crop=False,
|
|
506
|
+
# p_elastic_deform=0,
|
|
507
|
+
# p_rotation=1,
|
|
508
|
+
# p_scaling=0,
|
|
509
|
+
# elastic_deform_scale=eldef_scale,
|
|
510
|
+
# elastic_deform_magnitude=eldef_magnitude,
|
|
511
|
+
# p_synchronize_def_scale_across_axes=0,
|
|
512
|
+
# rotation=rot,
|
|
513
|
+
# scaling=scaling,
|
|
514
|
+
# p_synchronize_scaling_across_axes=0,
|
|
515
|
+
# bg_style_seg_sampling=False,
|
|
516
|
+
# mode_seg='bilinear'
|
|
517
|
+
# )
|
|
518
|
+
#
|
|
519
|
+
#
|
|
520
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(use[0].numpy()), 'orig.nii.gz')
|
|
521
|
+
#
|
|
522
|
+
# params = sp.get_parameters(image=use)
|
|
523
|
+
# transformed = sp._apply_to_image(use, **params)
|
|
524
|
+
#
|
|
525
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
|
|
526
|
+
|
|
527
|
+
# p = torch.zeros((1, 1, 8, 16, 32))
|
|
528
|
+
# p[:, :, 2:6, 10:16, 10:24] = 1
|
|
529
|
+
# grid = _create_identity_grid(p.shape[2:])
|
|
530
|
+
# grid[:, :, :, 0] *= 0.5
|
|
531
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
532
|
+
# torch.all(out == p)
|
|
533
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
|
|
534
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
|
|
535
|
+
|
|
536
|
+
#################
|
|
537
|
+
# with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
|
|
538
|
+
#################
|
|
539
|
+
|
|
540
|
+
# sp = SpatialTransform(
|
|
541
|
+
# patch_size=(48, 52, 54),
|
|
542
|
+
# patch_center_dist_from_border=0,
|
|
543
|
+
# random_crop=True,
|
|
544
|
+
# p_elastic_deform=0,
|
|
545
|
+
# p_rotation=1,
|
|
546
|
+
# p_scaling=0,
|
|
547
|
+
# rotation=0
|
|
548
|
+
# )
|
|
549
|
+
# sp2 = SpatialTransform(
|
|
550
|
+
# patch_size=(48, 52, 54),
|
|
551
|
+
# patch_center_dist_from_border=0,
|
|
552
|
+
# random_crop=True,
|
|
553
|
+
# p_elastic_deform=0,
|
|
554
|
+
# p_rotation=0,
|
|
555
|
+
# p_scaling=0,
|
|
556
|
+
# )
|
|
557
|
+
#
|
|
558
|
+
# patch = torch.zeros((1, 64, 60, 68))
|
|
559
|
+
# patch[:, :, 10, 30] = 1
|
|
560
|
+
# patch[:, 50, :, 30] = 1
|
|
561
|
+
# patch[:, 40, 20, :] = 1
|
|
562
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
|
|
563
|
+
#
|
|
564
|
+
# center_coords = [50, 10, 16]
|
|
565
|
+
# params = sp.get_parameters(image=patch)
|
|
566
|
+
# params['center_location_in_pixels'] = center_coords
|
|
567
|
+
# params2 = sp2.get_parameters(image=patch)
|
|
568
|
+
# params2['center_location_in_pixels'] = center_coords
|
|
569
|
+
# transformed = sp._apply_to_image(patch, **params)
|
|
570
|
+
# transformed2 = sp._apply_to_image(patch, **params2)
|
|
571
|
+
#
|
|
572
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
|
|
573
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
####################
|
|
578
|
+
# This is exploraroty code to check how to retrieve coordinates. I used it to verify that grid_sample does in fact
|
|
579
|
+
# use coordinates in reversed dimension order (zyx and not xyz)
|
|
580
|
+
####################
|
|
581
|
+
# # create a dummy input which has a unique shape in each exis
|
|
582
|
+
# p = torch.zeros((1, 1, 8, 16, 32))
|
|
583
|
+
# # set one pixel to 1
|
|
584
|
+
# p[:, :, 4, 0, 31] = 1
|
|
585
|
+
# # now create an identity grid. I have verified that this grid yields the same image as the input when used in grid_sample. So the grid is correct
|
|
586
|
+
# grid = _create_identity_grid((8, 16, 32)).contiguous() # grid is shape torch.Size([8, 16, 32, 3])
|
|
587
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
588
|
+
# assert torch.all(out == p) # this passes
|
|
589
|
+
# # reduce the grid to the location we are interested in. That are the coordinates where we placed the 1. The 4:5 etc is only so that we keep the number of dimensions
|
|
590
|
+
# grid = grid[4:5, 0:1, 31:32]
|
|
591
|
+
# # What coordinate would we expect? Note that grid is [-1, 1]
|
|
592
|
+
# # For the first dimension, coordinate 4 out of shape 8 is approximately in the middle, so about 0
|
|
593
|
+
# # For the second dimension, coordinate 0 out of shape 16 is very low, so we expect -1 ish (remember there is aligned corners and shit)
|
|
594
|
+
# # For the third dimension, coordinate 31 out of shape 32 is very high, so we expect 1 ish (remember there is aligned corners and shit)
|
|
595
|
+
# # So we expect [0, -1, 1]
|
|
596
|
+
# # What do we get?
|
|
597
|
+
# print(grid)
|
|
598
|
+
# # > tensor([[[[ 0.9688, -0.9375, 0.1250]]]])
|
|
599
|
+
# # not what we expect
|
|
600
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
601
|
+
# assert out.item() == 1
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Set
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransposeAxesTransform(BasicTransform):
|
|
9
|
+
"""
|
|
10
|
+
A transformation class to permute specified spatial axes of an image and related data.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
allowed_axes (Set[int]): Set of spatial axes allowed for permutation (e.g., {1, 2} for y and z axes in an
|
|
14
|
+
image of shape (c, x, y, z)).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, allowed_axes: Set[int]):
|
|
18
|
+
"""
|
|
19
|
+
Initialize the transform with allowed spatial axes for permutation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
allowed_axes (Set[int]): Set of spatial axis indices for permutation.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.allowed_axes = allowed_axes
|
|
26
|
+
|
|
27
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
28
|
+
"""
|
|
29
|
+
Generate a random axis permutation order.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
data_dict (dict): Dictionary containing `image` tensor data.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
dict: Permutation order of axes as 'axis_order'.
|
|
36
|
+
"""
|
|
37
|
+
shape_of_allowed = [data_dict['image'].shape[1 + i] for i in self.allowed_axes]
|
|
38
|
+
if len(shape_of_allowed) < 2:
|
|
39
|
+
return {'axis_order': list(range(len(data_dict['image'].shape)))}
|
|
40
|
+
if not all(i == shape_of_allowed[0] for i in shape_of_allowed[1:]):
|
|
41
|
+
raise ValueError(f"Axis shapes are not identical: {shape_of_allowed}. Cannot permute.\n"
|
|
42
|
+
f"Image shape: {data_dict['image'].shape}. Allowed axes: {self.allowed_axes}")
|
|
43
|
+
|
|
44
|
+
axes = [i + 1 for i in self.allowed_axes]
|
|
45
|
+
np.random.shuffle(axes)
|
|
46
|
+
axis_order = np.arange(len(data_dict['image'].shape))
|
|
47
|
+
axis_order[np.isin(axis_order, axes)] = axes
|
|
48
|
+
return {'axis_order': [int(i) for i in axis_order]}
|
|
49
|
+
|
|
50
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
51
|
+
return segmentation.permute(params['axis_order']).contiguous()
|
|
52
|
+
|
|
53
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
54
|
+
return img.permute(params['axis_order']).contiguous()
|
|
55
|
+
|
|
56
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
57
|
+
return regression_target.permute(params['axis_order']).contiguous()
|
|
58
|
+
|
|
59
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
t = TransposeAxesTransform((1, 2))
|
|
67
|
+
ret = t(**{'image': torch.rand((2, 31, 32, 32)), 'segmentation': torch.ones((1, 31, 32, 32))})
|
|
File without changes
|