batchgeneratorsv2 0.1.1__tar.gz → 0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/PKG-INFO +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/base/basic_transform.py +5 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/contrast.py +3 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/spatial.py +220 -76
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/random.py +5 -1
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/PKG-INFO +1 -1
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/pyproject.toml +1 -1
- batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/seg_to_regions.py +0 -23
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/LICENSE +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/dataloading/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/helpers/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/base/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/brightness.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/gamma.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/low_resolution.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/compose.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/pseudo2d.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/remove_label.py +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/SOURCES.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/requires.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/readme.md +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/setup.cfg +0 -0
- {batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/setup.py +0 -0
|
@@ -135,4 +135,4 @@ if __name__ == '__main__':
|
|
|
135
135
|
compute_times[i].append(time() - st)
|
|
136
136
|
|
|
137
137
|
for t, ct in zip(transforms, compute_times):
|
|
138
|
-
print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.
|
|
138
|
+
print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.mean(ct))
|
|
@@ -53,6 +53,11 @@ class BasicTransform(abc.ABC):
|
|
|
53
53
|
def get_parameters(self, **data_dict) -> dict:
|
|
54
54
|
return {}
|
|
55
55
|
|
|
56
|
+
def __repr__(self):
|
|
57
|
+
ret_str = str(type(self).__name__) + "( " + ", ".join(
|
|
58
|
+
[key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
|
|
59
|
+
return ret_str
|
|
60
|
+
|
|
56
61
|
|
|
57
62
|
class ImageOnlyTransform(BasicTransform):
|
|
58
63
|
def apply(self, data_dict: dict, **params) -> dict:
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/contrast.py
RENAMED
|
@@ -22,6 +22,9 @@ class BGContrast():
|
|
|
22
22
|
def __call__(self, *args, **kwargs):
|
|
23
23
|
return self.sample_contrast(*args, **kwargs)
|
|
24
24
|
|
|
25
|
+
def __repr__(self):
|
|
26
|
+
return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
|
|
27
|
+
|
|
25
28
|
class ContrastTransform(ImageOnlyTransform):
|
|
26
29
|
def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
|
|
27
30
|
super().__init__()
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/spatial.py
RENAMED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
2
|
from typing import Tuple, List, Union
|
|
3
3
|
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import SimpleITK
|
|
4
7
|
import numpy as np
|
|
5
8
|
import pandas as pd
|
|
6
9
|
import torch
|
|
@@ -14,14 +17,19 @@ from batchgeneratorsv2.transforms.utils.cropping import crop_tensor
|
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
class SpatialTransform(BasicTransform):
|
|
17
|
-
def __init__(self,
|
|
20
|
+
def __init__(self,
|
|
21
|
+
patch_size: Tuple[int, ...],
|
|
18
22
|
patch_center_dist_from_border: Union[int, List[int], Tuple[int, ...]],
|
|
19
23
|
random_crop: bool,
|
|
20
|
-
p_elastic_deform: float = 0,
|
|
24
|
+
p_elastic_deform: float = 0,
|
|
25
|
+
elastic_deform_scale: RandomScalar = (0, 0.2),
|
|
21
26
|
elastic_deform_magnitude: RandomScalar = (0, 0.2),
|
|
22
|
-
p_synchronize_def_scale_across_axes: float =
|
|
23
|
-
p_rotation: float = 0,
|
|
24
|
-
|
|
27
|
+
p_synchronize_def_scale_across_axes: float = 0,
|
|
28
|
+
p_rotation: float = 0,
|
|
29
|
+
rotation: RandomScalar = (0, 2 * np.pi),
|
|
30
|
+
p_scaling: float = 0,
|
|
31
|
+
scaling: RandomScalar = (0.7, 1.3),
|
|
32
|
+
p_synchronize_scaling_across_axes: float = 0,
|
|
25
33
|
bg_style_seg_sampling: bool = True,
|
|
26
34
|
mode_seg: str = 'bilinear'
|
|
27
35
|
):
|
|
@@ -44,6 +52,7 @@ class SpatialTransform(BasicTransform):
|
|
|
44
52
|
self.mode_seg = mode_seg
|
|
45
53
|
|
|
46
54
|
def get_parameters(self, **data_dict) -> dict:
|
|
55
|
+
# note that we revert the axis order here because grid_sample uses dimensions in reverse order!
|
|
47
56
|
dim = data_dict['image'].ndim - 1
|
|
48
57
|
|
|
49
58
|
do_rotation = np.random.uniform() < self.p_rotation
|
|
@@ -51,21 +60,23 @@ class SpatialTransform(BasicTransform):
|
|
|
51
60
|
do_deform = np.random.uniform() < self.p_elastic_deform
|
|
52
61
|
|
|
53
62
|
if do_rotation:
|
|
54
|
-
angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(dim)]
|
|
63
|
+
angles = [sample_scalar(self.rotation, image=data_dict['image'], dim=i) for i in range(dim - 1, -1, -1)]
|
|
55
64
|
else:
|
|
56
65
|
angles = [0] * dim
|
|
57
66
|
if do_scale:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
67
|
+
if np.random.uniform() <= self.p_synchronize_scaling_across_axes:
|
|
68
|
+
scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=None)] * dim
|
|
69
|
+
else:
|
|
70
|
+
scales = [sample_scalar(self.scaling, image=data_dict['image'], dim=i) for i in range(dim - 1, -1, -1)]
|
|
61
71
|
else:
|
|
62
72
|
scales = [1] * dim
|
|
73
|
+
|
|
63
74
|
# affine matrix
|
|
64
75
|
if do_scale or do_rotation:
|
|
65
76
|
if dim == 3:
|
|
66
77
|
affine = create_affine_matrix_3d(angles, scales)
|
|
67
78
|
elif dim == 2:
|
|
68
|
-
affine = create_affine_matrix_2d(angles[
|
|
79
|
+
affine = create_affine_matrix_2d(angles[-1], scales)
|
|
69
80
|
else:
|
|
70
81
|
raise RuntimeError(f'Unsupported dimension: {dim}')
|
|
71
82
|
else:
|
|
@@ -74,19 +85,24 @@ class SpatialTransform(BasicTransform):
|
|
|
74
85
|
# elastic deformation. We need to create the displacement field here
|
|
75
86
|
# we use the method from augment_spatial_2 in batchgenerators
|
|
76
87
|
if do_deform:
|
|
77
|
-
grid_scale = [i / j for i, j in zip(data_dict['image'].shape[1:], self.patch_size)]
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
88
|
+
grid_scale = [i / j for i, j in zip(data_dict['image'].shape[1:], self.patch_size)][::-1]
|
|
89
|
+
if np.random.uniform() <= self.p_synchronize_def_scale_across_axes:
|
|
90
|
+
deformation_scales = [
|
|
91
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=None, patch_size=self.patch_size)
|
|
92
|
+
] * dim
|
|
93
|
+
else:
|
|
94
|
+
deformation_scales = [
|
|
95
|
+
sample_scalar(self.elastic_deform_scale, image=data_dict['image'], dim=i, patch_size=self.patch_size)
|
|
96
|
+
for i in range(dim - 1, -1, -1)
|
|
97
|
+
]
|
|
98
|
+
|
|
83
99
|
# sigmas must be in pixels, as this will be applied to the deformation field
|
|
84
|
-
sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)]
|
|
100
|
+
sigmas = [i * j for i, j in zip(deformation_scales, self.patch_size)][::-1]
|
|
85
101
|
# the magnitude of the deformation field must adhere to the torch's value range for grid_sample, i.e. [-1. 1] and not pixel coordinates. Do not use sigmas here
|
|
86
102
|
# we need to correct magnitude by grid_scale to account for the fact that the grid will be wrt to the image size but the magnitude should be wrt the patch size. oof.
|
|
87
103
|
magnitude = [
|
|
88
104
|
sample_scalar(self.elastic_deform_magnitude, image=data_dict['image'], patch_size=self.patch_size,
|
|
89
|
-
dim=i, deformation_scale=deformation_scales[i]) / grid_scale[i] for i in range(dim)]
|
|
105
|
+
dim=i, deformation_scale=deformation_scales[i]) / grid_scale[i] for i in range(dim - 1, -1, -1)]
|
|
90
106
|
# doing it like this for better memory layout for blurring
|
|
91
107
|
offsets = torch.normal(mean=0, std=1, size=(dim, *self.patch_size))
|
|
92
108
|
|
|
@@ -98,7 +114,7 @@ class SpatialTransform(BasicTransform):
|
|
|
98
114
|
|
|
99
115
|
# fft numpy, this is faster o.O
|
|
100
116
|
tmp = np.fft.fftn(offsets[d].numpy())
|
|
101
|
-
tmp = fourier_gaussian(tmp, sigmas)
|
|
117
|
+
tmp = fourier_gaussian(tmp, sigmas[d])
|
|
102
118
|
offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
|
|
103
119
|
|
|
104
120
|
mx = torch.max(torch.abs(offsets[d]))
|
|
@@ -109,10 +125,10 @@ class SpatialTransform(BasicTransform):
|
|
|
109
125
|
# grid center must be in [-1, 1] as required by grid_sample
|
|
110
126
|
shape = data_dict['image'].shape[1:]
|
|
111
127
|
if not self.random_crop:
|
|
112
|
-
center_location_in_pixels = [i / 2 for i in shape]
|
|
128
|
+
center_location_in_pixels = [i / 2 for i in shape][::-1]
|
|
113
129
|
else:
|
|
114
130
|
center_location_in_pixels = []
|
|
115
|
-
for d in range(dim):
|
|
131
|
+
for d in range(dim - 1, -1, -1):
|
|
116
132
|
mn = self.patch_center_dist_from_border[d]
|
|
117
133
|
mx = shape[d] - self.patch_center_dist_from_border[d]
|
|
118
134
|
if mx < mn:
|
|
@@ -130,14 +146,14 @@ class SpatialTransform(BasicTransform):
|
|
|
130
146
|
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
|
|
131
147
|
# This saves compute.
|
|
132
148
|
# cropping requires the center to be given as integer coordinates
|
|
133
|
-
img = crop_tensor(img, [
|
|
149
|
+
img = crop_tensor(img, [math.floor(i) for i in params['center_location_in_pixels']][::-1], self.patch_size, pad_mode='constant',
|
|
134
150
|
pad_kwargs={'value': 0})
|
|
135
151
|
return img
|
|
136
152
|
else:
|
|
137
153
|
grid = _create_identity_grid(self.patch_size)
|
|
138
154
|
|
|
139
155
|
# the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
|
|
140
|
-
grid_scale = torch.Tensor([i / j for i, j in zip(img.shape[1:], self.patch_size)])
|
|
156
|
+
grid_scale = torch.Tensor([i / j for i, j in zip(img.shape[1:], self.patch_size)][::-1])
|
|
141
157
|
grid /= grid_scale
|
|
142
158
|
|
|
143
159
|
# we deform first, then rotate
|
|
@@ -147,9 +163,13 @@ class SpatialTransform(BasicTransform):
|
|
|
147
163
|
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
|
|
148
164
|
|
|
149
165
|
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center position
|
|
150
|
-
|
|
166
|
+
# only do this if we elastic deform
|
|
167
|
+
if params['elastic_offsets'] is not None:
|
|
168
|
+
mn = grid.mean(dim=list(range(img.ndim - 1)))
|
|
169
|
+
else:
|
|
170
|
+
mn = 0
|
|
151
171
|
new_center = torch.Tensor(
|
|
152
|
-
[(j / (i / 2) - 1) for i, j in zip(img.shape[1:], params['center_location_in_pixels'])])
|
|
172
|
+
[(j / (i / 2) - 1) for i, j in zip(img.shape[1:][::-1], params['center_location_in_pixels'])])
|
|
153
173
|
grid += - mn + new_center
|
|
154
174
|
return grid_sample(img[None], grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)[0]
|
|
155
175
|
|
|
@@ -159,14 +179,17 @@ class SpatialTransform(BasicTransform):
|
|
|
159
179
|
# No spatial transformation is being done. Round grid_center and crop without having to interpolate.
|
|
160
180
|
# This saves compute.
|
|
161
181
|
# cropping requires the center to be given as integer coordinates
|
|
162
|
-
segmentation = crop_tensor(segmentation,
|
|
163
|
-
|
|
182
|
+
segmentation = crop_tensor(segmentation,
|
|
183
|
+
[math.floor(i) for i in params['center_location_in_pixels']][::-1],
|
|
184
|
+
self.patch_size,
|
|
185
|
+
pad_mode='constant',
|
|
186
|
+
pad_kwargs={'value': 0})
|
|
164
187
|
return segmentation
|
|
165
188
|
else:
|
|
166
189
|
grid = _create_identity_grid(self.patch_size)
|
|
167
190
|
|
|
168
191
|
# the grid must be scaled. The grid is [-1, 1] in image coordinates, but we want it to represent the smaller patch
|
|
169
|
-
grid_scale = torch.Tensor([i / j for i, j in zip(segmentation.shape[1:], self.patch_size)])
|
|
192
|
+
grid_scale = torch.Tensor([i / j for i, j in zip(segmentation.shape[1:], self.patch_size)][::-1])
|
|
170
193
|
grid /= grid_scale
|
|
171
194
|
|
|
172
195
|
# we deform first, then rotate
|
|
@@ -176,9 +199,12 @@ class SpatialTransform(BasicTransform):
|
|
|
176
199
|
grid = torch.matmul(grid, torch.from_numpy(params['affine']).float())
|
|
177
200
|
|
|
178
201
|
# we center the grid around the center_location_in_pixels. We should center the mean of the grid, not the center coordinate
|
|
179
|
-
|
|
202
|
+
if params['elastic_offsets'] is not None:
|
|
203
|
+
mn = grid.mean(dim=list(range(segmentation.ndim - 1)))
|
|
204
|
+
else:
|
|
205
|
+
mn = 0
|
|
180
206
|
new_center = torch.Tensor(
|
|
181
|
-
[(j / (i / 2) - 1) for i, j in zip(segmentation.shape[1:], params['center_location_in_pixels'])])
|
|
207
|
+
[(j / (i / 2) - 1) for i, j in zip(segmentation.shape[1:][::-1], params['center_location_in_pixels'])])
|
|
182
208
|
grid += - mn + new_center
|
|
183
209
|
|
|
184
210
|
if self.mode_seg == 'nearest':
|
|
@@ -308,58 +334,176 @@ def _create_identity_grid(size: List[int]) -> Tensor:
|
|
|
308
334
|
|
|
309
335
|
|
|
310
336
|
if __name__ == '__main__':
|
|
311
|
-
torch.set_num_threads(1)
|
|
312
|
-
|
|
313
|
-
shape = (128, 128, 128)
|
|
314
|
-
patch_size = (128, 128, 128)
|
|
315
|
-
labels = 2
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
# seg = torch.rand([i // 32 for i in shape]) * labels
|
|
319
|
-
# seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
|
|
320
|
-
# decimals=0).to(torch.int16)
|
|
321
|
-
# img = torch.ones((1, *shape))
|
|
322
|
-
# img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
|
|
337
|
+
# torch.set_num_threads(1)
|
|
338
|
+
#
|
|
339
|
+
# shape = (128, 128, 128)
|
|
340
|
+
# patch_size = (128, 128, 128)
|
|
341
|
+
# labels = 2
|
|
342
|
+
#
|
|
343
|
+
#
|
|
344
|
+
# # seg = torch.rand([i // 32 for i in shape]) * labels
|
|
345
|
+
# # seg_up = torch.round(torch.nn.functional.interpolate(seg[None, None], size=shape, mode='trilinear')[0],
|
|
346
|
+
# # decimals=0).to(torch.int16)
|
|
347
|
+
# # img = torch.ones((1, *shape))
|
|
348
|
+
# # img[tuple([slice(img.shape[0])] + [slice(i // 4, i // 4 * 2) for i in shape])] = 200
|
|
349
|
+
#
|
|
350
|
+
#
|
|
351
|
+
# import SimpleITK as sitk
|
|
352
|
+
# # img = camera()
|
|
353
|
+
# # seg = None
|
|
354
|
+
# img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/imagesTr/BraTS2021_00000_0000.nii.gz'))
|
|
355
|
+
# seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/labelsTr/BraTS2021_00000.nii.gz'))
|
|
356
|
+
#
|
|
357
|
+
# patch_size = (192, 192, 192)
|
|
358
|
+
# sp = SpatialTransform(
|
|
359
|
+
# patch_size=(192, 192, 192),
|
|
360
|
+
# patch_center_dist_from_border=[i / 2 for i in patch_size],
|
|
361
|
+
# random_crop=True,
|
|
362
|
+
# p_elastic_deform=0,
|
|
363
|
+
# elastic_deform_magnitude=(0.1, 0.1),
|
|
364
|
+
# elastic_deform_scale=(0.1, 0.1),
|
|
365
|
+
# p_synchronize_def_scale_across_axes=0.5,
|
|
366
|
+
# p_rotation=1,
|
|
367
|
+
# rotation=(-30 / 360 * np.pi, 30 / 360 * np.pi),
|
|
368
|
+
# p_scaling=1,
|
|
369
|
+
# scaling=(0.75, 1),
|
|
370
|
+
# p_synchronize_scaling_across_axes=0.5,
|
|
371
|
+
# bg_style_seg_sampling=True,
|
|
372
|
+
# mode_seg='bilinear'
|
|
373
|
+
# )
|
|
374
|
+
#
|
|
375
|
+
# data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
|
|
376
|
+
# if seg is not None:
|
|
377
|
+
# data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
|
|
378
|
+
# # out = sp(**data_dict)
|
|
379
|
+
# #
|
|
380
|
+
# # view_batch(out['image'], out['segmentation'])
|
|
381
|
+
#
|
|
382
|
+
# from time import time
|
|
383
|
+
# times = []
|
|
384
|
+
# for _ in range(10):
|
|
385
|
+
# data_dict = {'image': torch.from_numpy(deepcopy(img[None])).float()}
|
|
386
|
+
# if seg is not None:
|
|
387
|
+
# data_dict['segmentation'] = torch.from_numpy(deepcopy(seg[None]))
|
|
388
|
+
# st = time()
|
|
389
|
+
# out = sp(**data_dict)
|
|
390
|
+
# times.append(time() - st)
|
|
391
|
+
# print(np.median(times))
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
#################
|
|
395
|
+
# with this part we can qualitatively test that the correct axes are ebing augmented. Just set one of the probs to 1 and off you go
|
|
396
|
+
#################
|
|
397
|
+
|
|
398
|
+
# def constant_scaling(image, dim, patch_size):
|
|
399
|
+
# return 0.1
|
|
400
|
+
#
|
|
401
|
+
# def constant_magnitude(image, dim, patch_size, deformation_scale):
|
|
402
|
+
# return 0.25 if dim == 2 else 0
|
|
403
|
+
#
|
|
404
|
+
# def rot(image, dim):
|
|
405
|
+
# return 45/360 * 2 * np.pi if dim == 1 else 0
|
|
406
|
+
#
|
|
407
|
+
# sp = SpatialTransform(
|
|
408
|
+
# patch_size=(64, 60, 68),
|
|
409
|
+
# patch_center_dist_from_border=0,
|
|
410
|
+
# random_crop=False,
|
|
411
|
+
# p_elastic_deform=0,
|
|
412
|
+
# elastic_deform_scale=0,
|
|
413
|
+
# elastic_deform_magnitude=0,
|
|
414
|
+
# p_synchronize_def_scale_across_axes=0,
|
|
415
|
+
# p_rotation=1,
|
|
416
|
+
# rotation=rot,
|
|
417
|
+
# p_scaling=0,
|
|
418
|
+
# scaling=constant_scaling,
|
|
419
|
+
# p_synchronize_scaling_across_axes=0,
|
|
420
|
+
# bg_style_seg_sampling=False,
|
|
421
|
+
# mode_seg='bilinear'
|
|
422
|
+
# )
|
|
423
|
+
#
|
|
424
|
+
# patch = torch.zeros((1, 64, 60, 68))
|
|
425
|
+
# patch[:, :, 10, 30] = 1
|
|
426
|
+
# patch[:, 50, :, 30] = 1
|
|
427
|
+
# patch[:, 40, 20, :] = 1
|
|
428
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
|
|
429
|
+
#
|
|
430
|
+
# params = sp.get_parameters(image=patch)
|
|
431
|
+
# transformed = sp._apply_to_image(patch, **params)
|
|
432
|
+
#
|
|
433
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
|
|
323
434
|
|
|
435
|
+
# p = torch.zeros((1, 1, 8, 16, 32))
|
|
436
|
+
# p[:, :, 2:6, 10:16, 10:24] = 1
|
|
437
|
+
# grid = _create_identity_grid(p.shape[2:])
|
|
438
|
+
# grid[:, :, :, 0] *= 0.5
|
|
439
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
440
|
+
# torch.all(out == p)
|
|
441
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(p[0, 0].numpy()), 'orig.nii.gz')
|
|
442
|
+
# SimpleITK.WriteImage(SimpleITK.GetImageFromArray(out[0, 0].numpy()), 'transformed.nii.gz')
|
|
324
443
|
|
|
325
|
-
|
|
326
|
-
#
|
|
327
|
-
|
|
328
|
-
img = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/imagesTr/BraTS2021_00000_0000.nii.gz'))
|
|
329
|
-
seg = sitk.GetArrayFromImage(sitk.ReadImage('/media/isensee/raw_data/nnUNet_raw/Dataset137_BraTS2021/labelsTr/BraTS2021_00000.nii.gz'))
|
|
444
|
+
#################
|
|
445
|
+
# with this part I verify that the crop through spatialtransforms grid sample yields the same result as crop_tensor
|
|
446
|
+
#################
|
|
330
447
|
|
|
331
|
-
patch_size = (192, 192, 192)
|
|
332
448
|
sp = SpatialTransform(
|
|
333
|
-
patch_size=(
|
|
334
|
-
patch_center_dist_from_border=
|
|
449
|
+
patch_size=(48, 52, 54),
|
|
450
|
+
patch_center_dist_from_border=0,
|
|
335
451
|
random_crop=True,
|
|
336
452
|
p_elastic_deform=0,
|
|
337
|
-
elastic_deform_magnitude=(0.1, 0.1),
|
|
338
|
-
elastic_deform_scale=(0.1, 0.1),
|
|
339
|
-
p_synchronize_def_scale_across_axes=0.5,
|
|
340
453
|
p_rotation=1,
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
454
|
+
p_scaling=0,
|
|
455
|
+
rotation=0
|
|
456
|
+
)
|
|
457
|
+
sp2 = SpatialTransform(
|
|
458
|
+
patch_size=(48, 52, 54),
|
|
459
|
+
patch_center_dist_from_border=0,
|
|
460
|
+
random_crop=True,
|
|
461
|
+
p_elastic_deform=0,
|
|
462
|
+
p_rotation=0,
|
|
463
|
+
p_scaling=0,
|
|
347
464
|
)
|
|
348
465
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
466
|
+
patch = torch.zeros((1, 64, 60, 68))
|
|
467
|
+
patch[:, :, 10, 30] = 1
|
|
468
|
+
patch[:, 50, :, 30] = 1
|
|
469
|
+
patch[:, 40, 20, :] = 1
|
|
470
|
+
SimpleITK.WriteImage(SimpleITK.GetImageFromArray(patch[0].numpy()), 'orig.nii.gz')
|
|
471
|
+
|
|
472
|
+
center_coords = [30, 28, 44]
|
|
473
|
+
params = sp.get_parameters(image=patch)
|
|
474
|
+
params['center_location_in_pixels'] = center_coords
|
|
475
|
+
params2 = sp2.get_parameters(image=patch)
|
|
476
|
+
params2['center_location_in_pixels'] = center_coords
|
|
477
|
+
transformed = sp._apply_to_image(patch, **params)
|
|
478
|
+
transformed2 = sp._apply_to_image(patch, **params)
|
|
479
|
+
|
|
480
|
+
SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed[0].numpy()), 'transformed.nii.gz')
|
|
481
|
+
SimpleITK.WriteImage(SimpleITK.GetImageFromArray(transformed2[0].numpy()), 'transformed2.nii.gz')
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
####################
|
|
486
|
+
# This is exploraroty code to check how to retrieve coordinates. I used it to verify that grid_sample does in fact
|
|
487
|
+
# use coordinates in reversed dimension order (zyx and not xyz)
|
|
488
|
+
####################
|
|
489
|
+
# # create a dummy input which has a unique shape in each exis
|
|
490
|
+
# p = torch.zeros((1, 1, 8, 16, 32))
|
|
491
|
+
# # set one pixel to 1
|
|
492
|
+
# p[:, :, 4, 0, 31] = 1
|
|
493
|
+
# # now create an identity grid. I have verified that this grid yields the same image as the input when used in grid_sample. So the grid is correct
|
|
494
|
+
# grid = _create_identity_grid((8, 16, 32)).contiguous() # grid is shape torch.Size([8, 16, 32, 3])
|
|
495
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
496
|
+
# assert torch.all(out == p) # this passes
|
|
497
|
+
# # reduce the grid to the location we are interested in. That are the coordinates where we placed the 1. The 4:5 etc is only so that we keep the number of dimensions
|
|
498
|
+
# grid = grid[4:5, 0:1, 31:32]
|
|
499
|
+
# # What coordinate would we expect? Note that grid is [-1, 1]
|
|
500
|
+
# # For the first dimension, coordinate 4 out of shape 8 is approximately in the middle, so about 0
|
|
501
|
+
# # For the second dimension, coordinate 0 out of shape 16 is very low, so we expect -1 ish (remember there is aligned corners and shit)
|
|
502
|
+
# # For the third dimension, coordinate 31 out of shape 32 is very high, so we expect 1 ish (remember there is aligned corners and shit)
|
|
503
|
+
# # So we expect [0, -1, 1]
|
|
504
|
+
# # What do we get?
|
|
505
|
+
# print(grid)
|
|
506
|
+
# # > tensor([[[[ 0.9688, -0.9375, 0.1250]]]])
|
|
507
|
+
# # not what we expect
|
|
508
|
+
# out = grid_sample(p, grid[None], mode='bilinear', padding_mode="zeros", align_corners=False)
|
|
509
|
+
# assert out.item() == 1
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/random.py
RENAMED
|
@@ -16,4 +16,8 @@ class RandomTransform(BasicTransform):
|
|
|
16
16
|
if params['apply_transform']:
|
|
17
17
|
return self.transform(**data_dict)
|
|
18
18
|
else:
|
|
19
|
-
return data_dict
|
|
19
|
+
return data_dict
|
|
20
|
+
|
|
21
|
+
def __repr__(self):
|
|
22
|
+
ret_str = f"{type(self).__name__}(p={self.apply_probability}, transform={self.transform})"
|
|
23
|
+
return ret_str
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Union, List, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
|
|
8
|
+
def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.regions = [torch.Tensor(i) if not isinstance(i, int) else torch.Tensor([i]) for i in regions]
|
|
11
|
+
self.channel_in_seg = channel_in_seg
|
|
12
|
+
|
|
13
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
14
|
+
num_regions = len(self.regions)
|
|
15
|
+
region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
|
|
16
|
+
for region_id, region_labels in enumerate(self.regions):
|
|
17
|
+
if len(region_labels) == 1:
|
|
18
|
+
region_output[region_id] = segmentation[self.channel_in_seg] == region_labels
|
|
19
|
+
else:
|
|
20
|
+
region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels)
|
|
21
|
+
# we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to
|
|
22
|
+
# device followed by cast on device should be faster than having fp32 here and transferring that
|
|
23
|
+
return region_output
|
|
24
|
+
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
from typing import Union, List, Tuple
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
|
|
9
|
-
def __init__(self, regions: Union[List, Tuple], channel_in_seg: int = 0):
|
|
10
|
-
super().__init__()
|
|
11
|
-
self.regions = regions
|
|
12
|
-
self.channel_in_seg = channel_in_seg
|
|
13
|
-
|
|
14
|
-
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
15
|
-
num_regions = len(self.regions)
|
|
16
|
-
region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
|
|
17
|
-
if isinstance(region_labels, int) or len(region_labels) == 1:
|
|
18
|
-
if not isinstance(region_labels, int):
|
|
19
|
-
region_labels = region_labels[0]
|
|
20
|
-
region_output[:, region_id] = seg[:, self.seg_channel] == region_labels
|
|
21
|
-
else:
|
|
22
|
-
region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
|
|
23
|
-
return region_output.to(segmentation.dtype)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/benchmarks/unique_values.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/base/__init__.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/intensity/gamma.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/nnunet/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/noise/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/spatial/mirroring.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/__init__.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/compose.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/cropping.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/pseudo2d.py
RENAMED
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/transforms/utils/remove_label.py
RENAMED
|
File without changes
|
|
File without changes
|
{batchgeneratorsv2-0.1.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|