batchgeneratorsv2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. batchgeneratorsv2/benchmarks/__init__.py +0 -0
  2. batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  3. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  4. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  5. batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  6. batchgeneratorsv2/dataloading/__init__.py +0 -0
  7. batchgeneratorsv2/helpers/__init__.py +0 -0
  8. batchgeneratorsv2/helpers/fft_conv.py +149 -0
  9. batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
  13. batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  15. batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
  16. batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  17. batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  18. batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  19. batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  20. batchgeneratorsv2/transforms/local/__init__.py +0 -0
  21. batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  22. batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  23. batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  24. batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  25. batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  26. batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
  28. batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  29. batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  30. batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  32. batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  33. batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  34. batchgeneratorsv2/transforms/noise/rician.py +61 -0
  35. batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  36. batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  37. batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  38. batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
  39. batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  40. batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  41. batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
  42. batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  43. batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  44. batchgeneratorsv2/transforms/utils/compose.py +89 -0
  45. batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  46. batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  47. batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  48. batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
  49. batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
  50. batchgeneratorsv2/transforms/utils/random.py +46 -0
  51. batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
  52. batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  53. batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
  54. batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
  55. batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
  56. batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
  57. batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
