batchgeneratorsv2 0.2.1__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/PKG-INFO +2 -2
- batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/spatial/low_resolution.py +6 -2
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/spatial/spatial.py +46 -15
- batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2.egg-info/PKG-INFO +2 -2
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2.egg-info/SOURCES.txt +2 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/pyproject.toml +1 -1
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/LICENSE +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/benchmarks/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/dataloading/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/helpers/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/base/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/base/basic_transform.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/intensity/brightness.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/intensity/contrast.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/intensity/gamma.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/compose.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/pseudo2d.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/random.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/remove_label.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/seg_to_regions.py +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2.egg-info/requires.txt +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/readme.md +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/setup.cfg +0 -0
- {batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/setup.py +0 -0
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class InvertImageTransform(ImageOnlyTransform):
|
|
8
|
+
def __init__(self, p_invert_image: float, p_synchronize_channels: float = 1, p_per_channel: float = 1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.p_invert_image = p_invert_image
|
|
11
|
+
self.p_synchronize_channels = p_synchronize_channels
|
|
12
|
+
self.p_per_channel = p_per_channel
|
|
13
|
+
|
|
14
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
15
|
+
shape = data_dict['image'].shape
|
|
16
|
+
apply = np.random.uniform() < self.p_invert_image
|
|
17
|
+
if apply:
|
|
18
|
+
if np.random.uniform() < self.p_synchronize_channels:
|
|
19
|
+
apply_to_channel = torch.arange(0, shape[0])
|
|
20
|
+
else:
|
|
21
|
+
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
|
|
22
|
+
else:
|
|
23
|
+
apply_to_channel = []
|
|
24
|
+
return {
|
|
25
|
+
'apply_to_channel': apply_to_channel,
|
|
26
|
+
'apply': apply,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
30
|
+
if not params['apply']:
|
|
31
|
+
return img
|
|
32
|
+
else:
|
|
33
|
+
for ch in params['apply_to_channel']:
|
|
34
|
+
mn = img[ch].mean()
|
|
35
|
+
img[ch] -= mn
|
|
36
|
+
img[ch] *= -1
|
|
37
|
+
img[ch] += mn
|
|
38
|
+
return img
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
if __name__ == '__main__':
|
|
42
|
+
mbt = InvertImageTransform(0.5, 0.5, 0.5)
|
|
43
|
+
from batchviewer import view_batch
|
|
44
|
+
|
|
45
|
+
for _ in range(100):
|
|
46
|
+
data_dict = {'image': torch.ones((2, 20, 192, 64))}
|
|
47
|
+
data_dict['image'][0, :10] = -1
|
|
48
|
+
data_dict['image'][1, :5] = -1
|
|
49
|
+
ret = mbt(**data_dict)
|
|
50
|
+
print(ret['image'][0, 0, 0, 0], ret['image'][1, 0, 0, 0])
|
|
51
|
+
view_batch(mbt(**data_dict)['image'])
|
|
@@ -8,9 +8,13 @@ from torch.nn.functional import interpolate
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class SimulateLowResolutionTransform(ImageOnlyTransform):
|
|
11
|
-
def __init__(self,
|
|
11
|
+
def __init__(self,
|
|
12
|
+
scale: RandomScalar,
|
|
13
|
+
synchronize_channels: bool,
|
|
14
|
+
synchronize_axes: bool,
|
|
12
15
|
ignore_axes: Tuple[int, ...],
|
|
13
|
-
allowed_channels: Tuple[int, ...] = None,
|
|
16
|
+
allowed_channels: Tuple[int, ...] = None,
|
|
17
|
+
p_per_channel: float = 1):
|
|
14
18
|
super().__init__()
|
|
15
19
|
self.scale = scale
|
|
16
20
|
self.synchronize_channels = synchronize_channels
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/spatial/spatial.py
RENAMED
|
@@ -7,7 +7,7 @@ import SimpleITK
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import torch
|
|
10
|
-
from scipy.ndimage import fourier_gaussian
|
|
10
|
+
from scipy.ndimage import fourier_gaussian, gaussian_filter
|
|
11
11
|
from torch import Tensor
|
|
12
12
|
from torch.nn.functional import grid_sample
|
|
13
13
|
|
|
@@ -31,10 +31,17 @@ class SpatialTransform(BasicTransform):
|
|
|
31
31
|
scaling: RandomScalar = (0.7, 1.3),
|
|
32
32
|
p_synchronize_scaling_across_axes: float = 0,
|
|
33
33
|
bg_style_seg_sampling: bool = True,
|
|
34
|
-
mode_seg: str = 'bilinear'
|
|
34
|
+
mode_seg: str = 'bilinear',
|
|
35
|
+
border_mode_seg: str = "zeros",
|
|
36
|
+
center_deformation: bool = True,
|
|
37
|
+
padding_mode_image: str = "zeros"
|
|
35
38
|
):
|
|
36
39
|
"""
|
|
37
40
|
magnitude must be given in pixels!
|
|
41
|
+
deformation scale is given as a paercentage of the edge length
|
|
42
|
+
|
|
43
|
+
padding_mode_image: see torch grid_sample documentation. This currently applies to image and regression target
|
|
44
|
+
because both call self._apply_to_image. Can be "zeros", "reflection", "border"
|
|
38
45
|
"""
|
|
39
46
|
super().__init__()
|
|
40
47
|
self.patch_size = patch_size
|
|
@@ -44,7 +51,7 @@ class SpatialTransform(BasicTransform):
|
|
|
44
51
|
self.random_crop = random_crop
|
|
45
52
|
self.p_elastic_deform = p_elastic_deform
|
|
46
53
|
self.elastic_deform_scale = elastic_deform_scale # sigma for blurring offsets, in % of patch size. Larger values mean coarser deformation
|
|
47
|
-
self.elastic_deform_magnitude = elastic_deform_magnitude # determines the maximum displacement, measured in
|
|
54
|
+
self.elastic_deform_magnitude = elastic_deform_magnitude # determines the maximum displacement, measured in pixels!!
|
|
48
55
|
self.p_rotation = p_rotation
|
|
49
56
|
self.rotation = rotation
|
|
50
57
|
self.p_scaling = p_scaling
|
|
@@ -53,6 +60,9 @@ class SpatialTransform(BasicTransform):
|
|
|
53
60
|
self.p_synchronize_def_scale_across_axes = p_synchronize_def_scale_across_axes
|
|
54
61
|
self.bg_style_seg_sampling = bg_style_seg_sampling
|
|
55
62
|
self.mode_seg = mode_seg
|
|
63
|
+
self.border_mode_seg = border_mode_seg
|
|
64
|
+
self.center_deformation = center_deformation
|
|
65
|
+
self.padding_mode_image = padding_mode_image
|
|
56
66
|
|
|
57
67
|
def get_parameters(self, **data_dict) -> dict:
|
|
58
68
|
dim = data_dict['image'].ndim - 1
|
|
@@ -94,7 +104,7 @@ class SpatialTransform(BasicTransform):
|
|
|
94
104
|
else:
|
|
95
105
|
deformation_scales = [
|
|
96
106
|
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
|
|
97
|
-
for i in range(
|
|
107
|
+
for i in range(dim)
|
|
98
108
|
]
|
|
99
109
|
|
|
100
110
|
# sigmas must be in pixels, as this will be applied to the deformation field
|
|
@@ -103,7 +113,7 @@ class SpatialTransform(BasicTransform):
|
|
|
103
113
|
magnitude = [
|
|
104
114
|
sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
|
|
105
115
|
dim=i, deformation_scale=deformation_scales[i])
|
|
106
|
-
for i in range(
|
|
116
|
+
for i in range(dim)]
|
|
107
117
|
# doing it like this for better memory layout for blurring
|
|
108
118
|
offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
|
|
109
119
|
|
|
@@ -118,9 +128,15 @@ class SpatialTransform(BasicTransform):
|
|
|
118
128
|
tmp = fourier_gaussian(tmp, sigmas[d])
|
|
119
129
|
offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
|
|
120
130
|
|
|
131
|
+
# tmp = offsets[d].numpy().astype(np.float64)
|
|
132
|
+
# gaussian_filter(tmp, sigmas[d], 0, output=tmp)
|
|
133
|
+
# offsets[d] = torch.from_numpy(tmp).to(offsets.dtype)
|
|
134
|
+
# print(offsets.dtype)
|
|
135
|
+
|
|
121
136
|
mx = torch.max(torch.abs(offsets[d]))
|
|
122
137
|
offsets[d] /= (mx / np.clip(magnitude[d], a_min=1e-8, a_max=np.inf))
|
|
123
|
-
|
|
138
|
+
spatial_dims = tuple(list(range(1, dim + 1)))
|
|
139
|
+
offsets = torch.permute(offsets, (*spatial_dims, 0))
|
|
124
140
|
else:
|
|
125
141
|
offsets = None
|
|
126
142
|
|
|
@@ -147,8 +163,22 @@ class SpatialTransform(BasicTransform):
|
|
|
147
163
|
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
|
|
148
164
|
# This saves compute.
|
|
149
165
|
# cropping requires the center to be given as integer coordinates
|
|
150
|
-
|
|
151
|
-
|
|
166
|
+
|
|
167
|
+
# torch is inconsistent. AAAAaaah
|
|
168
|
+
if self.padding_mode_image == 'reflection':
|
|
169
|
+
pad_mode = 'reflect'
|
|
170
|
+
pad_kwargs = {}
|
|
171
|
+
elif self.padding_mode_image == 'zeros':
|
|
172
|
+
pad_mode = 'constant'
|
|
173
|
+
pad_kwargs = {'value': 0}
|
|
174
|
+
elif self.padding_mode_image == 'border':
|
|
175
|
+
pad_mode = 'replicate'
|
|
176
|
+
pad_kwargs = {}
|
|
177
|
+
else:
|
|
178
|
+
raise RuntimeError('Unknown pad mode')
|
|
179
|
+
|
|
180
|
+
img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']], self.patch_size, pad_mode=pad_mode,
|
|
181
|
+
pad_kwargs=pad_kwargs)
|
|
152
182
|
return img
|
|
153
183
|
else:
|
|
154
184
|
grid = _create_centered_identity_grid2(self.patch_size)
|
|
@@ -161,15 +191,16 @@ class SpatialTransform(BasicTransform):
|
|
|
161
191
|
|
|
162
192
|
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
|
|
163
193
|
# only do this if we elastic deform
|
|
164
|
-
if params['elastic_offsets'] is not None:
|
|
194
|
+
if self.center_deformation and params['elastic_offsets'] is not None:
|
|
165
195
|
mn = grid.mean(dim=list(range(img.ndim - 1)))
|
|
166
196
|
else:
|
|
167
197
|
mn = 0
|
|
168
198
|
|
|
169
199
|
new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
|
|
170
200
|
grid += (new_center - mn)
|
|
201
|
+
# print(f'grid sample with pad mode {self.padding_mode_image}')
|
|
171
202
|
return grid_sample(img[None], _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
|
|
172
|
-
mode='bilinear', padding_mode=
|
|
203
|
+
mode='bilinear', padding_mode=self.padding_mode_image, align_corners=False)[0]
|
|
173
204
|
|
|
174
205
|
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
175
206
|
segmentation = segmentation.contiguous()
|
|
@@ -193,7 +224,7 @@ class SpatialTransform(BasicTransform):
|
|
|
193
224
|
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
|
|
194
225
|
|
|
195
226
|
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate
|
|
196
|
-
if params['elastic_offsets'] is not None:
|
|
227
|
+
if self.center_deformation and params['elastic_offsets'] is not None:
|
|
197
228
|
mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
|
|
198
229
|
else:
|
|
199
230
|
mn = 0
|
|
@@ -208,7 +239,7 @@ class SpatialTransform(BasicTransform):
|
|
|
208
239
|
segmentation[None].float(),
|
|
209
240
|
grid[None],
|
|
210
241
|
mode=self.mode_seg,
|
|
211
|
-
padding_mode=
|
|
242
|
+
padding_mode=self.border_mode_seg,
|
|
212
243
|
align_corners=False
|
|
213
244
|
)[0].to(segmentation.dtype)
|
|
214
245
|
else:
|
|
@@ -222,7 +253,7 @@ class SpatialTransform(BasicTransform):
|
|
|
222
253
|
((segmentation[c] == labels[1]).float())[None, None],
|
|
223
254
|
grid[None],
|
|
224
255
|
mode=self.mode_seg,
|
|
225
|
-
padding_mode=
|
|
256
|
+
padding_mode=self.border_mode_seg,
|
|
226
257
|
align_corners=False
|
|
227
258
|
)[0][0] >= 0.5
|
|
228
259
|
result_seg[c][out] = labels[1]
|
|
@@ -234,7 +265,7 @@ class SpatialTransform(BasicTransform):
|
|
|
234
265
|
((segmentation[c] == u).float())[None, None],
|
|
235
266
|
grid[None],
|
|
236
267
|
mode=self.mode_seg,
|
|
237
|
-
padding_mode=
|
|
268
|
+
padding_mode=self.border_mode_seg,
|
|
238
269
|
align_corners=False
|
|
239
270
|
)[0][0] >= 0.5] = u
|
|
240
271
|
else:
|
|
@@ -246,7 +277,7 @@ class SpatialTransform(BasicTransform):
|
|
|
246
277
|
done_mask = torch.zeros(*self.patch_size, dtype=torch.bool)
|
|
247
278
|
for i, u in enumerate(labels):
|
|
248
279
|
tmp[i] = grid_sample(((segmentation[c] == u).float() * scale_factor)[None, None], grid[None],
|
|
249
|
-
mode=self.mode_seg, padding_mode=
|
|
280
|
+
mode=self.mode_seg, padding_mode=self.border_mode_seg, align_corners=False)[0][0]
|
|
250
281
|
mask = tmp[i] > (0.7 * scale_factor)
|
|
251
282
|
result_seg[c][mask] = u
|
|
252
283
|
done_mask = done_mask | mask
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Set
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransposeAxesTransform(BasicTransform):
|
|
9
|
+
"""
|
|
10
|
+
A transformation class to permute specified spatial axes of an image and related data.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
allowed_axes (Set[int]): Set of spatial axes allowed for permutation (e.g., {1, 2} for y and z axes in an
|
|
14
|
+
image of shape (c, x, y, z)).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, allowed_axes: Set[int]):
|
|
18
|
+
"""
|
|
19
|
+
Initialize the transform with allowed spatial axes for permutation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
allowed_axes (Set[int]): Set of spatial axis indices for permutation.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.allowed_axes = allowed_axes
|
|
26
|
+
|
|
27
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
28
|
+
"""
|
|
29
|
+
Generate a random axis permutation order.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
data_dict (dict): Dictionary containing `image` tensor data.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
dict: Permutation order of axes as 'axis_order'.
|
|
36
|
+
"""
|
|
37
|
+
shape_of_allowed = [data_dict['image'].shape[1 + i] for i in self.allowed_axes]
|
|
38
|
+
if len(shape_of_allowed) < 2:
|
|
39
|
+
return {'axis_order': list(range(len(data_dict['image'].shape)))}
|
|
40
|
+
if not all(i == shape_of_allowed[0] for i in shape_of_allowed[1:]):
|
|
41
|
+
raise ValueError(f"Axis shapes are not identical: {shape_of_allowed}. Cannot permute.\n"
|
|
42
|
+
f"Image shape: {data_dict['image'].shape}. Allowed axes: {self.allowed_axes}")
|
|
43
|
+
|
|
44
|
+
axes = [i + 1 for i in self.allowed_axes]
|
|
45
|
+
np.random.shuffle(axes)
|
|
46
|
+
axis_order = np.arange(len(data_dict['image'].shape))
|
|
47
|
+
axis_order[np.isin(axis_order, axes)] = axes
|
|
48
|
+
return {'axis_order': [int(i) for i in axis_order]}
|
|
49
|
+
|
|
50
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
51
|
+
return segmentation.permute(params['axis_order']).contiguous()
|
|
52
|
+
|
|
53
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
54
|
+
return img.permute(params['axis_order']).contiguous()
|
|
55
|
+
|
|
56
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
57
|
+
return regression_target.permute(params['axis_order']).contiguous()
|
|
58
|
+
|
|
59
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
t = TransposeAxesTransform((1, 2))
|
|
67
|
+
ret = t(**{'image': torch.rand((2, 31, 32, 32)), 'segmentation': torch.ones((1, 31, 32, 32))})
|
|
@@ -24,6 +24,7 @@ batchgeneratorsv2/transforms/intensity/brightness.py
|
|
|
24
24
|
batchgeneratorsv2/transforms/intensity/contrast.py
|
|
25
25
|
batchgeneratorsv2/transforms/intensity/gamma.py
|
|
26
26
|
batchgeneratorsv2/transforms/intensity/gaussian_noise.py
|
|
27
|
+
batchgeneratorsv2/transforms/intensity/inversion.py
|
|
27
28
|
batchgeneratorsv2/transforms/nnunet/__init__.py
|
|
28
29
|
batchgeneratorsv2/transforms/nnunet/random_binary_operator.py
|
|
29
30
|
batchgeneratorsv2/transforms/nnunet/remove_connected_components.py
|
|
@@ -34,6 +35,7 @@ batchgeneratorsv2/transforms/spatial/__init__.py
|
|
|
34
35
|
batchgeneratorsv2/transforms/spatial/low_resolution.py
|
|
35
36
|
batchgeneratorsv2/transforms/spatial/mirroring.py
|
|
36
37
|
batchgeneratorsv2/transforms/spatial/spatial.py
|
|
38
|
+
batchgeneratorsv2/transforms/spatial/transpose.py
|
|
37
39
|
batchgeneratorsv2/transforms/utils/__init__.py
|
|
38
40
|
batchgeneratorsv2/transforms/utils/compose.py
|
|
39
41
|
batchgeneratorsv2/transforms/utils/cropping.py
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/benchmarks/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/benchmarks/unique_values.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/dataloading/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/helpers/scalar_type.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/__init__.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/base/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/intensity/gamma.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/nnunet/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/noise/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/spatial/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/__init__.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/compose.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/cropping.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/pseudo2d.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2/transforms/utils/random.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.2.1 → batchgeneratorsv2-0.2.3}/batchgeneratorsv2.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|