batchgeneratorsv2 0.1.1__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/PKG-INFO +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/base/basic_transform.py +5 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/contrast.py +3 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/spatial.py +267 -90
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/random.py +5 -1
- batchgeneratorsv2-0.2.1/batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/PKG-INFO +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/pyproject.toml +1 -1
- batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/seg_to_regions.py +0 -23
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/LICENSE +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/dataloading/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/helpers/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/base/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/brightness.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/gamma.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/low_resolution.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/compose.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/pseudo2d.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/remove_label.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/SOURCES.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/requires.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/readme.md +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/setup.cfg +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/setup.py +0 -0
|
@@ -135,4 +135,4 @@ if __name__ == '__main__':
|
|
|
135
135
|
compute_times[i].append(time() - st)
|
|
136
136
|
|
|
137
137
|
for t, ct in zip(transforms, compute_times):
|
|
138
|
-
print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.
|
|
138
|
+
print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.mean(ct))
|
|
@@ -53,6 +53,11 @@ class BasicTransform(abc.ABC):
|
|
|
53
53
|
def get_parameters(self, **data_dict) -> dict:
|
|
54
54
|
return {}
|
|
55
55
|
|
|
56
|
+
def __repr__(self):
|
|
57
|
+
ret_str = str(type(self).__name__) + "( " + ", ".join(
|
|
58
|
+
[key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
|
|
59
|
+
return ret_str
|
|
60
|
+
|
|
56
61
|
|
|
57
62
|
class ImageOnlyTransform(BasicTransform):
|
|
58
63
|
def apply(self, data_dict: dict, **params) -> dict:
|
|
@@ -22,6 +22,9 @@ class BGContrast():
|
|
|
22
22
|
def __call__(self, *args, **kwargs):
|
|
23
23
|
return self.sample_contrast(*args, **kwargs)
|
|
24
24
|
|
|
25
|
+
def __repr__(self):
|
|
26
|
+
return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
|
|
27
|
+
|
|
25
28
|
class ContrastTransform(ImageOnlyTransform):
|
|
26
29
|
def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
|
|
27
30
|
super().__init__()
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/spatial.py
RENAMED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
2
|
from typing import Tuple, List, Union
|
|
3
3
|
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import SimpleITK
|
|
4
7
|
import numpy as np
|
|
5
8
|
import pandas as pd
|
|
6
9
|
import torch
|
|
@@ -14,17 +17,25 @@ from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
|
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
class SpatialTransform(BasicTransform):
|
|
17
|
-
def __init__(self,
|
|
20
|
+
def __init__(self,
|
|
21
|
+
patch_size: Tuple[int, ...],
|
|
18
22
|
patch_center_dist_from_border: Union[int, List[int], Tuple[int, ...]],
|
|
19
23
|
random_crop: bool,
|
|
20
|
-
p_elastic_deform: float = 0,
|
|
24
|
+
p_elastic_deform: float = 0,
|
|
25
|
+
elastic_deform_scale: RandomScalar = (0, 0.2),
|
|
21
26
|
elastic_deform_magnitude: RandomScalar = (0, 0.2),
|
|
22
|
-
p_synchronize_def_scale_across_axes: float =
|
|
23
|
-
p_rotation: float = 0,
|
|
24
|
-
|
|
27
|
+
p_synchronize_def_scale_across_axes: float = 0,
|
|
28
|
+
p_rotation: float = 0,
|
|
29
|
+
rotation: RandomScalar = (0, 2 * np.pi),
|
|
30
|
+
p_scaling: float = 0,
|
|
31
|
+
scaling: RandomScalar = (0.7, 1.3),
|
|
32
|
+
p_synchronize_scaling_across_axes: float = 0,
|
|
25
33
|
bg_style_seg_sampling: bool = True,
|
|
26
34
|
mode_seg: str = 'bilinear'
|
|
27
35
|
):
|
|
36
|
+
"""
|
|
37
|
+
magnitude must be given in pixels!
|
|
38
|
+
"""
|
|
28
39
|
super().__init__()
|
|
29
40
|
self.patch_size = patch_size
|
|
30
41
|
if not isinstance(patch_center_dist_from_border, (tuple, list)):
|
|
@@ -51,21 +62,23 @@ class SpatialTransform(BasicTransform):
|
|
|
51
62
|
do_deform = np.random.uniform() < self.p_elastic_deform
|
|
52
63
|
|
|
53
64
|
if do_rotation:
|
|
54
|
-
angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(
|
|
65
|
+
angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(0, 3)]
|
|
55
66
|
else:
|
|
56
67
|
angles = [0] * dim
|
|
57
68
|
if do_scale:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
69
|
+
if np.random.uniform() <= self.p_synchronize_scaling_across_axes:
|
|
70
|
+
scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
|
|
71
|
+
else:
|
|
72
|
+
scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(0, 3)]
|
|
61
73
|
else:
|
|
62
74
|
scales = [1] * dim
|
|
75
|
+
|
|
63
76
|
# affine matrix
|
|
64
77
|
if do_scale or do_rotation:
|
|
65
78
|
if dim == 3:
|
|
66
79
|
affine = create_affine_matrix_3d(angles, scales)
|
|
67
80
|
elif dim == 2:
|
|
68
|
-
affine = create_affine_matrix_2d(angles[
|
|
81
|
+
affine = create_affine_matrix_2d(angles[-1], scales)
|
|
69
82
|
else:
|
|
70
83
|
raise RuntimeError(f'Unsupported dimension: {dim}')
|
|
71
84
|
else:
|
|
@@ -74,19 +87,23 @@ class SpatialTransform(BasicTransform):
|
|
|
74
87
|
# elastic deformation. We need to create the displacement field here
|
|
75
88
|
# we use the method from augment_spatial_2 in batchgenerators
|
|
76
89
|
if do_deform:
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
90
|
+
if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
|
|
91
|
+
deformation_scales = [
|
|
92
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)
|
|
93
|
+
] * dim
|
|
94
|
+
else:
|
|
95
|
+
deformation_scales = [
|
|
96
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
|
|
97
|
+
for i in range(0, 3)
|
|
98
|
+
]
|
|
99
|
+
|
|
83
100
|
# sigmas must be in pixels, as this will be applied to the deformation field
|
|
84
101
|
sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
|
|
85
|
-
|
|
86
|
-
# we need to correct magnitude by grid_scale to account for the fact that the grid will be wrt to the image size but the magnitude should be wrt the patch size. oof.
|
|
102
|
+
|
|
87
103
|
magnitude = [
|
|
88
104
|
sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
|
|
89
|
-
dim=i, deformation_scale=deformation_scales[i])
|
|
105
|
+
dim=i, deformation_scale=deformation_scales[i])
|
|
106
|
+
for i in range(0, 3)]
|
|
90
107
|
# doing it like this for better memory layout for blurring
|
|
91
108
|
offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
|
|
92
109
|
|
|
@@ -98,7 +115,7 @@ class SpatialTransform(BasicTransform):
|
|
|
98
115
|
|
|
99
116
|
# fft numpy, this is faster o.O
|
|
100
117
|
tmp = np.fft.fftn(offsets[d].numpy())
|
|
101
|
-
tmp = fourier_gaussian(tmp, sigmas)
|
|
118
|
+
tmp = fourier_gaussian(tmp, sigmas[d])
|
|
102
119
|
offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
|
|
103
120
|
|
|
104
121
|
mx = torch.max(torch.abs(offsets[d]))
|
|
@@ -106,13 +123,13 @@ class SpatialTransform(BasicTransform):
|
|
|
106
123
|
offsets = torch.permute(offsets, (1, 2, 3, 0))
|
|
107
124
|
else:
|
|
108
125
|
offsets = None
|
|
109
|
-
|
|
126
|
+
|
|
110
127
|
shape = data_dict['image'].shape[1:]
|
|
111
128
|
if not self.random_crop:
|
|
112
129
|
center_location_in_pixels = [i / 2 for i in shape]
|
|
113
130
|
else:
|
|
114
131
|
center_location_in_pixels = []
|
|
115
|
-
for d in range(
|
|
132
|
+
for d in range(0, 3):
|
|
116
133
|
mn = self.patch_center_dist_from_border[d]
|
|
117
134
|
mx = shape[d] - self.patch_center_dist_from_border[d]
|
|
118
135
|
if mx < mn:
|
|
@@ -130,15 +147,11 @@ class SpatialTransform(BasicTransform):
|
|
|
130
147
|
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
|
|
131
148
|
# This saves compute.
|
|
132
149
|
# cropping requires the center to be given as integer coordinates
|
|
133
|
-
img = crop_tensor(img, [
|
|
150
|
+
img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']], self.patch_size, pad_mode='constant',
|
|
134
151
|
pad_kwargs={'value': 0})
|
|
135
152
|
return img
|
|
136
153
|
else:
|
|
137
|
-
grid =
|
|
138
|
-
|
|
139
|
-
# the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
|
|
140
|
-
grid_scale = torch.Tensor([i / j for i, j in zip(img.shape[1:], self.patch_size)])
|
|
141
|
-
grid /= grid_scale
|
|
154
|
+
grid = _create_centered_identity_grid2(self.patch_size)
|
|
142
155
|
|
|
143
156
|
# we deform first, then rotate
|
|
144
157
|
if params['elastic_offsets'] is not None:
|
|
@@ -147,11 +160,16 @@ class SpatialTransform(BasicTransform):
|
|
|
147
160
|
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
|
|
148
161
|
|
|
149
162
|
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
163
|
+
# only do this if we elastic deform
|
|
164
|
+
if params['elastic_offsets'] is not None:
|
|
165
|
+
mn = grid.mean(dim=list(range(img.ndim - 1)))
|
|
166
|
+
else:
|
|
167
|
+
mn = 0
|
|
168
|
+
|
|
169
|
+
new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], img.shape[1:])])
|
|
170
|
+
grid += (new_center - mn)
|
|
171
|
+
return grid_sample(img[None], _convert_my_grid_to_grid_sample_grid(grid, img.shape[1:])[None],
|
|
172
|
+
mode='bilinear', padding_mode="zeros", align_corners=False)[0]
|
|
155
173
|
|
|
156
174
|
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
157
175
|
segmentation = segmentation.contiguous()
|
|
@@ -159,15 +177,14 @@ class SpatialTransform(BasicTransform):
|
|
|
159
177
|
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
|
|
160
178
|
# This saves compute.
|
|
161
179
|
# cropping requires the center to be given as integer coordinates
|
|
162
|
-
segmentation = crop_tensor(segmentation,
|
|
163
|
-
|
|
180
|
+
segmentation = crop_tensor(segmentation,
|
|
181
|
+
[math.floor(i) for i in params['center_location_in_pixels']],
|
|
182
|
+
self.patch_size,
|
|
183
|
+
pad_mode='constant',
|
|
184
|
+
pad_kwargs={'value': 0})
|
|
164
185
|
return segmentation
|
|
165
186
|
else:
|
|
166
|
-
grid =
|
|
167
|
-
|
|
168
|
-
# the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
|
|
169
|
-
grid_scale = torch.Tensor([i / j for i, j in zip(segmentation.shape[1:], self.patch_size)])
|
|
170
|
-
grid /= grid_scale
|
|
187
|
+
grid = _create_centered_identity_grid2(self.patch_size)
|
|
171
188
|
|
|
172
189
|
# we deform first, then rotate
|
|
173
190
|
if params['elastic_offsets'] is not None:
|
|
@@ -176,10 +193,15 @@ class SpatialTransform(BasicTransform):
|
|
|
176
193
|
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
|
|
177
194
|
|
|
178
195
|
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
196
|
+
if params['elastic_offsets'] is not None:
|
|
197
|
+
mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
|
|
198
|
+
else:
|
|
199
|
+
mn = 0
|
|
200
|
+
|
|
201
|
+
new_center = torch.Tensor([c - s / 2 for c, s in zip(params['center_location_in_pixels'], segmentation.shape[1:])])
|
|
202
|
+
|
|
203
|
+
grid += (new_center - mn)
|
|
204
|
+
grid = _convert_my_grid_to_grid_sample_grid(grid, segmentation.shape[1:])
|
|
183
205
|
|
|
184
206
|
if self.mode_seg == 'nearest':
|
|
185
207
|
result_seg = grid_sample(
|
|
@@ -279,15 +301,38 @@ def create_affine_matrix_2d(rotation_angle, scaling_factors):
|
|
|
279
301
|
return RS
|
|
280
302
|
|
|
281
303
|
|
|
282
|
-
def _create_identity_grid(size: List[int]) -> Tensor:
|
|
283
|
-
|
|
304
|
+
# def _create_identity_grid(size: List[int]) -> Tensor:
|
|
305
|
+
# space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size[::-1]]
|
|
306
|
+
# grid = torch.meshgrid(space, indexing="ij")
|
|
307
|
+
# grid = torch.stack(grid, -1)
|
|
308
|
+
# spatial_dims = list(range(len(size)))
|
|
309
|
+
# grid = grid.permute((*spatial_dims[::-1], len(size)))
|
|
310
|
+
# return grid
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _create_centered_identity_grid2(size: Union[Tuple[int, ...], List[int]]) -> torch.Tensor:
|
|
314
|
+
space = [torch.linspace((1 - s) / 2, (s - 1) / 2, s) for s in size]
|
|
284
315
|
grid = torch.meshgrid(space, indexing="ij")
|
|
285
316
|
grid = torch.stack(grid, -1)
|
|
286
|
-
spatial_dims = list(range(len(size)))
|
|
287
|
-
grid = grid.permute((*spatial_dims[::-1], len(size)))
|
|
288
317
|
return grid
|
|
289
318
|
|
|
290
319
|
|
|
320
|
+
def _convert_my_grid_to_grid_sample_grid(my_grid: torch.Tensor, original_shape: Union[Tuple[int, ...], List[int]]):
|
|
321
|
+
# rescale
|
|
322
|
+
for d in range(len(original_shape)):
|
|
323
|
+
s = original_shape[d]
|
|
324
|
+
my_grid[..., d] /= (s / 2)
|
|
325
|
+
my_grid = torch.flip(my_grid, (len(my_grid.shape) - 1, ))
|
|
326
|
+
# my_grid = my_grid.flip((len(my_grid.shape) - 1,))
|
|
327
|
+
return my_grid
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
# size = (4, 5, 6)
|
|
331
|
+
# grid_old = _create_identity_grid(size)
|
|
332
|
+
# grid_new = _create_centered_identity_grid2(size)
|
|
333
|
+
# grid_new_converted = _convert_my_grid_to_grid_sample_grid(grid_new, size)
|
|
334
|
+
# torch.all(torch.isclose(grid_new_converted, grid_old))
|
|
335
|
+
|
|
291
336
|
# An alternative way of generating the displacement fieldQ
|
|
292
337
|
# def displacement_field(data: torch.Tensor):
|
|
293
338
|
# downscaling_global = np.random.uniform() ** 2 * 4 + 2
|
|
@@ -308,58 +353,190 @@ def _create_identity_grid(size: List[int]) -> Tensor:
|
|
|
308
353
|
|
|
309
354
|
|
|
310
355
|
if __name__ == '__main__':
|
|
311
|
-
torch.set_num_threads(1)
|
|
356
|
+
# torch.set_num_threads(1)
|
|
357
|
+
#
|
|
358
|
+
# shape = (128, 128, 128)
|
|
359
|
+
# patch_size = (128, 128, 128)
|
|
360
|
+
# labels = 2
|
|
361
|
+
#
|
|
362
|
+
#
|
|
363
|
+
# # seg = torch.rand([i // 32 for i in shape]) * labels
|
|
364
|
+
# # seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
|
|
365
|
+
# # decimals=0).to(torch.int16)
|
|
366
|
+
# # img = torch.ones((1, *shape))
|
|
367
|
+
# # img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
|
|
368
|
+
#
|
|
369
|
+
#
|
|
370
|
+
# import SimpleITK as sitk
|
|
371
|
+
# # img = camera()
|
|
372
|
+
# # seg = None
|
|
373
|
+
# img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/imagesTr/BraTS2021_00000_0000.nii.gz'))
|
|
374
|
+
# seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/labelsTr/BraTS2021_00000.nii.gz'))
|
|
375
|
+
#
|
|
376
|
+
# patch_size = (192, 192, 192)
|
|
377
|
+
# sp = SpatialTransform(
|
|
378
|
+
# patch_size=(192, 192, 192),
|
|
379
|
+
# patch_center_dist_from_border=[i / 2 for i in patch_size],
|
|
380
|
+
# random_crop=True,
|
|
381
|
+
# p_elastic_deform=0,
|
|
382
|
+
# elastic_deform_magnitude=(0.1, 0.1),
|
|
383
|
+
# elastic_deform_scale=(0.1, 0.1),
|
|
384
|
+
# p_synchronize_def_scale_across_axes=0.5,
|
|
385
|
+
# p_rotation=1,
|
|
386
|
+
# rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
|
|
387
|
+
# p_scaling=1,
|
|
388
|
+
# scaling=(0.75, 1),
|
|
389
|
+
# p_synchronize_scaling_across_axes=0.5,
|
|
390
|
+
# bg_style_seg_sampling=True,
|
|
391
|
+
# mode_seg='bilinear'
|
|
392
|
+
# )
|
|
393
|
+
#
|
|
394
|
+
# data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
|
|
395
|
+
# if seg is not None:
|
|
396
|
+
# data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
|
|
397
|
+
# # out = sp(**data_dict)
|
|
398
|
+
# #
|
|
399
|
+
# # view_batch(out['image'], out['segmentation'])
|
|
400
|
+
#
|
|
401
|
+
# from time import time
|
|
402
|
+
# times = []
|
|
403
|
+
# for _ in range(10):
|
|
404
|
+
# data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
|
|
405
|
+
# if seg is not None:
|
|
406
|
+
# data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
|
|
407
|
+
# st = time()
|
|
408
|
+
# out = sp(**data_dict)
|
|
409
|
+
# times.append(time() - st)
|
|
410
|
+
# print(np.median(times))
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
#################
|
|
414
|
+
# with this part we can qualitatively test that the correct axes are ebing augmented. Just set one of the probs to 1 and off you go
|
|
415
|
+
#################
|
|
416
|
+
|
|
417
|
+
def eldef_scale(image, dim, patch_size):
|
|
418
|
+
return 0.1
|
|
312
419
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
labels = 2
|
|
420
|
+
def eldef_magnitude(image, dim, patch_size, deformation_scale):
|
|
421
|
+
return 10 if dim == 2 else 0
|
|
316
422
|
|
|
423
|
+
def rot(image, dim):
|
|
424
|
+
return 45/360 * 2 * np.pi if dim == 0 else 0
|
|
317
425
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
# decimals=0).to(torch.int16)
|
|
321
|
-
# img = torch.ones((1, *shape))
|
|
322
|
-
# img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
|
|
426
|
+
def scaling(image, dim):
|
|
427
|
+
return 0.5 if dim == 0 else 1
|
|
323
428
|
|
|
429
|
+
# lines
|
|
430
|
+
patch = torch.zeros((1, 64, 60, 68))
|
|
431
|
+
patch[:, :, 10, 30] = 1
|
|
432
|
+
patch[:, 50, :, 30] = 1
|
|
433
|
+
patch[:, 40, 20, :] = 1
|
|
324
434
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
435
|
+
# patch_block
|
|
436
|
+
patch_block = torch.zeros((1, 64, 60, 68))
|
|
437
|
+
patch_block[:, 22:42, 20:40, 24:44] = 1
|
|
438
|
+
|
|
439
|
+
patch_line = torch.zeros((1, 64, 60, 128))
|
|
440
|
+
patch_line[:, 22:24, 30:32, 10:-10] = 1
|
|
441
|
+
use = patch_line
|
|
330
442
|
|
|
331
|
-
patch_size = (192, 192, 192)
|
|
332
443
|
sp = SpatialTransform(
|
|
333
|
-
patch_size=
|
|
334
|
-
patch_center_dist_from_border=
|
|
335
|
-
random_crop=
|
|
444
|
+
patch_size=patch.shape[1:],
|
|
445
|
+
patch_center_dist_from_border=0,
|
|
446
|
+
random_crop=False,
|
|
336
447
|
p_elastic_deform=0,
|
|
337
|
-
elastic_deform_magnitude=(0.1, 0.1),
|
|
338
|
-
elastic_deform_scale=(0.1, 0.1),
|
|
339
|
-
p_synchronize_def_scale_across_axes=0.5,
|
|
340
448
|
p_rotation=1,
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
449
|
+
p_scaling=0,
|
|
450
|
+
elastic_deform_scale=eldef_scale,
|
|
451
|
+
elastic_deform_magnitude=eldef_magnitude,
|
|
452
|
+
p_synchronize_def_scale_across_axes=0,
|
|
453
|
+
rotation=rot,
|
|
454
|
+
scaling=scaling,
|
|
455
|
+
p_synchronize_scaling_across_axes=0,
|
|
456
|
+
bg_style_seg_sampling=False,
|
|
346
457
|
mode_seg='bilinear'
|
|
347
458
|
)
|
|
348
459
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
460
|
+
|
|
461
|
+
SimpleITK.WriteImage(SimpleITK.GetImageFromArray(use[0].numpy()), 'orig.nii.gz')
|
|
462
|
+
|
|
463
|
+
params = sp.get_parameters(image=use)
|
|
464
|
+
transformed = sp._apply_to_image(use, **params)
|
|
465
|
+
|
|
466
|
+
SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
|
|
467
|
+
|
|
468
|
+
# p = torch.zeros((1, 1, 8, 16, 32))
|
|
469
|
+
# p[:, :, 2:6, 10:16, 10:24] = 1
|
|
470
|
+
# grid = _create_identity_grid(p.shape[2:])
|
|
471
|
+
# grid[:, :, :, 0] *= 0.5
|
|
472
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
473
|
+
# torch.all(out == p)
|
|
474
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
|
|
475
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
|
|
476
|
+
|
|
477
|
+
#################
|
|
478
|
+
# with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
|
|
479
|
+
#################
|
|
480
|
+
|
|
481
|
+
# sp = SpatialTransform(
|
|
482
|
+
# patch_size=(48, 52, 54),
|
|
483
|
+
# patch_center_dist_from_border=0,
|
|
484
|
+
# random_crop=True,
|
|
485
|
+
# p_elastic_deform=0,
|
|
486
|
+
# p_rotation=1,
|
|
487
|
+
# p_scaling=0,
|
|
488
|
+
# rotation=0
|
|
489
|
+
# )
|
|
490
|
+
# sp2 = SpatialTransform(
|
|
491
|
+
# patch_size=(48, 52, 54),
|
|
492
|
+
# patch_center_dist_from_border=0,
|
|
493
|
+
# random_crop=True,
|
|
494
|
+
# p_elastic_deform=0,
|
|
495
|
+
# p_rotation=0,
|
|
496
|
+
# p_scaling=0,
|
|
497
|
+
# )
|
|
498
|
+
#
|
|
499
|
+
# patch = torch.zeros((1, 64, 60, 68))
|
|
500
|
+
# patch[:, :, 10, 30] = 1
|
|
501
|
+
# patch[:, 50, :, 30] = 1
|
|
502
|
+
# patch[:, 40, 20, :] = 1
|
|
503
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
|
|
504
|
+
#
|
|
505
|
+
# center_coords = [50, 10, 16]
|
|
506
|
+
# params = sp.get_parameters(image=patch)
|
|
507
|
+
# params['center_location_in_pixels'] = center_coords
|
|
508
|
+
# params2 = sp2.get_parameters(image=patch)
|
|
509
|
+
# params2['center_location_in_pixels'] = center_coords
|
|
510
|
+
# transformed = sp._apply_to_image(patch, **params)
|
|
511
|
+
# transformed2 = sp._apply_to_image(patch, **params2)
|
|
353
512
|
#
|
|
354
|
-
#
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
513
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
|
|
514
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
####################
|
|
519
|
+
# This is exploraroty code to check how to retrieve coordinates. I used it to verify that grid_sample does in fact
|
|
520
|
+
# use coordinates in reversed dimension order (zyx and not xyz)
|
|
521
|
+
####################
|
|
522
|
+
# # create a dummy input which has a unique shape in each exis
|
|
523
|
+
# p = torch.zeros((1, 1, 8, 16, 32))
|
|
524
|
+
# # set one pixel to 1
|
|
525
|
+
# p[:, :, 4, 0, 31] = 1
|
|
526
|
+
# # now create an identity grid. I have verified that this grid yields the same image as the input when used in grid_sample. So the grid is correct
|
|
527
|
+
# grid = _create_identity_grid((8, 16, 32)).contiguous() # grid is shape torch.Size([8, 16, 32, 3])
|
|
528
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
529
|
+
# assert torch.all(out == p) # this passes
|
|
530
|
+
# # reduce the grid to the location we are interested in. That are the coordinates where we placed the 1. The 4:5 etc is only so that we keep the number of dimensions
|
|
531
|
+
# grid = grid[4:5, 0:1, 31:32]
|
|
532
|
+
# # What coordinate would we expect? Note that grid is [-1, 1]
|
|
533
|
+
# # For the first dimension, coordinate 4 out of shape 8 is approximately in the middle, so about 0
|
|
534
|
+
# # For the second dimension, coordinate 0 out of shape 16 is very low, so we expect -1 ish (remember there is aligned corners and shit)
|
|
535
|
+
# # For the third dimension, coordinate 31 out of shape 32 is very high, so we expect 1 ish (remember there is aligned corners and shit)
|
|
536
|
+
# # So we expect [0, -1, 1]
|
|
537
|
+
# # What do we get?
|
|
538
|
+
# print(grid)
|
|
539
|
+
# # > tensor([[[[ 0.9688, -0.9375, 0.1250]]]])
|
|
540
|
+
# # not what we expect
|
|
541
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
542
|
+
# assert out.item() == 1
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/random.py
RENAMED
|
@@ -16,4 +16,8 @@ class RandomTransform(BasicTransform):
|
|
|
16
16
|
if params['apply_transform']:
|
|
17
17
|
return self.transform(**data_dict)
|
|
18
18
|
else:
|
|
19
|
-
return data_dict
|
|
19
|
+
return data_dict
|
|
20
|
+
|
|
21
|
+
def __repr__(self):
|
|
22
|
+
ret_str = f"{type(self).__name__}(p={self.apply_probability}, transform={self.transform})"
|
|
23
|
+
return ret_str
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Union, List, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
|
|
8
|
+
def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.regions = [torch.Tensor(i) if not isinstance(i, int) else torch.Tensor([i]) for i in regions]
|
|
11
|
+
self.channel_in_seg = channel_in_seg
|
|
12
|
+
|
|
13
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
14
|
+
num_regions = len(self.regions)
|
|
15
|
+
region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
|
|
16
|
+
for region_id, region_labels in enumerate(self.regions):
|
|
17
|
+
if len(region_labels) == 1:
|
|
18
|
+
region_output[region_id] = segmentation[self.channel_in_seg] == region_labels
|
|
19
|
+
else:
|
|
20
|
+
region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels)
|
|
21
|
+
# we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to
|
|
22
|
+
# device followed by cast on device should be faster than having fp32 here and transferring that
|
|
23
|
+
return region_output
|
|
24
|
+
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
from typing import Union, List, Tuple
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
|
|
9
|
-
def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
|
|
10
|
-
super().__init__()
|
|
11
|
-
self.regions = regions
|
|
12
|
-
self.channel_in_seg = channel_in_seg
|
|
13
|
-
|
|
14
|
-
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
15
|
-
num_regions = len(self.regions)
|
|
16
|
-
region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
|
|
17
|
-
if isinstance(region_labels, int) or len(region_labels) == 1:
|
|
18
|
-
if not isinstance(region_labels, int):
|
|
19
|
-
region_labels = region_labels[0]
|
|
20
|
-
region_output[:, region_id] = seg[:, self.seg_channel] == region_labels
|
|
21
|
-
else:
|
|
22
|
-
region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
|
|
23
|
-
return region_output.to(segmentation.dtype)
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/benchmarks/unique_values.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/dataloading/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/helpers/scalar_type.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/__init__.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/base/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/intensity/gamma.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/nnunet/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/noise/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/spatial/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/__init__.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/compose.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/cropping.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2/transforms/utils/pseudo2d.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2.1}/batchgeneratorsv2.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|