File without changes
File without changes
@@ -0,0 +1,90 @@
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
6
+ ContrastAugmentationTransform, GammaTransform
7
+ from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
8
+ from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
9
+ from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
10
+ from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, NumpyToTensor
11
+ from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
12
+ DownsampleSegForDSTransform2
13
+ from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
14
+ from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
15
+ ConvertSegmentationToRegionsTransform
16
+ from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
17
+ Convert2DTo3DTransform
18
+
19
+ if __name__ == '__main__':
20
+ regions = ((1, 2, 3), (2, 3), (3, ))
21
+ do_dummy_2d_data_aug = False
22
+ patch_size = (128, 128, 128)
23
+ rotation_for_DA = (0, 2*np.pi)
24
+ deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
25
+
26
+ tr_transforms = []
27
+ if do_dummy_2d_data_aug:
28
+ ignore_axes = (0,)
29
+ tr_transforms.append(Convert3DTo2DTransform())
30
+ patch_size_spatial = patch_size[1:]
31
+ else:
32
+ patch_size_spatial = patch_size
33
+ ignore_axes = None
34
+
35
+ tr_transforms.append(SpatialTransform(
36
+ patch_size_spatial, patch_center_dist_from_border=None,
37
+ do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
38
+ do_rotation=True, angle_x=rotation_for_DA, angle_y=rotation_for_DA, angle_z=rotation_for_DA,
39
+ p_rot_per_axis=1, # todo experiment with this
40
+ do_scale=True, scale=(0.7, 1.4),
41
+ border_mode_data="constant", border_cval_data=0, order_data=3,
42
+ border_mode_seg="constant", border_cval_seg=-1, order_seg=1,
43
+ random_crop=False, # random cropping is part of our dataloaders
44
+ p_el_per_sample=0, p_scale_per_sample=1, p_rot_per_sample=1,
45
+ independent_scale_for_each_axis=False # todo experiment with this
46
+ ))
47
+
48
+ if do_dummy_2d_data_aug:
49
+ tr_transforms.append(Convert2DTo3DTransform())
50
+
51
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=1, p_per_channel=1))
52
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=1,
53
+ p_per_channel=1))
54
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=1))
55
+ tr_transforms.append(ContrastAugmentationTransform(p_per_sample=1, p_per_channel=1))
56
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
57
+ p_per_channel=1,
58
+ order_downsample=0, order_upsample=3, p_per_sample=1,
59
+ ignore_axes=ignore_axes))
60
+ tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=1))
61
+ tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1))
62
+
63
+ tr_transforms.append(MirrorTransform((0, 1, 2)))
64
+
65
+ tr_transforms.append(MaskTransform([0, 1, 2, 3],
66
+ mask_idx_in_seg=0, set_outside_to=0))
67
+
68
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
69
+
70
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'seg', 'seg'))
71
+
72
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='seg',
73
+ output_key='seg'))
74
+
75
+ tr_transforms.append(NumpyToTensor(['data', 'seg'], 'float'))
76
+
77
+ compute_times = [[] for i in range(len(tr_transforms))]
78
+
79
+ torch.set_num_threads(1)
80
+ for iter in range(50):
81
+ print(iter)
82
+ data_dict = {'data': np.random.uniform(size=(1, 4, 128, 128, 128)),
83
+ 'seg': np.round(4.5 * np.random.uniform(size=(1, 1, 128, 128, 128)) - 1, decimals=0).astype(np.int8)}
84
+ for i, t in enumerate(tr_transforms):
85
+ st = time()
86
+ data_dict = t(**data_dict)
87
+ compute_times[i].append(time() - st)
88
+
89
+ for t, ct in zip(tr_transforms, compute_times):
90
+ print(t.__class__.__name__, np.mean(ct))
@@ -0,0 +1,138 @@
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
7
+ from batchgeneratorsv2.transforms.intensity.contrast import BGContrast, ContrastTransform
8
+ from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
9
+ from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
10
+ from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
11
+ from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
12
+ from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
13
+ from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
14
+ from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
15
+ from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
16
+ from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
17
+ from batchgeneratorsv2.transforms.utils.pseudo2d import Convert2DTo3DTransform, Convert3DTo2DTransform
18
+ from batchgeneratorsv2.transforms.utils.random import RandomTransform
19
+ from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
20
+ from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
21
+
22
+ if __name__ == '__main__':
23
+ regions = ((1, 2, 3), (2, 3), (3, ))
24
+ do_dummy_2d_data_aug = False
25
+ patch_size = (128, 128, 128)
26
+ rotation_for_DA = (0, 2*np.pi)
27
+ deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
28
+
29
+ transforms = []
30
+ if do_dummy_2d_data_aug:
31
+ ignore_axes = (0,)
32
+ transforms.append(Convert3DTo2DTransform())
33
+ patch_size_spatial = patch_size[1:]
34
+ else:
35
+ patch_size_spatial = patch_size
36
+ ignore_axes = None
37
+ transforms.append(
38
+ SpatialTransform(
39
+ patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
40
+ p_rotation=1,
41
+ rotation=rotation_for_DA, p_scaling=1, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1
42
+ )
43
+ )
44
+ if do_dummy_2d_data_aug:
45
+ transforms.append(Convert2DTo3DTransform())
46
+
47
+ transforms.append(
48
+ GaussianNoiseTransform(
49
+ noise_variance=(0, 0.1),
50
+ p_per_channel=1,
51
+ synchronize_channels=True
52
+ )
53
+ )
54
+ transforms.append(
55
+ GaussianBlurTransform(
56
+ blur_sigma=(0.5, 1.),
57
+ synchronize_channels=False,
58
+ synchronize_axes=False,
59
+ p_per_channel=1, benchmark=True
60
+ ))
61
+ transforms.append(
62
+ MultiplicativeBrightnessTransform(
63
+ multiplier_range=BGContrast((0.75, 1.25)),
64
+ synchronize_channels=False,
65
+ p_per_channel=1
66
+ ))
67
+ transforms.append(
68
+ ContrastTransform(
69
+ contrast_range=BGContrast((0.75, 1.25)),
70
+ preserve_range=True,
71
+ synchronize_channels=False,
72
+ p_per_channel=1
73
+ ))
74
+ transforms.append(
75
+ SimulateLowResolutionTransform(
76
+ scale=(0.5, 1),
77
+ synchronize_channels=False,
78
+ synchronize_axes=True,
79
+ ignore_axes=ignore_axes,
80
+ allowed_channels=None,
81
+ p_per_channel=1
82
+ ))
83
+ transforms.append(
84
+ GammaTransform(
85
+ gamma=BGContrast((0.7, 1.5)),
86
+ p_invert_image=1,
87
+ synchronize_channels=False,
88
+ p_per_channel=1,
89
+ p_retain_stats=1
90
+ ))
91
+ transforms.append(
92
+ GammaTransform(
93
+ gamma=BGContrast((0.7, 1.5)),
94
+ p_invert_image=0,
95
+ synchronize_channels=False,
96
+ p_per_channel=1,
97
+ p_retain_stats=1
98
+ ))
99
+ transforms.append(
100
+ MirrorTransform(
101
+ allowed_axes=(0, 1, 2)
102
+ )
103
+ )
104
+
105
+ transforms.append(MaskImageTransform(
106
+ apply_to_channels=[0, 1, 2, 3],
107
+ channel_idx_in_seg=0,
108
+ set_outside_to=0,
109
+ ))
110
+
111
+ transforms.append(
112
+ RemoveLabelTansform(-1, 0)
113
+ )
114
+
115
+ transforms.append(
116
+ ConvertSegmentationToRegionsTransform(
117
+ regions=regions,
118
+ channel_in_seg=0
119
+ )
120
+ )
121
+
122
+ transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
123
+
124
+ compute_times = [[] for i in range(len(transforms))]
125
+
126
+ with torch.no_grad():
127
+ torch.set_num_threads(1)
128
+ for iter in range(50):
129
+ print(iter)
130
+ data_dict = {'image': torch.rand((4, 128, 128, 128)),
131
+ 'segmentation': torch.round(4.5 * torch.rand((1, 128, 128, 128)) - 1, decimals=0).to(torch.int8)}
132
+ for i, t in enumerate(transforms):
133
+ st = time()
134
+ data_dict = t(**data_dict)
135
+ compute_times[i].append(time() - st)
136
+
137
+ for t, ct in zip(transforms, compute_times):
138
+ print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.mean(ct))
@@ -0,0 +1,55 @@
1
+ import torch
2
+ import numpy as np
3
+ from time import time
4
+ import pandas as pd
5
+
6
+ def unique_torch(tensor):
7
+ return torch.unique(tensor)
8
+
9
+ def unique_npy(tensor):
10
+ return np.unique(tensor.numpy())
11
+
12
+ def unique_pandas(tensor):
13
+ np.sort(pd.unique(tensor.numpy().ravel()))
14
+
15
+ def unique_bincount(tensor):
16
+ return torch.where(torch.bincount(tensor.ravel()) > 0)[0]
17
+
18
+
19
+ if __name__ == '__main__':
20
+ torch.set_num_threads(1)
21
+ shape = (64, 64, 64)
22
+ labels = 200
23
+
24
+ times = []
25
+ for _ in range(10):
26
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
27
+ st = time()
28
+ unique = unique_torch(seg)
29
+ times.append(time() - st)
30
+ print('unique_torch', np.median(times))
31
+
32
+ times = []
33
+ for _ in range(10):
34
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
35
+ st = time()
36
+ unique = unique_npy(seg)
37
+ times.append(time() - st)
38
+ print('unique_npy', np.median(times))
39
+
40
+ times = []
41
+ for _ in range(10):
42
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
43
+ st = time()
44
+ unique = unique_pandas(seg)
45
+ times.append(time() - st)
46
+ print('unique_pandas', np.median(times))
47
+
48
+ times = []
49
+ for _ in range(10):
50
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
51
+ st = time()
52
+ unique = unique_bincount(seg)
53
+ times.append(time() - st)
54
+ print('unique_bincount', np.median(times))
55
+
File without changes
File without changes
@@ -0,0 +1,149 @@
1
+ from typing import Iterable, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as f
5
+ from torch import Tensor, nn
6
+ from torch.fft import irfftn, rfftn
7
+ from math import ceil, floor
8
+
9
+
10
+ # Taken from here: https://github.com/vcasellesb/fft-conv-pytorch/tree/non-tup-slice-fix. THANK YOU!
11
+ # Original codebase is here: https://github.com/fkodom/fft-conv-pytorch -> unfortunately was not updated for pytorch 2.9.
12
+
13
+
14
+ def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
15
+ """Multiplies two complex-valued tensors."""
16
+ # Scalar matrix multiplication of two tensors, over only the first channel
17
+ # dimensions. Dimensions 3 and higher will have the same shape after multiplication.
18
+ # We also allow for "grouped" multiplications, where multiple sections of channels
19
+ # are multiplied independently of one another (required for group convolutions).
20
+ a = a.view(a.size(0), groups, -1, *a.shape[2:])
21
+ b = b.view(groups, -1, *b.shape[1:])
22
+
23
+ a = torch.movedim(a, 2, a.dim() - 1).unsqueeze(-2)
24
+ b = torch.movedim(b, (1, 2), (b.dim() - 1, b.dim() - 2))
25
+
26
+ # complex value matrix multiplication
27
+ real = a.real @ b.real - a.imag @ b.imag
28
+ imag = a.imag @ b.real + a.real @ b.imag
29
+ real = torch.movedim(real, real.dim() - 1, 2).squeeze(-1)
30
+ imag = torch.movedim(imag, imag.dim() - 1, 2).squeeze(-1)
31
+ c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
32
+ c.real, c.imag = real, imag
33
+
34
+ return c.view(c.size(0), -1, *c.shape[3:])
35
+
36
+
37
+ def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:
38
+ """Casts to a tuple with length 'n'. Useful for automatically computing the
39
+ padding and stride for convolutions, where users may only provide an integer.
40
+
41
+ Args:
42
+ val: (Union[int, Iterable[int]]) Value to cast into a tuple.
43
+ n: (int) Desired length of the tuple
44
+
45
+ Returns:
46
+ (Tuple[int, ...]) Tuple of length 'n'
47
+ """
48
+ if isinstance(val, Iterable):
49
+ out = tuple(val)
50
+ if len(out) == n:
51
+ return out
52
+ else:
53
+ raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.")
54
+ else:
55
+ return n * (val,)
56
+
57
+
58
+ def fft_conv(
59
+ signal: Tensor,
60
+ kernel: Tensor,
61
+ bias: Tensor = None,
62
+ padding: Union[int, Iterable[int], str] = 0,
63
+ padding_mode: str = "constant",
64
+ stride: Union[int, Iterable[int]] = 1,
65
+ dilation: Union[int, Iterable[int]] = 1,
66
+ groups: int = 1,
67
+ ) -> Tensor:
68
+ """Performs N-d convolution of Tensors using a fast fourier transform, which
69
+ is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
70
+ the convolution (in order ot mimic the PyTorch direct convolution).
71
+
72
+ Args:
73
+ signal: (Tensor) Input tensor to be convolved with the kernel.
74
+ kernel: (Tensor) Convolution kernel.
75
+ bias: (Tensor) Bias tensor to add to the output.
76
+ padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
77
+ input on the last dimension. If str, "same" supported to pad input for size preservation.
78
+ padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
79
+ reflection not available for 3d.
80
+ stride: (Union[int, Iterable[int]) Stride size for computing output values.
81
+ dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
82
+ groups: (int) Number of groups for the convolution.
83
+
84
+ Returns:
85
+ (Tensor) Convolved tensor
86
+ """
87
+
88
+ # Cast padding, stride & dilation to tuples.
89
+ n = signal.ndim - 2
90
+ stride_ = to_ntuple(stride, n=n)
91
+ dilation_ = to_ntuple(dilation, n=n)
92
+ if isinstance(padding, str):
93
+ if padding == "same":
94
+ if stride != 1 or dilation != 1:
95
+ raise ValueError("stride must be 1 for padding='same'.")
96
+ padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
97
+ # else:
98
+ # raise ValueError(f"Padding mode {padding} not supported.")
99
+ else:
100
+ padding_ = to_ntuple(padding, n=n)
101
+
102
+ # internal dilation offsets
103
+ offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
104
+ offset[(slice(None), slice(None), *((0,) * n))] = 1.0
105
+
106
+ # correct the kernel by cutting off unwanted dilation trailing zeros
107
+ cutoff = tuple(slice(None, -d + 1 if d != 1 else None) for d in dilation_)
108
+
109
+ # pad the kernel internally according to the dilation parameters
110
+ kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
111
+
112
+ # Pad the input signal & kernel tensors (round to support even sized convolutions)
113
+ signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
114
+ signal = f.pad(signal, signal_padding, mode=padding_mode)
115
+
116
+ # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
117
+ # have *even* length. Just pad with one more zero if the final dimension is odd.
118
+ signal_size = signal.size() # original signal size without padding to even
119
+ if signal.size(-1) % 2 != 0:
120
+ signal = f.pad(signal, [0, 1])
121
+
122
+ kernel_padding = [
123
+ pad
124
+ for i in reversed(range(2, signal.ndim))
125
+ for pad in [0, signal.size(i) - kernel.size(i)]
126
+ ]
127
+ padded_kernel = f.pad(kernel, kernel_padding)
128
+
129
+ # Perform fourier convolution -- FFT, matrix multiply, then IFFT
130
+ signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
131
+ kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))
132
+
133
+ kernel_fr.imag *= -1
134
+ output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
135
+ output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
136
+
137
+ # Remove extra padded values
138
+ crop_slices = (slice(None), slice(None)) + tuple(
139
+ slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2])
140
+ for i in range(2, signal.ndim)
141
+ )
142
+ output = output[crop_slices].contiguous()
143
+
144
+ # Optionally, add a bias term before returning.
145
+ if bias is not None:
146
+ bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
147
+ output += bias.view(bias_shape)
148
+
149
+ return output
@@ -0,0 +1,28 @@
1
+ from typing import Union, Tuple, Callable
2
+ import numpy as np
3
+
4
+
5
+ RandomScalar = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]]
6
+
7
+
8
+ def sample_scalar(scalar_type: RandomScalar, *args, **kwargs):
9
+ if isinstance(scalar_type, (int, float)):
10
+ return scalar_type
11
+ elif isinstance(scalar_type, (list, tuple)):
12
+ assert len(scalar_type) == 2, 'if list is provided, its length must be 2'
13
+ assert scalar_type[0] <= scalar_type[1], 'if list is provided, first entry must be smaller or equal than second entry, ' \
14
+ 'otherwise we cannot sample using np.random.uniform'
15
+ if scalar_type[0] == scalar_type[1]:
16
+ return scalar_type[0]
17
+ return np.random.uniform(*scalar_type)
18
+ elif callable(scalar_type):
19
+ return scalar_type(*args, **kwargs)
20
+ else:
21
+ raise RuntimeError('Unknown type: %s. Expected: int, float, list, tuple, callable', type(scalar_type))
22
+
23
+
24
+ if __name__ == '__main__':
25
+ sample_scalar(0.5)
26
+ sample_scalar((0, 1))
27
+ sample_scalar(lambda: np.random.uniform(-1, 2))
28
+ sample_scalar(lambda x, y: np.random.uniform(x, y), 0.5, 2)
File without changes
File without changes
@@ -0,0 +1,77 @@
1
+ import abc
2
+ import torch
3
+
4
+
5
+ class BasicTransform(abc.ABC):
6
+ """
7
+ Transforms are applied to each sample individually. The dataloader is responsible for collating, or we might consider a CollateTransform
8
+
9
+ We expect (C, X, Y) or (C, X, Y, Z) shaped inputs for image and seg (yes seg can have more color channels)
10
+
11
+ No idea what keypoint and bbox will look like, this is Michaels turf
12
+ """
13
+ def __init__(self):
14
+ pass
15
+
16
+ def __call__(self, **data_dict) -> dict:
17
+ params = self.get_parameters(**data_dict)
18
+ return self.apply(data_dict, **params)
19
+
20
+ def apply(self, data_dict, **params):
21
+ if data_dict.get('image') is not None:
22
+ data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
23
+
24
+ if data_dict.get('regression_target') is not None:
25
+ data_dict['regression_target'] = self._apply_to_regr_target(data_dict['regression_target'], **params)
26
+
27
+ if data_dict.get('segmentation') is not None:
28
+ data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
29
+
30
+ if data_dict.get('keypoints') is not None:
31
+ data_dict['keypoints'] = self._apply_to_keypoints(data_dict['keypoints'], **params)
32
+
33
+ if data_dict.get('bbox') is not None:
34
+ data_dict['bbox'] = self._apply_to_bbox(data_dict['bbox'], **params)
35
+
36
+ return data_dict
37
+
38
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
39
+ pass
40
+
41
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
42
+ pass
43
+
44
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
45
+ pass
46
+
47
+ def _apply_to_keypoints(self, keypoints, **params):
48
+ pass
49
+
50
+ def _apply_to_bbox(self, bbox, **params):
51
+ pass
52
+
53
+ def get_parameters(self, **data_dict) -> dict:
54
+ return {}
55
+
56
+ def __repr__(self):
57
+ ret_str = str(type(self).__name__) + "( " + ", ".join(
58
+ [key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
59
+ return ret_str
60
+
61
+
62
+ class ImageOnlyTransform(BasicTransform):
63
+ def apply(self, data_dict: dict, **params) -> dict:
64
+ if data_dict.get('image') is not None:
65
+ data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
66
+ return data_dict
67
+
68
+
69
+ class SegOnlyTransform(BasicTransform):
70
+ def apply(self, data_dict: dict, **params) -> dict:
71
+ if data_dict.get('segmentation') is not None:
72
+ data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
73
+ return data_dict
74
+
75
+
76
+ if __name__ == '__main__':
77
+ pass
File without changes
@@ -0,0 +1,123 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class MultiplicativeBrightnessTransform(ImageOnlyTransform):
9
+ def __init__(self, multiplier_range: RandomScalar, synchronize_channels: bool, p_per_channel: float = 1):
10
+ super().__init__()
11
+ self.multiplier_range = multiplier_range
12
+ self.synchronize_channels = synchronize_channels
13
+ self.p_per_channel = p_per_channel
14
+
15
+ def get_parameters(self, **data_dict) -> dict:
16
+ shape = data_dict['image'].shape
17
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
18
+ if self.synchronize_channels:
19
+ multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
20
+ else:
21
+ multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
22
+ return {
23
+ 'apply_to_channel': apply_to_channel,
24
+ 'multipliers': multipliers
25
+ }
26
+
27
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
28
+ if len(params['apply_to_channel']) == 0:
29
+ return img
30
+ # even though this is array notation it's a lot slower. Shame shame
31
+ # img[params['apply_to_channel']] *= params['multipliers'].view(-1, *[1]*(img.ndim - 1))
32
+ for c, m in zip(params['apply_to_channel'], params['multipliers']):
33
+ img[c] *= m
34
+ return img
35
+
36
+
37
+ class BrightnessAdditiveTransform(ImageOnlyTransform):
38
+ """
39
+ Adds random additive brightness noise sampled from a Gaussian distribution (mu, sigma).
40
+
41
+ Supports either synchronized brightness shift across all channels or per-channel brightness shift.
42
+
43
+ Args:
44
+ mu (float): Mean of the Gaussian used to sample brightness shifts.
45
+ sigma (float): Standard deviation of the Gaussian.
46
+ synchronize_channels (bool): If True, brightness shifts are shared across all channels.
47
+ p_per_channel (float): Probability to apply the brightness shift to each channel.
48
+ """
49
+
50
+ def __init__(self,
51
+ mu: float,
52
+ sigma: float,
53
+ synchronize_channels: bool = True, # Changed to synchronize_channels
54
+ p_per_channel: float = 1.0):
55
+ super().__init__()
56
+ self.mu = mu
57
+ self.sigma = sigma
58
+ self.synchronize_channels = synchronize_channels # Now it's being used
59
+ self.p_per_channel = p_per_channel
60
+
61
+ def get_parameters(self, **data_dict) -> dict:
62
+ img = data_dict["image"]
63
+ c = img.shape[0]
64
+ apply_to_channel = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
65
+
66
+ if len(apply_to_channel) == 0:
67
+ return {"apply_to_channel": apply_to_channel, "shift": None}
68
+
69
+ # Apply either synchronized or per-channel brightness shift
70
+ if self.synchronize_channels:
71
+ shift_value = float(sample_scalar((self.mu, self.sigma), image=img, channel=None))
72
+ shift = torch.full((c,), shift_value, device=img.device)
73
+ else:
74
+ shift = torch.empty(c, device=img.device).normal_(float(self.mu), float(self.sigma))
75
+
76
+ return {"apply_to_channel": apply_to_channel, "shift": shift}
77
+
78
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
79
+ if params["shift"] is None:
80
+ return img
81
+
82
+ apply_idx = params["apply_to_channel"]
83
+ if apply_idx.numel() == 0:
84
+ return img
85
+
86
+ shift = params["shift"]
87
+ # Build full per-channel shift vector; non-selected channels get shift 0
88
+ shift_full = torch.zeros((img.shape[0],), device=img.device, dtype=img.dtype)
89
+ shift_full[apply_idx] = shift[apply_idx]
90
+
91
+ view_shape = (img.shape[0],) + (1,) * (img.ndim - 1)
92
+ img.add_(shift_full.view(view_shape))
93
+ return img
94
+
95
+
96
+ if __name__ == '__main__':
97
+ from time import time
98
+ import os
99
+
100
+ os.environ['OMP_NUM_THREADS'] = '1'
101
+ torch.set_num_threads(1)
102
+
103
+ # mbt = BrightnessAdditiveTransform(0, 0.5,True, 1)
104
+ mbt = MultiplicativeBrightnessTransform((0.5, 2),False, 1)
105
+
106
+ times_torch = []
107
+ for _ in range(1000):
108
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
109
+ st = time()
110
+ out = mbt(**data_dict)
111
+ times_torch.append(time() - st)
112
+ print('torch', np.mean(times_torch))
113
+
114
+ from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform
115
+
116
+ gnt_bg = BrightnessMultiplicativeTransform((0.5, 2), True, p_per_sample=1)
117
+ times_bg = []
118
+ for _ in range(1000):
119
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
120
+ st = time()
121
+ out = gnt_bg(**data_dict)
122
+ times_bg.append(time() - st)
123
+ print('bg', np.mean(times_bg))