batchgeneratorsv2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2/helpers/fft_conv.py +149 -0
- batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
- batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
- batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
- batchgeneratorsv2/transforms/local/__init__.py +0 -0
- batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
- batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
- batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
- batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
- batchgeneratorsv2/transforms/local/local_transform.py +86 -0
- batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
- batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
- batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
- batchgeneratorsv2/transforms/noise/rician.py +61 -0
- batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
- batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
- batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
- batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
- batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
- batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
- batchgeneratorsv2/transforms/utils/random.py +46 -0
- batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
- batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
- batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
- batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
- batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
- batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from typing import Union, List, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MoveSegAsOneHotToDataTransform(BasicTransform):
|
|
9
|
+
def __init__(self, source_channel_idx: int, all_labels: Union[Tuple[int, ...], List[int]],
|
|
10
|
+
remove_channel_from_source: bool = True):
|
|
11
|
+
"""
|
|
12
|
+
Used in nnU-Net to append segmentations from the previous stage to the image as additional input
|
|
13
|
+
Args:
|
|
14
|
+
source_channel_idx:
|
|
15
|
+
all_labels:
|
|
16
|
+
remove_channel_from_source:
|
|
17
|
+
"""
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.source_channel_idx = source_channel_idx
|
|
20
|
+
self.all_labels = all_labels
|
|
21
|
+
self.remove_channel_from_source = remove_channel_from_source
|
|
22
|
+
|
|
23
|
+
def apply(self, data_dict, **params):
|
|
24
|
+
seg = data_dict['segmentation'][self.source_channel_idx]
|
|
25
|
+
seg_onehot = torch.zeros((len(self.all_labels), *seg.shape), dtype=data_dict['image'].dtype)
|
|
26
|
+
for i, l in enumerate(self.all_labels):
|
|
27
|
+
seg_onehot[i][seg == l] = 1
|
|
28
|
+
data_dict['image'] = torch.cat((data_dict['image'], seg_onehot))
|
|
29
|
+
if self.remove_channel_from_source:
|
|
30
|
+
remaining_channels = [i for i in range(data_dict['segmentation'].shape[0]) if i != self.source_channel_idx]
|
|
31
|
+
data_dict['segmentation'] = data_dict['segmentation'][remaining_channels]
|
|
32
|
+
return data_dict
|
|
File without changes
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Union, Tuple, List, Callable
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ColorFunctionExtractor:
|
|
8
|
+
def __init__(self, rectangle_value: Union[int, float, Tuple[float, float], Callable]):
|
|
9
|
+
self.rectangle_value = rectangle_value
|
|
10
|
+
|
|
11
|
+
def __call__(self, x: torch.Tensor) -> float:
|
|
12
|
+
if np.isscalar(self.rectangle_value):
|
|
13
|
+
return float(self.rectangle_value)
|
|
14
|
+
elif callable(self.rectangle_value):
|
|
15
|
+
return float(self.rectangle_value(x))
|
|
16
|
+
elif isinstance(self.rectangle_value, (tuple, list)):
|
|
17
|
+
return float(np.random.uniform(*self.rectangle_value))
|
|
18
|
+
else:
|
|
19
|
+
raise RuntimeError("Unrecognized format for rectangle_value")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BlankRectangleTransform(ImageOnlyTransform):
|
|
23
|
+
"""
|
|
24
|
+
Overwrites random rectangles in the image with a constant or sampled value.
|
|
25
|
+
|
|
26
|
+
Supports 2D/3D data and various configurations of rectangle size/value.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
rectangle_size: Union[int,
|
|
31
|
+
Tuple[int, ...],
|
|
32
|
+
Tuple[Tuple[int, int], ...]],
|
|
33
|
+
rectangle_value: Union[int, float, Tuple[float, float], Callable],
|
|
34
|
+
num_rectangles: Union[int, Tuple[int, int]],
|
|
35
|
+
force_square: bool = False,
|
|
36
|
+
p_per_channel: float = 1.0):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.rectangle_size = rectangle_size
|
|
39
|
+
self.num_rectangles = num_rectangles
|
|
40
|
+
self.force_square = force_square
|
|
41
|
+
self.p_per_channel = p_per_channel
|
|
42
|
+
self.color_fn = ColorFunctionExtractor(rectangle_value)
|
|
43
|
+
|
|
44
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
45
|
+
C = image.shape[0]
|
|
46
|
+
spatial_shape = image.shape[1:]
|
|
47
|
+
D = len(spatial_shape)
|
|
48
|
+
|
|
49
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
50
|
+
|
|
51
|
+
# Number of rectangles
|
|
52
|
+
if isinstance(self.num_rectangles, int):
|
|
53
|
+
n_rects = [self.num_rectangles for _ in range(C)]
|
|
54
|
+
else:
|
|
55
|
+
n_rects = [np.random.randint(self.num_rectangles[0], self.num_rectangles[1]) for _ in range(C)]
|
|
56
|
+
|
|
57
|
+
# Precompute all rectangles for all channels
|
|
58
|
+
rectangles = [[] for _ in range(C)]
|
|
59
|
+
for c in range(C):
|
|
60
|
+
if not apply_channel[c]:
|
|
61
|
+
continue
|
|
62
|
+
for _ in range(n_rects[c]):
|
|
63
|
+
size = self._sample_rectangle_size(D)
|
|
64
|
+
lb = [np.random.randint(0, spatial_shape[d] - size[d] + 1) for d in range(D)]
|
|
65
|
+
ub = [lb[d] + size[d] for d in range(D)]
|
|
66
|
+
rectangles[c].append((lb, ub))
|
|
67
|
+
|
|
68
|
+
return {
|
|
69
|
+
'apply_channel': apply_channel,
|
|
70
|
+
'rectangles': rectangles
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
74
|
+
out = img
|
|
75
|
+
for c, (apply, rects) in enumerate(zip(params['apply_channel'], params['rectangles'])):
|
|
76
|
+
if not apply:
|
|
77
|
+
continue
|
|
78
|
+
for lb, ub in rects:
|
|
79
|
+
slices = tuple([slice(l, u) for l, u in zip(lb, ub)])
|
|
80
|
+
intensity = self.color_fn(out[c][slices])
|
|
81
|
+
out[c][slices] = intensity
|
|
82
|
+
return out
|
|
83
|
+
|
|
84
|
+
def _sample_rectangle_size(self, ndim: int) -> List[int]:
|
|
85
|
+
if isinstance(self.rectangle_size, int):
|
|
86
|
+
return [self.rectangle_size] * ndim
|
|
87
|
+
|
|
88
|
+
elif isinstance(self.rectangle_size, (tuple, list)) and all(isinstance(x, int) for x in self.rectangle_size):
|
|
89
|
+
return list(self.rectangle_size)
|
|
90
|
+
|
|
91
|
+
elif isinstance(self.rectangle_size, (tuple, list)) and all(isinstance(x, (tuple, list)) for x in self.rectangle_size):
|
|
92
|
+
if self.force_square:
|
|
93
|
+
val = np.random.randint(*self.rectangle_size[0])
|
|
94
|
+
return [val] * ndim
|
|
95
|
+
else:
|
|
96
|
+
return [np.random.randint(*self.rectangle_size[d]) for d in range(ndim)]
|
|
97
|
+
|
|
98
|
+
raise RuntimeError("Unrecognized format for rectangle_size")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
if __name__ == '__main__':
|
|
102
|
+
# import matplotlib.pyplot as plt
|
|
103
|
+
# from skimage.data import camera
|
|
104
|
+
# from skimage.util import img_as_float32
|
|
105
|
+
#
|
|
106
|
+
# img = torch.from_numpy(img_as_float32(camera())).unsqueeze(0) # (C, H, W)
|
|
107
|
+
#
|
|
108
|
+
# transform = BlankRectangleTransform(
|
|
109
|
+
# rectangle_size=((10, 30), (20, 40)),
|
|
110
|
+
# rectangle_value=(0.0, 1.0),
|
|
111
|
+
# num_rectangles=(2, 5),
|
|
112
|
+
# force_square=False,
|
|
113
|
+
# p_per_channel=1.0
|
|
114
|
+
# )
|
|
115
|
+
#
|
|
116
|
+
# params = transform.get_parameters(image=img)
|
|
117
|
+
# img_aug = transform._apply_to_image(img, **params)
|
|
118
|
+
#
|
|
119
|
+
# plt.subplot(1, 2, 1)
|
|
120
|
+
# plt.imshow(img.squeeze().numpy(), cmap='gray')
|
|
121
|
+
# plt.title("Original")
|
|
122
|
+
#
|
|
123
|
+
# plt.subplot(1, 2, 2)
|
|
124
|
+
# plt.imshow(img_aug.squeeze().numpy(), cmap='gray')
|
|
125
|
+
# plt.title("With Blank Rectangles")
|
|
126
|
+
#
|
|
127
|
+
# plt.tight_layout()
|
|
128
|
+
# plt.show()
|
|
129
|
+
import torch
|
|
130
|
+
import numpy as np
|
|
131
|
+
import matplotlib.pyplot as plt
|
|
132
|
+
|
|
133
|
+
# Create a random 3D image (C, D, H, W)
|
|
134
|
+
image = torch.rand(1, 32, 64, 64) # Single-channel 3D volume
|
|
135
|
+
|
|
136
|
+
# Instantiate the transform
|
|
137
|
+
transform = BlankRectangleTransform(
|
|
138
|
+
rectangle_size=((4, 10), (10, 20), (10, 20)), # (Z, Y, X) size ranges
|
|
139
|
+
rectangle_value=(0.0, 1.0), # Random intensity per rectangle
|
|
140
|
+
num_rectangles=(3, 7), # 3 to 6 rectangles per channel
|
|
141
|
+
force_square=False,
|
|
142
|
+
p_per_channel=1.0 # Always apply to the channel
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Sample transform parameters and apply
|
|
146
|
+
params = transform.get_parameters(image=image)
|
|
147
|
+
image_aug = transform._apply_to_image(image, **params)
|
|
148
|
+
|
|
149
|
+
from batchviewer import view_batch
|
|
150
|
+
view_batch(image, image_aug)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from time import time
|
|
5
|
+
import torch
|
|
6
|
+
from skimage.data import camera
|
|
7
|
+
from torch.nn.functional import pad, conv3d, conv1d, conv2d
|
|
8
|
+
|
|
9
|
+
from batchgeneratorsv2.helpers.fft_conv import fft_conv
|
|
10
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
11
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6):
|
|
15
|
+
"""
|
|
16
|
+
Smoothes an input image with a 1D Gaussian kernel along the specified dimension.
|
|
17
|
+
The function supports 1D, 2D, and 3D images.
|
|
18
|
+
|
|
19
|
+
:param img: Input image tensor with shape (C, X), (C, X, Y), or (C, X, Y, Z),
|
|
20
|
+
where C is the channel dimension and X, Y, Z are spatial dimensions.
|
|
21
|
+
:param sigma: The standard deviation of the Gaussian kernel.
|
|
22
|
+
:param dim_to_blur: The dimension along which to apply the Gaussian blur (0 for X, 1 for Y, 2 for Z).
|
|
23
|
+
:return: The blurred image tensor.
|
|
24
|
+
"""
|
|
25
|
+
assert img.ndim - 1 > dim_to_blur, "dim_to_blur must be a valid spatial dimension of the input image."
|
|
26
|
+
# Adjustments for kernel based on image dimensions
|
|
27
|
+
spatial_dims = img.ndim - 1 # Number of spatial dimensions in the input image
|
|
28
|
+
kernel = _build_kernel(sigma, truncate=truncate)
|
|
29
|
+
|
|
30
|
+
ksize = kernel.shape[0]
|
|
31
|
+
|
|
32
|
+
# Dynamically set up padding, convolution operation, and kernel shape based on the number of spatial dimensions
|
|
33
|
+
conv_ops = {1: conv1d, 2: conv2d, 3: conv3d}
|
|
34
|
+
if force_use_fft is not None:
|
|
35
|
+
conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv
|
|
36
|
+
else:
|
|
37
|
+
conv_op = conv_ops[spatial_dims]
|
|
38
|
+
|
|
39
|
+
# Adjust kernel and padding for the specified blur dimension and input dimensions
|
|
40
|
+
if spatial_dims == 1:
|
|
41
|
+
kernel = kernel[None, None, :]
|
|
42
|
+
padding = [ksize // 2, ksize // 2]
|
|
43
|
+
elif spatial_dims == 2:
|
|
44
|
+
if dim_to_blur == 0:
|
|
45
|
+
kernel = kernel[None, None, :, None]
|
|
46
|
+
padding = [0, 0, ksize // 2, ksize // 2]
|
|
47
|
+
else: # dim_to_blur == 1
|
|
48
|
+
kernel = kernel[None, None, None, :]
|
|
49
|
+
padding = [ksize // 2, ksize // 2, 0, 0]
|
|
50
|
+
else: # spatial_dims == 3
|
|
51
|
+
# Expand kernel and adjust padding based on the blur dimension
|
|
52
|
+
if dim_to_blur == 0:
|
|
53
|
+
kernel = kernel[None, None, :, None, None]
|
|
54
|
+
padding = [0, 0, 0, 0, ksize // 2, ksize // 2]
|
|
55
|
+
elif dim_to_blur == 1:
|
|
56
|
+
kernel = kernel[None, None, None, :, None]
|
|
57
|
+
padding = [0, 0, ksize // 2, ksize // 2, 0, 0]
|
|
58
|
+
else: # dim_to_blur == 2
|
|
59
|
+
kernel = kernel[None, None, None, None, :]
|
|
60
|
+
padding = [ksize // 2, ksize // 2, 0, 0, 0, 0]
|
|
61
|
+
|
|
62
|
+
# Apply padding
|
|
63
|
+
img_padded = pad(img, padding, mode="reflect")
|
|
64
|
+
|
|
65
|
+
# Apply convolution
|
|
66
|
+
# remember that weights are [c_out, c_in, ...]
|
|
67
|
+
img_blurred = conv_op(img_padded[None], kernel.expand(img_padded.shape[0], *[-1] * (kernel.ndim - 1)), groups=img_padded.shape[0])[0]
|
|
68
|
+
return img_blurred
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class GaussianBlurTransform(ImageOnlyTransform):
|
|
72
|
+
def __init__(self,
|
|
73
|
+
blur_sigma: RandomScalar = (1, 5),
|
|
74
|
+
synchronize_channels: bool = False, # todo make this p_synchronize_channels
|
|
75
|
+
synchronize_axes: bool = False, # todo make this p_synchronize_axes
|
|
76
|
+
p_per_channel: float = 1,
|
|
77
|
+
benchmark: bool = False
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
uses separable gaussian filters for all the speed
|
|
81
|
+
|
|
82
|
+
blur_sigma, if callable, will be called as blur_sigma(image, shape, dim) where shape is (c, x(, y, z) and dim i
|
|
83
|
+
s 1, 2 or 3 for x, y and z, respectively)
|
|
84
|
+
:param blur_sigma:
|
|
85
|
+
:param synchronize_channels:
|
|
86
|
+
:param synchronize_axes:
|
|
87
|
+
:param p_per_channel:
|
|
88
|
+
"""
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.blur_sigma = blur_sigma
|
|
91
|
+
self.benchmark = benchmark
|
|
92
|
+
self.synchronize_channels = synchronize_channels
|
|
93
|
+
self.synchronize_axes = synchronize_axes
|
|
94
|
+
self.p_per_channel = p_per_channel
|
|
95
|
+
self.benchmark_use_fft = {} # shape -> kernel size -> use fft yes or no
|
|
96
|
+
self.benchmark_num_runs = 9
|
|
97
|
+
|
|
98
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
99
|
+
shape = data_dict['image'].shape
|
|
100
|
+
dims = len(shape) - 1
|
|
101
|
+
dct = {}
|
|
102
|
+
dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel
|
|
103
|
+
if self.synchronize_axes:
|
|
104
|
+
dct['sigmas'] = \
|
|
105
|
+
[[sample_scalar(self.blur_sigma, shape, dim=None)] * dims
|
|
106
|
+
for _ in range(sum(dct['apply_to_channel']))] \
|
|
107
|
+
if not self.synchronize_channels else \
|
|
108
|
+
[sample_scalar(self.blur_sigma, shape, dim=None)] * dims
|
|
109
|
+
else:
|
|
110
|
+
dct['sigmas'] = \
|
|
111
|
+
[[sample_scalar(self.blur_sigma, shape, dim=i + 1) for i in range(dims)]
|
|
112
|
+
for _ in range(sum(dct['apply_to_channel']))] \
|
|
113
|
+
if not self.synchronize_channels else \
|
|
114
|
+
[sample_scalar(self.blur_sigma, shape[i + 1]) for i in range(dims)]
|
|
115
|
+
return dct
|
|
116
|
+
|
|
117
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
118
|
+
if len(params['apply_to_channel']) == 0:
|
|
119
|
+
return img
|
|
120
|
+
dim = len(img.shape[1:])
|
|
121
|
+
|
|
122
|
+
# print(params['sigmas'])
|
|
123
|
+
if self.synchronize_channels:
|
|
124
|
+
# we can compute that in one go as the conv implementation supports arbitrary input channels (with expanded kernel)
|
|
125
|
+
for d in range(dim):
|
|
126
|
+
# print(d, params['sigmas'][d])
|
|
127
|
+
if not self.benchmark:
|
|
128
|
+
img[params['apply_to_channel']] = blur_dimension(img[params['apply_to_channel']], params['sigmas'][d], d)
|
|
129
|
+
else:
|
|
130
|
+
img[params['apply_to_channel']] = self._benchmark_wrapper(img[params['apply_to_channel']], params['sigmas'][d], d)
|
|
131
|
+
else:
|
|
132
|
+
# we have to go through all the channels, build the kernel for each channel etc
|
|
133
|
+
idx = np.where(params['apply_to_channel'])[0]
|
|
134
|
+
for j, i in enumerate(idx):
|
|
135
|
+
for d in range(dim):
|
|
136
|
+
# print(i, d, params['sigmas'][i][d])
|
|
137
|
+
if not self.benchmark:
|
|
138
|
+
img[i:i+1] = blur_dimension(img[i:i+1], params['sigmas'][j][d], d)
|
|
139
|
+
else:
|
|
140
|
+
img[i:i+1] = self._benchmark_wrapper(img[i:i+1], params['sigmas'][j][d], d)
|
|
141
|
+
return img
|
|
142
|
+
|
|
143
|
+
def _benchmark_wrapper(self, img: torch.Tensor, sigma: float, dim_to_blur: int):
|
|
144
|
+
kernel_size = _compute_kernel_size(sigma)
|
|
145
|
+
shp = img.shape[dim_to_blur + 1]
|
|
146
|
+
# check if we already benchmarked this
|
|
147
|
+
if shp in self.benchmark_use_fft.keys() and kernel_size in self.benchmark_use_fft[shp].keys():
|
|
148
|
+
return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])
|
|
149
|
+
else:
|
|
150
|
+
# let's not mess up the original image!
|
|
151
|
+
if shp not in self.benchmark_use_fft.keys():
|
|
152
|
+
self.benchmark_use_fft[shp] = {}
|
|
153
|
+
dummy_img = deepcopy(img)
|
|
154
|
+
times_nonfft = []
|
|
155
|
+
for _ in range(self.benchmark_num_runs):
|
|
156
|
+
st = time()
|
|
157
|
+
blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=False)
|
|
158
|
+
times_nonfft.append(time() - st)
|
|
159
|
+
times_fft = []
|
|
160
|
+
for _ in range(self.benchmark_num_runs):
|
|
161
|
+
st = time()
|
|
162
|
+
blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=True)
|
|
163
|
+
times_fft.append(time() - st)
|
|
164
|
+
# print(shp, kernel_size, np.median(times_fft), np.median(times_nonfft), np.median(times_fft) < np.median(times_nonfft))
|
|
165
|
+
self.benchmark_use_fft[shp][kernel_size] = np.median(times_fft) < np.median(times_nonfft)
|
|
166
|
+
# convenience stuff
|
|
167
|
+
self.benchmark_use_fft[shp] = dict(sorted(self.benchmark_use_fft[shp].items()))
|
|
168
|
+
# now create the real return value
|
|
169
|
+
return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _build_kernel(sigma: float, truncate: float = 4) -> torch.Tensor:
|
|
173
|
+
kernel_size = _compute_kernel_size(sigma, truncate=truncate)
|
|
174
|
+
ksize_half = (kernel_size - 1) * 0.5
|
|
175
|
+
|
|
176
|
+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
177
|
+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
178
|
+
kernel1d = pdf / pdf.sum()
|
|
179
|
+
|
|
180
|
+
return kernel1d
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _round_to_nearest_odd(n):
|
|
184
|
+
rounded = round(n)
|
|
185
|
+
# If the rounded number is odd, return it
|
|
186
|
+
if rounded % 2 == 1:
|
|
187
|
+
return rounded
|
|
188
|
+
# If the rounded number is even, adjust to the nearest odd number
|
|
189
|
+
return rounded + 1 if n - rounded >= 0 else rounded - 1
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _compute_kernel_size(sigma, truncate: float = 4):
|
|
193
|
+
ksize = _round_to_nearest_odd(sigma * truncate + 0.5)
|
|
194
|
+
return ksize
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
if __name__ == "__main__":
|
|
198
|
+
# this is the fastest for larger kernels but it doesn't fit well into the curent benchmark scheme
|
|
199
|
+
|
|
200
|
+
# tmp = np.fft.fftn(offsets[d].numpy())
|
|
201
|
+
# tmp = fourier_gaussian(tmp, sigmas)
|
|
202
|
+
# offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
|
|
203
|
+
|
|
204
|
+
import os
|
|
205
|
+
from batchgenerators.transforms.noise_transforms import GaussianBlurTransform as GBTBG
|
|
206
|
+
from batchviewer import view_batch
|
|
207
|
+
|
|
208
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
209
|
+
torch.set_num_threads(1)
|
|
210
|
+
|
|
211
|
+
data = camera()
|
|
212
|
+
data_dict = {'image': torch.from_numpy(camera()[None]).float()}
|
|
213
|
+
|
|
214
|
+
gnt2 = GaussianBlurTransform(2, False, False, 1, benchmark=False)
|
|
215
|
+
out = gnt2(**data_dict)
|
|
216
|
+
# view_batch(out['image'], torch.from_numpy(camera()[None]).float())
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
shape = (128, 164, 64)
|
|
220
|
+
num_warmup_for_benchmark = 1
|
|
221
|
+
num_repeats = 10
|
|
222
|
+
for sigma_range in (0.1, 1, 10, 20):
|
|
223
|
+
print(shape, sigma_range)
|
|
224
|
+
gnt2 = GaussianBlurTransform(sigma_range, False, False, 1, benchmark=False)
|
|
225
|
+
times = []
|
|
226
|
+
for _ in range(num_repeats):
|
|
227
|
+
data_dict = {'image': torch.ones((2, *shape))}
|
|
228
|
+
data_dict['image'][tuple([slice(data_dict['image'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
|
|
229
|
+
st = time()
|
|
230
|
+
out = gnt2(**data_dict)
|
|
231
|
+
times.append(time() - st)
|
|
232
|
+
print('w /o benchmark', np.median(times))
|
|
233
|
+
|
|
234
|
+
gnt = GaussianBlurTransform(sigma_range, False, False, 1, benchmark=True)
|
|
235
|
+
# warmup
|
|
236
|
+
for _ in range(num_warmup_for_benchmark):
|
|
237
|
+
data_dict = {'image': torch.ones((2, *shape))}
|
|
238
|
+
data_dict['image'][tuple([slice(data_dict['image'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
|
|
239
|
+
out = gnt(**data_dict)
|
|
240
|
+
times = []
|
|
241
|
+
for _ in range(num_repeats):
|
|
242
|
+
data_dict = {'image': torch.ones((2, *shape))}
|
|
243
|
+
data_dict['image'][tuple([slice(data_dict['image'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
|
|
244
|
+
st = time()
|
|
245
|
+
out = gnt(**data_dict)
|
|
246
|
+
times.append(time() - st)
|
|
247
|
+
print('with benchmark', np.median(times))
|
|
248
|
+
|
|
249
|
+
gnt3 = GBTBG(sigma_range, True, True, 0, 1, 1)
|
|
250
|
+
times = []
|
|
251
|
+
for _ in range(num_repeats):
|
|
252
|
+
data_dict = {'data': np.ones((1, 2, *shape))}
|
|
253
|
+
data_dict['data'][tuple([slice(data_dict['data'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
|
|
254
|
+
st = time()
|
|
255
|
+
out = gnt3(**data_dict)
|
|
256
|
+
times.append(time() - st)
|
|
257
|
+
print('batchgenerator', np.median(times))
|
|
258
|
+
print()
|
|
259
|
+
#
|
|
260
|
+
# print(gnt.benchmark_use_fft)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Union, Tuple
|
|
4
|
+
from scipy.ndimage import median_filter
|
|
5
|
+
|
|
6
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MedianFilterTransform(ImageOnlyTransform):
|
|
10
|
+
"""
|
|
11
|
+
Applies a median filter to selected image channels.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
filter_size (int or Tuple[int, int]): Either fixed filter size or range for random sampling.
|
|
15
|
+
p_same_for_each_channel (float): Probability that all channels share the same filter size.
|
|
16
|
+
p_per_channel (float): Probability of applying the filter to a given channel.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self,
|
|
20
|
+
filter_size: Union[int, Tuple[int, int]],
|
|
21
|
+
p_same_for_each_channel: float = 0.0,
|
|
22
|
+
p_per_channel: float = 1.0):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.filter_size = filter_size
|
|
25
|
+
self.p_same_for_each_channel = p_same_for_each_channel
|
|
26
|
+
self.p_per_channel = p_per_channel
|
|
27
|
+
|
|
28
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
29
|
+
C = image.shape[0]
|
|
30
|
+
use_same = np.random.rand() < self.p_same_for_each_channel
|
|
31
|
+
|
|
32
|
+
if isinstance(self.filter_size, int):
|
|
33
|
+
sizes = [self.filter_size] * C
|
|
34
|
+
elif use_same:
|
|
35
|
+
sampled_size = int(np.random.randint(*self.filter_size))
|
|
36
|
+
sizes = [sampled_size] * C
|
|
37
|
+
else:
|
|
38
|
+
sizes = [int(np.random.randint(*self.filter_size)) for _ in range(C)]
|
|
39
|
+
|
|
40
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
41
|
+
|
|
42
|
+
return {
|
|
43
|
+
'filter_sizes': sizes,
|
|
44
|
+
'apply_channel': apply_channel
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
48
|
+
img_np = img.cpu().numpy()
|
|
49
|
+
for c, (apply, size) in enumerate(zip(params['apply_channel'], params['filter_sizes'])):
|
|
50
|
+
if apply:
|
|
51
|
+
img_np[c] = median_filter(img_np[c], size=size)
|
|
52
|
+
return torch.from_numpy(img_np).to(img.device)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RicianNoiseTransform(ImageOnlyTransform):
|
|
7
|
+
"""
|
|
8
|
+
Adds Rician noise to simulate MRI characteristics.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
noise_variance (Tuple[float, float]): Range to sample Gaussian noise variance used in Rician computation.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, noise_variance: Tuple[float, float] = (0.0, 0.1)):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.noise_variance = noise_variance
|
|
17
|
+
|
|
18
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
19
|
+
variance = float(np.random.uniform(*self.noise_variance))
|
|
20
|
+
return {'variance': variance}
|
|
21
|
+
|
|
22
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
23
|
+
var = params['variance']
|
|
24
|
+
noise_real = torch.empty_like(img).normal_(mean=0.0, std=var)
|
|
25
|
+
noise_imag = torch.empty_like(img).normal_(mean=0.0, std=var)
|
|
26
|
+
|
|
27
|
+
min_val = img.min()
|
|
28
|
+
shifted = img - min_val
|
|
29
|
+
|
|
30
|
+
rician = torch.sqrt((shifted + noise_real).pow_(2).add_(noise_imag.pow_(2)))
|
|
31
|
+
rician = rician + min_val
|
|
32
|
+
|
|
33
|
+
# Normalize to match original mean and std
|
|
34
|
+
input_mean, input_std = img.mean(), img.std()
|
|
35
|
+
rician_mean, rician_std = rician.mean(), rician.std()
|
|
36
|
+
|
|
37
|
+
if rician_std > 0:
|
|
38
|
+
rician = (rician - rician_mean) / rician_std * input_std + input_mean
|
|
39
|
+
else:
|
|
40
|
+
rician = rician * 0 + input_mean # fallback if std is zero (flat image)
|
|
41
|
+
|
|
42
|
+
return rician
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == '__main__':
|
|
46
|
+
import torch
|
|
47
|
+
import numpy as np
|
|
48
|
+
|
|
49
|
+
# Create a synthetic normalized 3D image: (C, D, H, W)
|
|
50
|
+
image = torch.ones(1, 32, 64, 64) * 0.5 # z-score normalized MRI-like noise
|
|
51
|
+
image[0,1,1,1] = 2
|
|
52
|
+
|
|
53
|
+
# Instantiate the transform
|
|
54
|
+
transform = RicianNoiseTransform(noise_variance=(0.05, 0.1))
|
|
55
|
+
|
|
56
|
+
# Sample parameters and apply transform
|
|
57
|
+
params = transform.get_parameters(image=image)
|
|
58
|
+
image_noisy = transform._apply_to_image(image, **params)
|
|
59
|
+
|
|
60
|
+
from batchviewer import view_batch
|
|
61
|
+
view_batch(image, image_noisy)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import Union, Tuple, List
|
|
5
|
+
|
|
6
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SharpeningTransform(ImageOnlyTransform):
|
|
10
|
+
"""
|
|
11
|
+
Applies sharpening to 2D or 3D images using Laplacian-based contrast enhancement.
|
|
12
|
+
Preserves global intensity by explicitly adding scaled Laplacian to the original image.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
strength (float or Tuple[float, float]): Sharpening strength.
|
|
16
|
+
p_same_for_each_channel (float): Probability of using the same strength for all channels.
|
|
17
|
+
p_per_channel (float): Probability of applying sharpening to a given channel.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self,
|
|
21
|
+
strength: Union[float, Tuple[float, float]] = 0.2,
|
|
22
|
+
p_same_for_each_channel: float = 0.0,
|
|
23
|
+
p_per_channel: float = 1.0,
|
|
24
|
+
p_clamp_intensities: float = 0):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.strength = strength
|
|
27
|
+
self.p_same_for_each_channel = p_same_for_each_channel
|
|
28
|
+
self.p_per_channel = p_per_channel
|
|
29
|
+
self.p_clamp_intensities: float = p_clamp_intensities
|
|
30
|
+
|
|
31
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
32
|
+
C = image.shape[0]
|
|
33
|
+
use_same = np.random.rand() < self.p_same_for_each_channel
|
|
34
|
+
|
|
35
|
+
if use_same:
|
|
36
|
+
strength = self._sample_strength()
|
|
37
|
+
strengths = [strength] * C
|
|
38
|
+
clamp = np.random.uniform() < self.p_clamp_intensities
|
|
39
|
+
clamp_intensities = [clamp] * C
|
|
40
|
+
else:
|
|
41
|
+
strengths = [self._sample_strength() for _ in range(C)]
|
|
42
|
+
clamp_intensities = [np.random.uniform() < self.p_clamp_intensities for _ in range(C)]
|
|
43
|
+
|
|
44
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
45
|
+
|
|
46
|
+
return {
|
|
47
|
+
'strengths': strengths,
|
|
48
|
+
'clamp_intensities': clamp_intensities,
|
|
49
|
+
'apply_channel': apply_channel
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
53
|
+
out = img
|
|
54
|
+
spatial_dims = img.dim() - 1 # 2 for (C, H, W), 3 for (C, D, H, W)
|
|
55
|
+
|
|
56
|
+
if spatial_dims == 2:
|
|
57
|
+
kernel = torch.tensor([[0, -1, 0],
|
|
58
|
+
[-1, 4, -1],
|
|
59
|
+
[0, -1, 0]], dtype=torch.float32, device=img.device)
|
|
60
|
+
kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, 3, 3)
|
|
61
|
+
pad = (1, 1, 1, 1) # left, right, top, bottom
|
|
62
|
+
|
|
63
|
+
elif spatial_dims == 3:
|
|
64
|
+
kernel = torch.tensor([[[0, 0, 0],
|
|
65
|
+
[0, -1, 0],
|
|
66
|
+
[0, 0, 0]],
|
|
67
|
+
[[0, -1, 0],
|
|
68
|
+
[-1, 6, -1],
|
|
69
|
+
[0, -1, 0]],
|
|
70
|
+
[[0, 0, 0],
|
|
71
|
+
[0, -1, 0],
|
|
72
|
+
[0, 0, 0]]], dtype=torch.float32, device=img.device)
|
|
73
|
+
kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, 3, 3, 3)
|
|
74
|
+
pad = (1, 1, 1, 1, 1, 1) # left, right, top, bottom, front, back
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}. Expected 2 or 3.")
|
|
78
|
+
|
|
79
|
+
for c, (apply, strength, clamp) in enumerate(zip(params['apply_channel'], params['strengths'], params['clamp_intensities'])):
|
|
80
|
+
if not apply:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
if clamp:
|
|
84
|
+
mn, mx = torch.min(img[c]), torch.max(img[c])
|
|
85
|
+
|
|
86
|
+
x = img[c].unsqueeze(0).unsqueeze(0) # (1, 1, H, W) or (1, 1, D, H, W)
|
|
87
|
+
padded = F.pad(x, pad, mode='replicate')
|
|
88
|
+
|
|
89
|
+
if spatial_dims == 2:
|
|
90
|
+
laplace = F.conv2d(padded, kernel)
|
|
91
|
+
else:
|
|
92
|
+
laplace = F.conv3d(padded, kernel)
|
|
93
|
+
|
|
94
|
+
sharpened = x + strength * laplace
|
|
95
|
+
out[c] = sharpened.squeeze()
|
|
96
|
+
|
|
97
|
+
if clamp:
|
|
98
|
+
out[c].clamp_(mn, mx)
|
|
99
|
+
|
|
100
|
+
return out
|
|
101
|
+
|
|
102
|
+
def _sample_strength(self) -> float:
|
|
103
|
+
if isinstance(self.strength, float):
|
|
104
|
+
return self.strength
|
|
105
|
+
return float(np.random.uniform(*self.strength))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == '__main__':
|
|
109
|
+
from skimage.data import camera
|
|
110
|
+
from skimage.util import img_as_float32
|
|
111
|
+
|
|
112
|
+
# Load camera image and prepare it
|
|
113
|
+
img_np = img_as_float32(camera()) # (H, W), float32, values in [0, 1]
|
|
114
|
+
img_torch = torch.from_numpy(img_np).unsqueeze(0) # (1, H, W) = (C, H, W)
|
|
115
|
+
|
|
116
|
+
# Instantiate the transform
|
|
117
|
+
transform = SharpeningTransform(
|
|
118
|
+
strength=(2, 2.1), # Sharpening strength range
|
|
119
|
+
p_same_for_each_channel=1.0, # Force same strength for all channels (only 1 channel here)
|
|
120
|
+
p_per_channel=1.0, # Always apply
|
|
121
|
+
p_clamp_intensities = 1
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Generate parameters and apply
|
|
125
|
+
params = transform.get_parameters(image=img_torch)
|
|
126
|
+
sharpened = transform._apply_to_image(img_torch, **params)
|
|
127
|
+
from batchviewer import view_batch
|
|
128
|
+
view_batch(img_np, sharpened)
|
|
File without changes
|