batchgeneratorsv2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. batchgeneratorsv2/benchmarks/__init__.py +0 -0
  2. batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  3. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  4. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  5. batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  6. batchgeneratorsv2/dataloading/__init__.py +0 -0
  7. batchgeneratorsv2/helpers/__init__.py +0 -0
  8. batchgeneratorsv2/helpers/fft_conv.py +149 -0
  9. batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
  13. batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  15. batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
  16. batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  17. batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  18. batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  19. batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  20. batchgeneratorsv2/transforms/local/__init__.py +0 -0
  21. batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  22. batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  23. batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  24. batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  25. batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  26. batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
  28. batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  29. batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  30. batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  32. batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  33. batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  34. batchgeneratorsv2/transforms/noise/rician.py +61 -0
  35. batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  36. batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  37. batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  38. batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
  39. batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  40. batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  41. batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
  42. batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  43. batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  44. batchgeneratorsv2/transforms/utils/compose.py +89 -0
  45. batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  46. batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  47. batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  48. batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
  49. batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
  50. batchgeneratorsv2/transforms/utils/random.py +46 -0
  51. batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
  52. batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  53. batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
  54. batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
  55. batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
  56. batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
  57. batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,32 @@
1
+ from typing import Union, List, Tuple
2
+
3
+ import torch
4
+
5
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
6
+
7
+
8
+ class MoveSegAsOneHotToDataTransform(BasicTransform):
9
+ def __init__(self, source_channel_idx: int, all_labels: Union[Tuple[int, ...], List[int]],
10
+ remove_channel_from_source: bool = True):
11
+ """
12
+ Used in nnU-Net to append segmentations from the previous stage to the image as additional input
13
+ Args:
14
+ source_channel_idx:
15
+ all_labels:
16
+ remove_channel_from_source:
17
+ """
18
+ super().__init__()
19
+ self.source_channel_idx = source_channel_idx
20
+ self.all_labels = all_labels
21
+ self.remove_channel_from_source = remove_channel_from_source
22
+
23
+ def apply(self, data_dict, **params):
24
+ seg = data_dict['segmentation'][self.source_channel_idx]
25
+ seg_onehot = torch.zeros((len(self.all_labels), *seg.shape), dtype=data_dict['image'].dtype)
26
+ for i, l in enumerate(self.all_labels):
27
+ seg_onehot[i][seg == l] = 1
28
+ data_dict['image'] = torch.cat((data_dict['image'], seg_onehot))
29
+ if self.remove_channel_from_source:
30
+ remaining_channels = [i for i in range(data_dict['segmentation'].shape[0]) if i != self.source_channel_idx]
31
+ data_dict['segmentation'] = data_dict['segmentation'][remaining_channels]
32
+ return data_dict
File without changes
@@ -0,0 +1,150 @@
1
+ import numpy as np
2
+ import torch
3
+ from typing import Union, Tuple, List, Callable
4
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
5
+
6
+
7
+ class ColorFunctionExtractor:
8
+ def __init__(self, rectangle_value: Union[int, float, Tuple[float, float], Callable]):
9
+ self.rectangle_value = rectangle_value
10
+
11
+ def __call__(self, x: torch.Tensor) -> float:
12
+ if np.isscalar(self.rectangle_value):
13
+ return float(self.rectangle_value)
14
+ elif callable(self.rectangle_value):
15
+ return float(self.rectangle_value(x))
16
+ elif isinstance(self.rectangle_value, (tuple, list)):
17
+ return float(np.random.uniform(*self.rectangle_value))
18
+ else:
19
+ raise RuntimeError("Unrecognized format for rectangle_value")
20
+
21
+
22
+ class BlankRectangleTransform(ImageOnlyTransform):
23
+ """
24
+ Overwrites random rectangles in the image with a constant or sampled value.
25
+
26
+ Supports 2D/3D data and various configurations of rectangle size/value.
27
+ """
28
+
29
+ def __init__(self,
30
+ rectangle_size: Union[int,
31
+ Tuple[int, ...],
32
+ Tuple[Tuple[int, int], ...]],
33
+ rectangle_value: Union[int, float, Tuple[float, float], Callable],
34
+ num_rectangles: Union[int, Tuple[int, int]],
35
+ force_square: bool = False,
36
+ p_per_channel: float = 1.0):
37
+ super().__init__()
38
+ self.rectangle_size = rectangle_size
39
+ self.num_rectangles = num_rectangles
40
+ self.force_square = force_square
41
+ self.p_per_channel = p_per_channel
42
+ self.color_fn = ColorFunctionExtractor(rectangle_value)
43
+
44
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
45
+ C = image.shape[0]
46
+ spatial_shape = image.shape[1:]
47
+ D = len(spatial_shape)
48
+
49
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
50
+
51
+ # Number of rectangles
52
+ if isinstance(self.num_rectangles, int):
53
+ n_rects = [self.num_rectangles for _ in range(C)]
54
+ else:
55
+ n_rects = [np.random.randint(self.num_rectangles[0], self.num_rectangles[1]) for _ in range(C)]
56
+
57
+ # Precompute all rectangles for all channels
58
+ rectangles = [[] for _ in range(C)]
59
+ for c in range(C):
60
+ if not apply_channel[c]:
61
+ continue
62
+ for _ in range(n_rects[c]):
63
+ size = self._sample_rectangle_size(D)
64
+ lb = [np.random.randint(0, spatial_shape[d] - size[d] + 1) for d in range(D)]
65
+ ub = [lb[d] + size[d] for d in range(D)]
66
+ rectangles[c].append((lb, ub))
67
+
68
+ return {
69
+ 'apply_channel': apply_channel,
70
+ 'rectangles': rectangles
71
+ }
72
+
73
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
74
+ out = img
75
+ for c, (apply, rects) in enumerate(zip(params['apply_channel'], params['rectangles'])):
76
+ if not apply:
77
+ continue
78
+ for lb, ub in rects:
79
+ slices = tuple([slice(l, u) for l, u in zip(lb, ub)])
80
+ intensity = self.color_fn(out[c][slices])
81
+ out[c][slices] = intensity
82
+ return out
83
+
84
+ def _sample_rectangle_size(self, ndim: int) -> List[int]:
85
+ if isinstance(self.rectangle_size, int):
86
+ return [self.rectangle_size] * ndim
87
+
88
+ elif isinstance(self.rectangle_size, (tuple, list)) and all(isinstance(x, int) for x in self.rectangle_size):
89
+ return list(self.rectangle_size)
90
+
91
+ elif isinstance(self.rectangle_size, (tuple, list)) and all(isinstance(x, (tuple, list)) for x in self.rectangle_size):
92
+ if self.force_square:
93
+ val = np.random.randint(*self.rectangle_size[0])
94
+ return [val] * ndim
95
+ else:
96
+ return [np.random.randint(*self.rectangle_size[d]) for d in range(ndim)]
97
+
98
+ raise RuntimeError("Unrecognized format for rectangle_size")
99
+
100
+
101
+ if __name__ == '__main__':
102
+ # import matplotlib.pyplot as plt
103
+ # from skimage.data import camera
104
+ # from skimage.util import img_as_float32
105
+ #
106
+ # img = torch.from_numpy(img_as_float32(camera())).unsqueeze(0) # (C, H, W)
107
+ #
108
+ # transform = BlankRectangleTransform(
109
+ # rectangle_size=((10, 30), (20, 40)),
110
+ # rectangle_value=(0.0, 1.0),
111
+ # num_rectangles=(2, 5),
112
+ # force_square=False,
113
+ # p_per_channel=1.0
114
+ # )
115
+ #
116
+ # params = transform.get_parameters(image=img)
117
+ # img_aug = transform._apply_to_image(img, **params)
118
+ #
119
+ # plt.subplot(1, 2, 1)
120
+ # plt.imshow(img.squeeze().numpy(), cmap='gray')
121
+ # plt.title("Original")
122
+ #
123
+ # plt.subplot(1, 2, 2)
124
+ # plt.imshow(img_aug.squeeze().numpy(), cmap='gray')
125
+ # plt.title("With Blank Rectangles")
126
+ #
127
+ # plt.tight_layout()
128
+ # plt.show()
129
+ import torch
130
+ import numpy as np
131
+ import matplotlib.pyplot as plt
132
+
133
+ # Create a random 3D image (C, D, H, W)
134
+ image = torch.rand(1, 32, 64, 64) # Single-channel 3D volume
135
+
136
+ # Instantiate the transform
137
+ transform = BlankRectangleTransform(
138
+ rectangle_size=((4, 10), (10, 20), (10, 20)), # (Z, Y, X) size ranges
139
+ rectangle_value=(0.0, 1.0), # Random intensity per rectangle
140
+ num_rectangles=(3, 7), # 3 to 6 rectangles per channel
141
+ force_square=False,
142
+ p_per_channel=1.0 # Always apply to the channel
143
+ )
144
+
145
+ # Sample transform parameters and apply
146
+ params = transform.get_parameters(image=image)
147
+ image_aug = transform._apply_to_image(image, **params)
148
+
149
+ from batchviewer import view_batch
150
+ view_batch(image, image_aug)
@@ -0,0 +1,260 @@
1
+ from copy import deepcopy
2
+
3
+ import numpy as np
4
+ from time import time
5
+ import torch
6
+ from skimage.data import camera
7
+ from torch.nn.functional import pad, conv3d, conv1d, conv2d
8
+
9
+ from batchgeneratorsv2.helpers.fft_conv import fft_conv
10
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
11
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
12
+
13
+
14
+ def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6):
15
+ """
16
+ Smoothes an input image with a 1D Gaussian kernel along the specified dimension.
17
+ The function supports 1D, 2D, and 3D images.
18
+
19
+ :param img: Input image tensor with shape (C, X), (C, X, Y), or (C, X, Y, Z),
20
+ where C is the channel dimension and X, Y, Z are spatial dimensions.
21
+ :param sigma: The standard deviation of the Gaussian kernel.
22
+ :param dim_to_blur: The dimension along which to apply the Gaussian blur (0 for X, 1 for Y, 2 for Z).
23
+ :return: The blurred image tensor.
24
+ """
25
+ assert img.ndim - 1 > dim_to_blur, "dim_to_blur must be a valid spatial dimension of the input image."
26
+ # Adjustments for kernel based on image dimensions
27
+ spatial_dims = img.ndim - 1 # Number of spatial dimensions in the input image
28
+ kernel = _build_kernel(sigma, truncate=truncate)
29
+
30
+ ksize = kernel.shape[0]
31
+
32
+ # Dynamically set up padding, convolution operation, and kernel shape based on the number of spatial dimensions
33
+ conv_ops = {1: conv1d, 2: conv2d, 3: conv3d}
34
+ if force_use_fft is not None:
35
+ conv_op = conv_ops[spatial_dims] if not force_use_fft else fft_conv
36
+ else:
37
+ conv_op = conv_ops[spatial_dims]
38
+
39
+ # Adjust kernel and padding for the specified blur dimension and input dimensions
40
+ if spatial_dims == 1:
41
+ kernel = kernel[None, None, :]
42
+ padding = [ksize // 2, ksize // 2]
43
+ elif spatial_dims == 2:
44
+ if dim_to_blur == 0:
45
+ kernel = kernel[None, None, :, None]
46
+ padding = [0, 0, ksize // 2, ksize // 2]
47
+ else: # dim_to_blur == 1
48
+ kernel = kernel[None, None, None, :]
49
+ padding = [ksize // 2, ksize // 2, 0, 0]
50
+ else: # spatial_dims == 3
51
+ # Expand kernel and adjust padding based on the blur dimension
52
+ if dim_to_blur == 0:
53
+ kernel = kernel[None, None, :, None, None]
54
+ padding = [0, 0, 0, 0, ksize // 2, ksize // 2]
55
+ elif dim_to_blur == 1:
56
+ kernel = kernel[None, None, None, :, None]
57
+ padding = [0, 0, ksize // 2, ksize // 2, 0, 0]
58
+ else: # dim_to_blur == 2
59
+ kernel = kernel[None, None, None, None, :]
60
+ padding = [ksize // 2, ksize // 2, 0, 0, 0, 0]
61
+
62
+ # Apply padding
63
+ img_padded = pad(img, padding, mode="reflect")
64
+
65
+ # Apply convolution
66
+ # remember that weights are [c_out, c_in, ...]
67
+ img_blurred = conv_op(img_padded[None], kernel.expand(img_padded.shape[0], *[-1] * (kernel.ndim - 1)), groups=img_padded.shape[0])[0]
68
+ return img_blurred
69
+
70
+
71
+ class GaussianBlurTransform(ImageOnlyTransform):
72
+ def __init__(self,
73
+ blur_sigma: RandomScalar = (1, 5),
74
+ synchronize_channels: bool = False, # todo make this p_synchronize_channels
75
+ synchronize_axes: bool = False, # todo make this p_synchronize_axes
76
+ p_per_channel: float = 1,
77
+ benchmark: bool = False
78
+ ):
79
+ """
80
+ uses separable gaussian filters for all the speed
81
+
82
+ blur_sigma, if callable, will be called as blur_sigma(image, shape, dim) where shape is (c, x(, y, z) and dim i
83
+ s 1, 2 or 3 for x, y and z, respectively)
84
+ :param blur_sigma:
85
+ :param synchronize_channels:
86
+ :param synchronize_axes:
87
+ :param p_per_channel:
88
+ """
89
+ super().__init__()
90
+ self.blur_sigma = blur_sigma
91
+ self.benchmark = benchmark
92
+ self.synchronize_channels = synchronize_channels
93
+ self.synchronize_axes = synchronize_axes
94
+ self.p_per_channel = p_per_channel
95
+ self.benchmark_use_fft = {} # shape -> kernel size -> use fft yes or no
96
+ self.benchmark_num_runs = 9
97
+
98
+ def get_parameters(self, **data_dict) -> dict:
99
+ shape = data_dict['image'].shape
100
+ dims = len(shape) - 1
101
+ dct = {}
102
+ dct['apply_to_channel'] = torch.rand(shape[0]) < self.p_per_channel
103
+ if self.synchronize_axes:
104
+ dct['sigmas'] = \
105
+ [[sample_scalar(self.blur_sigma, shape, dim=None)] * dims
106
+ for _ in range(sum(dct['apply_to_channel']))] \
107
+ if not self.synchronize_channels else \
108
+ [sample_scalar(self.blur_sigma, shape, dim=None)] * dims
109
+ else:
110
+ dct['sigmas'] = \
111
+ [[sample_scalar(self.blur_sigma, shape, dim=i + 1) for i in range(dims)]
112
+ for _ in range(sum(dct['apply_to_channel']))] \
113
+ if not self.synchronize_channels else \
114
+ [sample_scalar(self.blur_sigma, shape[i + 1]) for i in range(dims)]
115
+ return dct
116
+
117
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
118
+ if len(params['apply_to_channel']) == 0:
119
+ return img
120
+ dim = len(img.shape[1:])
121
+
122
+ # print(params['sigmas'])
123
+ if self.synchronize_channels:
124
+ # we can compute that in one go as the conv implementation supports arbitrary input channels (with expanded kernel)
125
+ for d in range(dim):
126
+ # print(d, params['sigmas'][d])
127
+ if not self.benchmark:
128
+ img[params['apply_to_channel']] = blur_dimension(img[params['apply_to_channel']], params['sigmas'][d], d)
129
+ else:
130
+ img[params['apply_to_channel']] = self._benchmark_wrapper(img[params['apply_to_channel']], params['sigmas'][d], d)
131
+ else:
132
+ # we have to go through all the channels, build the kernel for each channel etc
133
+ idx = np.where(params['apply_to_channel'])[0]
134
+ for j, i in enumerate(idx):
135
+ for d in range(dim):
136
+ # print(i, d, params['sigmas'][i][d])
137
+ if not self.benchmark:
138
+ img[i:i+1] = blur_dimension(img[i:i+1], params['sigmas'][j][d], d)
139
+ else:
140
+ img[i:i+1] = self._benchmark_wrapper(img[i:i+1], params['sigmas'][j][d], d)
141
+ return img
142
+
143
+ def _benchmark_wrapper(self, img: torch.Tensor, sigma: float, dim_to_blur: int):
144
+ kernel_size = _compute_kernel_size(sigma)
145
+ shp = img.shape[dim_to_blur + 1]
146
+ # check if we already benchmarked this
147
+ if shp in self.benchmark_use_fft.keys() and kernel_size in self.benchmark_use_fft[shp].keys():
148
+ return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])
149
+ else:
150
+ # let's not mess up the original image!
151
+ if shp not in self.benchmark_use_fft.keys():
152
+ self.benchmark_use_fft[shp] = {}
153
+ dummy_img = deepcopy(img)
154
+ times_nonfft = []
155
+ for _ in range(self.benchmark_num_runs):
156
+ st = time()
157
+ blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=False)
158
+ times_nonfft.append(time() - st)
159
+ times_fft = []
160
+ for _ in range(self.benchmark_num_runs):
161
+ st = time()
162
+ blur_dimension(dummy_img, sigma, dim_to_blur, force_use_fft=True)
163
+ times_fft.append(time() - st)
164
+ # print(shp, kernel_size, np.median(times_fft), np.median(times_nonfft), np.median(times_fft) < np.median(times_nonfft))
165
+ self.benchmark_use_fft[shp][kernel_size] = np.median(times_fft) < np.median(times_nonfft)
166
+ # convenience stuff
167
+ self.benchmark_use_fft[shp] = dict(sorted(self.benchmark_use_fft[shp].items()))
168
+ # now create the real return value
169
+ return blur_dimension(img, sigma, dim_to_blur, force_use_fft=self.benchmark_use_fft[shp][kernel_size])
170
+
171
+
172
+ def _build_kernel(sigma: float, truncate: float = 4) -> torch.Tensor:
173
+ kernel_size = _compute_kernel_size(sigma, truncate=truncate)
174
+ ksize_half = (kernel_size - 1) * 0.5
175
+
176
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
177
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
178
+ kernel1d = pdf / pdf.sum()
179
+
180
+ return kernel1d
181
+
182
+
183
+ def _round_to_nearest_odd(n):
184
+ rounded = round(n)
185
+ # If the rounded number is odd, return it
186
+ if rounded % 2 == 1:
187
+ return rounded
188
+ # If the rounded number is even, adjust to the nearest odd number
189
+ return rounded + 1 if n - rounded >= 0 else rounded - 1
190
+
191
+
192
+ def _compute_kernel_size(sigma, truncate: float = 4):
193
+ ksize = _round_to_nearest_odd(sigma * truncate + 0.5)
194
+ return ksize
195
+
196
+
197
+ if __name__ == "__main__":
198
+ # this is the fastest for larger kernels but it doesn't fit well into the curent benchmark scheme
199
+
200
+ # tmp = np.fft.fftn(offsets[d].numpy())
201
+ # tmp = fourier_gaussian(tmp, sigmas)
202
+ # offsets[d] = torch.from_numpy(np.fft.ifftn(tmp).real)
203
+
204
+ import os
205
+ from batchgenerators.transforms.noise_transforms import GaussianBlurTransform as GBTBG
206
+ from batchviewer import view_batch
207
+
208
+ os.environ['OMP_NUM_THREADS'] = '1'
209
+ torch.set_num_threads(1)
210
+
211
+ data = camera()
212
+ data_dict = {'image': torch.from_numpy(camera()[None]).float()}
213
+
214
+ gnt2 = GaussianBlurTransform(2, False, False, 1, benchmark=False)
215
+ out = gnt2(**data_dict)
216
+ # view_batch(out['image'], torch.from_numpy(camera()[None]).float())
217
+
218
+
219
+ shape = (128, 164, 64)
220
+ num_warmup_for_benchmark = 1
221
+ num_repeats = 10
222
+ for sigma_range in (0.1, 1, 10, 20):
223
+ print(shape, sigma_range)
224
+ gnt2 = GaussianBlurTransform(sigma_range, False, False, 1, benchmark=False)
225
+ times = []
226
+ for _ in range(num_repeats):
227
+ data_dict = {'image': torch.ones((2, *shape))}
228
+ data_dict['image'][tuple([slice(data_dict['image'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
229
+ st = time()
230
+ out = gnt2(**data_dict)
231
+ times.append(time() - st)
232
+ print('w /o benchmark', np.median(times))
233
+
234
+ gnt = GaussianBlurTransform(sigma_range, False, False, 1, benchmark=True)
235
+ # warmup
236
+ for _ in range(num_warmup_for_benchmark):
237
+ data_dict = {'image': torch.ones((2, *shape))}
238
+ data_dict['image'][tuple([slice(data_dict['image'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
239
+ out = gnt(**data_dict)
240
+ times = []
241
+ for _ in range(num_repeats):
242
+ data_dict = {'image': torch.ones((2, *shape))}
243
+ data_dict['image'][tuple([slice(data_dict['image'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
244
+ st = time()
245
+ out = gnt(**data_dict)
246
+ times.append(time() - st)
247
+ print('with benchmark', np.median(times))
248
+
249
+ gnt3 = GBTBG(sigma_range, True, True, 0, 1, 1)
250
+ times = []
251
+ for _ in range(num_repeats):
252
+ data_dict = {'data': np.ones((1, 2, *shape))}
253
+ data_dict['data'][tuple([slice(data_dict['data'].shape[0])] + [slice(0, i // 2) for i in shape])] = 200
254
+ st = time()
255
+ out = gnt3(**data_dict)
256
+ times.append(time() - st)
257
+ print('batchgenerator', np.median(times))
258
+ print()
259
+ #
260
+ # print(gnt.benchmark_use_fft)
@@ -0,0 +1,52 @@
1
+ import numpy as np
2
+ import torch
3
+ from typing import Union, Tuple
4
+ from scipy.ndimage import median_filter
5
+
6
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
7
+
8
+
9
+ class MedianFilterTransform(ImageOnlyTransform):
10
+ """
11
+ Applies a median filter to selected image channels.
12
+
13
+ Attributes:
14
+ filter_size (int or Tuple[int, int]): Either fixed filter size or range for random sampling.
15
+ p_same_for_each_channel (float): Probability that all channels share the same filter size.
16
+ p_per_channel (float): Probability of applying the filter to a given channel.
17
+ """
18
+
19
+ def __init__(self,
20
+ filter_size: Union[int, Tuple[int, int]],
21
+ p_same_for_each_channel: float = 0.0,
22
+ p_per_channel: float = 1.0):
23
+ super().__init__()
24
+ self.filter_size = filter_size
25
+ self.p_same_for_each_channel = p_same_for_each_channel
26
+ self.p_per_channel = p_per_channel
27
+
28
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
29
+ C = image.shape[0]
30
+ use_same = np.random.rand() < self.p_same_for_each_channel
31
+
32
+ if isinstance(self.filter_size, int):
33
+ sizes = [self.filter_size] * C
34
+ elif use_same:
35
+ sampled_size = int(np.random.randint(*self.filter_size))
36
+ sizes = [sampled_size] * C
37
+ else:
38
+ sizes = [int(np.random.randint(*self.filter_size)) for _ in range(C)]
39
+
40
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
41
+
42
+ return {
43
+ 'filter_sizes': sizes,
44
+ 'apply_channel': apply_channel
45
+ }
46
+
47
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
48
+ img_np = img.cpu().numpy()
49
+ for c, (apply, size) in enumerate(zip(params['apply_channel'], params['filter_sizes'])):
50
+ if apply:
51
+ img_np[c] = median_filter(img_np[c], size=size)
52
+ return torch.from_numpy(img_np).to(img.device)
@@ -0,0 +1,61 @@
1
+ import torch
2
+ from typing import Tuple
3
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
4
+
5
+
6
+ class RicianNoiseTransform(ImageOnlyTransform):
7
+ """
8
+ Adds Rician noise to simulate MRI characteristics.
9
+
10
+ Args:
11
+ noise_variance (Tuple[float, float]): Range to sample Gaussian noise variance used in Rician computation.
12
+ """
13
+
14
+ def __init__(self, noise_variance: Tuple[float, float] = (0.0, 0.1)):
15
+ super().__init__()
16
+ self.noise_variance = noise_variance
17
+
18
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
19
+ variance = float(np.random.uniform(*self.noise_variance))
20
+ return {'variance': variance}
21
+
22
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
23
+ var = params['variance']
24
+ noise_real = torch.empty_like(img).normal_(mean=0.0, std=var)
25
+ noise_imag = torch.empty_like(img).normal_(mean=0.0, std=var)
26
+
27
+ min_val = img.min()
28
+ shifted = img - min_val
29
+
30
+ rician = torch.sqrt((shifted + noise_real).pow_(2).add_(noise_imag.pow_(2)))
31
+ rician = rician + min_val
32
+
33
+ # Normalize to match original mean and std
34
+ input_mean, input_std = img.mean(), img.std()
35
+ rician_mean, rician_std = rician.mean(), rician.std()
36
+
37
+ if rician_std > 0:
38
+ rician = (rician - rician_mean) / rician_std * input_std + input_mean
39
+ else:
40
+ rician = rician * 0 + input_mean # fallback if std is zero (flat image)
41
+
42
+ return rician
43
+
44
+
45
+ if __name__ == '__main__':
46
+ import torch
47
+ import numpy as np
48
+
49
+ # Create a synthetic normalized 3D image: (C, D, H, W)
50
+ image = torch.ones(1, 32, 64, 64) * 0.5 # z-score normalized MRI-like noise
51
+ image[0,1,1,1] = 2
52
+
53
+ # Instantiate the transform
54
+ transform = RicianNoiseTransform(noise_variance=(0.05, 0.1))
55
+
56
+ # Sample parameters and apply transform
57
+ params = transform.get_parameters(image=image)
58
+ image_noisy = transform._apply_to_image(image, **params)
59
+
60
+ from batchviewer import view_batch
61
+ view_batch(image, image_noisy)
@@ -0,0 +1,128 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from typing import Union, Tuple, List
5
+
6
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
7
+
8
+
9
+ class SharpeningTransform(ImageOnlyTransform):
10
+ """
11
+ Applies sharpening to 2D or 3D images using Laplacian-based contrast enhancement.
12
+ Preserves global intensity by explicitly adding scaled Laplacian to the original image.
13
+
14
+ Attributes:
15
+ strength (float or Tuple[float, float]): Sharpening strength.
16
+ p_same_for_each_channel (float): Probability of using the same strength for all channels.
17
+ p_per_channel (float): Probability of applying sharpening to a given channel.
18
+ """
19
+
20
+ def __init__(self,
21
+ strength: Union[float, Tuple[float, float]] = 0.2,
22
+ p_same_for_each_channel: float = 0.0,
23
+ p_per_channel: float = 1.0,
24
+ p_clamp_intensities: float = 0):
25
+ super().__init__()
26
+ self.strength = strength
27
+ self.p_same_for_each_channel = p_same_for_each_channel
28
+ self.p_per_channel = p_per_channel
29
+ self.p_clamp_intensities: float = p_clamp_intensities
30
+
31
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
32
+ C = image.shape[0]
33
+ use_same = np.random.rand() < self.p_same_for_each_channel
34
+
35
+ if use_same:
36
+ strength = self._sample_strength()
37
+ strengths = [strength] * C
38
+ clamp = np.random.uniform() < self.p_clamp_intensities
39
+ clamp_intensities = [clamp] * C
40
+ else:
41
+ strengths = [self._sample_strength() for _ in range(C)]
42
+ clamp_intensities = [np.random.uniform() < self.p_clamp_intensities for _ in range(C)]
43
+
44
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
45
+
46
+ return {
47
+ 'strengths': strengths,
48
+ 'clamp_intensities': clamp_intensities,
49
+ 'apply_channel': apply_channel
50
+ }
51
+
52
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
53
+ out = img
54
+ spatial_dims = img.dim() - 1 # 2 for (C, H, W), 3 for (C, D, H, W)
55
+
56
+ if spatial_dims == 2:
57
+ kernel = torch.tensor([[0, -1, 0],
58
+ [-1, 4, -1],
59
+ [0, -1, 0]], dtype=torch.float32, device=img.device)
60
+ kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, 3, 3)
61
+ pad = (1, 1, 1, 1) # left, right, top, bottom
62
+
63
+ elif spatial_dims == 3:
64
+ kernel = torch.tensor([[[0, 0, 0],
65
+ [0, -1, 0],
66
+ [0, 0, 0]],
67
+ [[0, -1, 0],
68
+ [-1, 6, -1],
69
+ [0, -1, 0]],
70
+ [[0, 0, 0],
71
+ [0, -1, 0],
72
+ [0, 0, 0]]], dtype=torch.float32, device=img.device)
73
+ kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, 3, 3, 3)
74
+ pad = (1, 1, 1, 1, 1, 1) # left, right, top, bottom, front, back
75
+
76
+ else:
77
+ raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}. Expected 2 or 3.")
78
+
79
+ for c, (apply, strength, clamp) in enumerate(zip(params['apply_channel'], params['strengths'], params['clamp_intensities'])):
80
+ if not apply:
81
+ continue
82
+
83
+ if clamp:
84
+ mn, mx = torch.min(img[c]), torch.max(img[c])
85
+
86
+ x = img[c].unsqueeze(0).unsqueeze(0) # (1, 1, H, W) or (1, 1, D, H, W)
87
+ padded = F.pad(x, pad, mode='replicate')
88
+
89
+ if spatial_dims == 2:
90
+ laplace = F.conv2d(padded, kernel)
91
+ else:
92
+ laplace = F.conv3d(padded, kernel)
93
+
94
+ sharpened = x + strength * laplace
95
+ out[c] = sharpened.squeeze()
96
+
97
+ if clamp:
98
+ out[c].clamp_(mn, mx)
99
+
100
+ return out
101
+
102
+ def _sample_strength(self) -> float:
103
+ if isinstance(self.strength, float):
104
+ return self.strength
105
+ return float(np.random.uniform(*self.strength))
106
+
107
+
108
+ if __name__ == '__main__':
109
+ from skimage.data import camera
110
+ from skimage.util import img_as_float32
111
+
112
+ # Load camera image and prepare it
113
+ img_np = img_as_float32(camera()) # (H, W), float32, values in [0, 1]
114
+ img_torch = torch.from_numpy(img_np).unsqueeze(0) # (1, H, W) = (C, H, W)
115
+
116
+ # Instantiate the transform
117
+ transform = SharpeningTransform(
118
+ strength=(2, 2.1), # Sharpening strength range
119
+ p_same_for_each_channel=1.0, # Force same strength for all channels (only 1 channel here)
120
+ p_per_channel=1.0, # Always apply
121
+ p_clamp_intensities = 1
122
+ )
123
+
124
+ # Generate parameters and apply
125
+ params = transform.get_parameters(image=img_torch)
126
+ sharpened = transform._apply_to_image(img_torch, **params)
127
+ from batchviewer import view_batch
128
+ view_batch(img_np, sharpened)
File without changes