batchgeneratorsv2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. batchgeneratorsv2/benchmarks/__init__.py +0 -0
  2. batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  3. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  4. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  5. batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  6. batchgeneratorsv2/dataloading/__init__.py +0 -0
  7. batchgeneratorsv2/helpers/__init__.py +0 -0
  8. batchgeneratorsv2/helpers/fft_conv.py +149 -0
  9. batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
  13. batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  15. batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
  16. batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  17. batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  18. batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  19. batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  20. batchgeneratorsv2/transforms/local/__init__.py +0 -0
  21. batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  22. batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  23. batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  24. batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  25. batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  26. batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
  28. batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  29. batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  30. batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  32. batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  33. batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  34. batchgeneratorsv2/transforms/noise/rician.py +61 -0
  35. batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  36. batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  37. batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  38. batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
  39. batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  40. batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  41. batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
  42. batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  43. batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  44. batchgeneratorsv2/transforms/utils/compose.py +89 -0
  45. batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  46. batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  47. batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  48. batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
  49. batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
  50. batchgeneratorsv2/transforms/utils/random.py +46 -0
  51. batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
  52. batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  53. batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
  54. batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
  55. batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
  56. batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
  57. batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,123 @@
1
+ import torch
2
+
3
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
4
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
5
+ import numpy as np
6
+
7
+
8
+ class BGContrast():
9
+ def __init__(self, contrast_range):
10
+ self.contrast_range = contrast_range
11
+
12
+ def sample_contrast(self, *args, **kwargs):
13
+ if callable(self.contrast_range):
14
+ factor = self.contrast_range()
15
+ else:
16
+ if np.random.random() < 0.5 and self.contrast_range[0] < 1:
17
+ factor = np.random.uniform(self.contrast_range[0], 1)
18
+ else:
19
+ factor = np.random.uniform(max(self.contrast_range[0], 1), self.contrast_range[1])
20
+ return factor
21
+
22
+ def __call__(self, *args, **kwargs):
23
+ return self.sample_contrast(*args, **kwargs)
24
+
25
+ def __repr__(self):
26
+ return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
27
+
28
+
29
+ import torch
30
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
31
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
32
+
33
+
34
+ class ContrastTransform(ImageOnlyTransform):
35
+ def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1.0):
36
+ super().__init__()
37
+ self.contrast_range = contrast_range
38
+ self.preserve_range = preserve_range
39
+ self.synchronize_channels = synchronize_channels
40
+ self.p_per_channel = float(p_per_channel)
41
+
42
+ def get_parameters(self, **data_dict) -> dict:
43
+ img = data_dict["image"]
44
+ c = img.shape[0]
45
+
46
+ # sample on correct device
47
+ apply_idx = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
48
+ n = apply_idx.numel()
49
+
50
+ if n == 0:
51
+ multipliers = None
52
+ elif self.synchronize_channels:
53
+ m = float(sample_scalar(self.contrast_range, image=img, channel=None))
54
+ multipliers = torch.full((n,), m, device=img.device, dtype=img.dtype)
55
+ else:
56
+ # Still a Python loop because sample_scalar is scalar-by-scalar
57
+ # Use .tolist() to avoid iterating tensor scalars in Python
58
+ ms = [sample_scalar(self.contrast_range, image=img, channel=int(ch)) for ch in apply_idx.tolist()]
59
+ multipliers = torch.as_tensor(ms, device=img.device, dtype=img.dtype)
60
+
61
+ return {"apply_to_channel": apply_idx, "multipliers": multipliers}
62
+
63
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
64
+ idx = params["apply_to_channel"]
65
+ multipliers = params["multipliers"]
66
+ if multipliers is None or idx.numel() == 0:
67
+ return img
68
+
69
+ if self.preserve_range:
70
+ for i in range(idx.numel()):
71
+ c = int(idx[i])
72
+ m = multipliers[i]
73
+
74
+ x = img[c]
75
+ mean = x.mean()
76
+ minm = x.min()
77
+ maxm = x.max()
78
+
79
+ x.sub_(mean)
80
+ x.mul_(m)
81
+ x.add_(mean)
82
+ x.clamp_(minm, maxm)
83
+ else:
84
+ for i in range(idx.numel()):
85
+ c = int(idx[i])
86
+ m = multipliers[i]
87
+
88
+ x = img[c]
89
+ mean = x.mean()
90
+ x.sub_(mean)
91
+ x.mul_(m)
92
+ x.add_(mean)
93
+
94
+ return img
95
+
96
+
97
+ if __name__ == '__main__':
98
+ from time import time
99
+ import os
100
+
101
+ os.environ['OMP_NUM_THREADS'] = '1'
102
+ torch.set_num_threads(1)
103
+
104
+ mbt = ContrastTransform(BGContrast((0.75, 1.25)).sample_contrast, True, False, p_per_channel=1)
105
+
106
+ times_torch = []
107
+ for _ in range(100):
108
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
109
+ st = time()
110
+ out = mbt(**data_dict)
111
+ times_torch.append(time() - st)
112
+ print('torch', np.mean(times_torch))
113
+
114
+ from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform
115
+
116
+ gnt_bg = ContrastAugmentationTransform((0.75, 1.25), preserve_range=True, per_channel=True, p_per_channel=1)
117
+ times_bg = []
118
+ for _ in range(100):
119
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
120
+ st = time()
121
+ out = gnt_bg(**data_dict)
122
+ times_bg.append(time() - st)
123
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,135 @@
1
+ from typing import Optional
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class GammaTransform(ImageOnlyTransform):
9
+ def __init__(self,
10
+ gamma: RandomScalar,
11
+ p_invert_image: float,
12
+ synchronize_channels: bool,
13
+ p_per_channel: float,
14
+ p_retain_stats: float):
15
+ super().__init__()
16
+ self.gamma = gamma
17
+ self.p_invert_image = float(p_invert_image)
18
+ self.synchronize_channels = synchronize_channels
19
+ self.p_per_channel = float(p_per_channel)
20
+ self.p_retain_stats = float(p_retain_stats)
21
+
22
+ def get_parameters(self, **data_dict) -> dict:
23
+ img: torch.Tensor = data_dict["image"]
24
+ c = img.shape[0]
25
+ device = img.device
26
+ dtype = img.dtype
27
+
28
+ apply_idx = (torch.rand(c, device=device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
29
+ n = apply_idx.numel()
30
+ if n == 0:
31
+ return {"apply_to_channel": apply_idx,
32
+ "retain_stats": None,
33
+ "invert_image": None,
34
+ "gamma": None}
35
+
36
+ retain_stats = (torch.rand(n, device=device) < self.p_retain_stats)
37
+ invert_image = (torch.rand(n, device=device) < self.p_invert_image)
38
+
39
+ if self.synchronize_channels:
40
+ g = float(sample_scalar(self.gamma, image=img, channel=None))
41
+ gamma = torch.full((n,), g, device=device, dtype=dtype)
42
+ else:
43
+ # sample_scalar is scalar-based; keep loop but avoid tensor scalar iteration
44
+ gs = [float(sample_scalar(self.gamma, image=img, channel=int(ch))) for ch in apply_idx.tolist()]
45
+ gamma = torch.as_tensor(gs, device=device, dtype=dtype)
46
+
47
+ return {
48
+ "apply_to_channel": apply_idx,
49
+ "retain_stats": retain_stats,
50
+ "invert_image": invert_image,
51
+ "gamma": gamma,
52
+ }
53
+
54
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
55
+ idx: torch.Tensor = params["apply_to_channel"]
56
+ if idx.numel() == 0:
57
+ return img
58
+
59
+ retain_stats: torch.Tensor = params["retain_stats"]
60
+ invert_image: torch.Tensor = params["invert_image"]
61
+ gamma: torch.Tensor = params["gamma"]
62
+
63
+ # constants
64
+ eps = 1e-7
65
+
66
+ # Loop over selected channels (good for small C)
67
+ for k in range(idx.numel()):
68
+ c = int(idx[k])
69
+ r = bool(retain_stats[k])
70
+ inv = bool(invert_image[k])
71
+ g = gamma[k]
72
+
73
+ x = img[c]
74
+
75
+ if inv:
76
+ x.mul_(-1)
77
+
78
+ if r:
79
+ mean = x.mean()
80
+ std = x.std()
81
+
82
+ minm = x.min()
83
+ maxm = x.max()
84
+ rnge = maxm - minm
85
+ denom = torch.clamp(rnge, min=eps)
86
+
87
+ # In-place gamma: x = (((x - min) / denom) ** g) * rnge + min
88
+ x.sub_(minm)
89
+ x.div_(denom)
90
+ x.pow_(g)
91
+ x.mul_(rnge)
92
+ x.add_(minm)
93
+
94
+ if r:
95
+ mn_here = x.mean()
96
+ std_here = x.std()
97
+ x.sub_(mn_here)
98
+ x.mul_(std / torch.clamp(std_here, min=eps))
99
+ x.add_(mean)
100
+
101
+ if inv:
102
+ x.mul_(-1)
103
+
104
+ return img
105
+
106
+
107
+
108
+ if __name__ == '__main__':
109
+ from time import time
110
+ import numpy as np
111
+ import os
112
+
113
+ os.environ['OMP_NUM_THREADS'] = '1'
114
+ torch.set_num_threads(1)
115
+
116
+ mbt = GammaTransform((0.7, 1.5), 0, False, 1, 1)
117
+
118
+ times_torch = []
119
+ for _ in range(100):
120
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
121
+ st = time()
122
+ out = mbt(**data_dict)
123
+ times_torch.append(time() - st)
124
+ print('torch', np.mean(times_torch))
125
+
126
+ from batchgenerators.transforms.color_transforms import GammaTransform as BGGamma
127
+
128
+ gnt_bg = BGGamma((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1)
129
+ times_bg = []
130
+ for _ in range(100):
131
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
132
+ st = time()
133
+ out = gnt_bg(**data_dict)
134
+ times_bg.append(time() - st)
135
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,104 @@
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class GaussianNoiseTransform(ImageOnlyTransform):
9
+ def __init__(self,
10
+ noise_variance: RandomScalar = (0, 0.1),
11
+ p_per_channel: float = 1.,
12
+ synchronize_channels: bool = False):
13
+ super().__init__()
14
+ self.noise_variance = noise_variance
15
+ self.p_per_channel = p_per_channel
16
+ self.synchronize_channels = synchronize_channels
17
+
18
+ def get_parameters(self, **data_dict) -> dict:
19
+ img = data_dict["image"]
20
+ c = img.shape[0]
21
+
22
+ # bool mask on same device as image
23
+ apply = torch.rand(c, device=img.device) < self.p_per_channel
24
+
25
+ # store also count / indices to avoid recomputing later
26
+ idx = apply.nonzero(as_tuple=False).flatten()
27
+ n = idx.numel()
28
+
29
+ if n == 0:
30
+ sigmas = None
31
+ elif self.synchronize_channels:
32
+ sigmas = sample_scalar(self.noise_variance, img)
33
+ else:
34
+ # still uses sample_scalar, but avoids list->cat in _apply
35
+ # if sample_scalar is cheap, this is fine; otherwise see note below
36
+ sigmas = [sample_scalar(self.noise_variance, img) for _ in range(n)]
37
+
38
+ return {"apply_mask": apply, "apply_idx": idx, "num_apply": n, "sigmas": sigmas}
39
+
40
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
41
+ n = params["num_apply"]
42
+ if n == 0:
43
+ return img
44
+
45
+ idx = params["apply_idx"]
46
+ spatial = img.shape[1:]
47
+ device = img.device
48
+ dtype = img.dtype
49
+
50
+ sigmas = params["sigmas"]
51
+
52
+ if sigmas is None:
53
+ return img
54
+
55
+ # Create noise only for selected channels
56
+ if not self.synchronize_channels:
57
+ # vectorize per-channel sigma by creating a tensor of shape (n, 1, 1, ...)
58
+ # list->tensor is small (n floats), then broadcast
59
+ sigma_t = torch.as_tensor(sigmas, device=device, dtype=dtype)
60
+ view_shape = (n,) + (1,) * len(spatial)
61
+ sigma_t = sigma_t.view(view_shape)
62
+
63
+ noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_()
64
+ noise.mul_(sigma_t)
65
+ else:
66
+ sigma = sigmas
67
+ noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_(mean=0.0, std=float(sigma))
68
+
69
+ # Advanced indexing (img[idx]) returns a copy, so use indexed assignment
70
+ # to make sure modifications are written back to img.
71
+ img[idx] += noise
72
+ return img
73
+
74
+
75
+ if __name__ == "__main__":
76
+ from time import time
77
+ import numpy as np
78
+
79
+ os.environ['OMP_NUM_THREADS'] = '1'
80
+ torch.set_num_threads(1)
81
+
82
+ gnt = GaussianNoiseTransform((0, 0.1), 1, False)
83
+
84
+ times = []
85
+ for _ in range(1000):
86
+ data_dict = {'image': torch.ones((2, 32, 32, 32))}
87
+ st = time()
88
+ out = gnt(**data_dict)
89
+ times.append(time() - st)
90
+ print('torch', np.mean(times))
91
+
92
+ from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform
93
+
94
+ gnt_bg = GaussianNoiseTransform((0, 0.1), 1, 1, True)
95
+
96
+ times = []
97
+ for _ in range(1000):
98
+ data_dict = {'data': np.ones((1, 2, 32, 32, 32))}
99
+ st = time()
100
+ out = gnt_bg(**data_dict)
101
+ times.append(time() - st)
102
+
103
+ print('bg', np.mean(times))
104
+ # torch is 2.5x faster
@@ -0,0 +1,51 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
5
+
6
+
7
+ class InvertImageTransform(ImageOnlyTransform):
8
+ def __init__(self, p_invert_image: float, p_synchronize_channels: float = 1, p_per_channel: float = 1):
9
+ super().__init__()
10
+ self.p_invert_image = p_invert_image
11
+ self.p_synchronize_channels = p_synchronize_channels
12
+ self.p_per_channel = p_per_channel
13
+
14
+ def get_parameters(self, **data_dict) -> dict:
15
+ shape = data_dict['image'].shape
16
+ apply = np.random.uniform() < self.p_invert_image
17
+ if apply:
18
+ if np.random.uniform() < self.p_synchronize_channels:
19
+ apply_to_channel = torch.arange(0, shape[0])
20
+ else:
21
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
22
+ else:
23
+ apply_to_channel = []
24
+ return {
25
+ 'apply_to_channel': apply_to_channel,
26
+ 'apply': apply,
27
+ }
28
+
29
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
30
+ if not params['apply']:
31
+ return img
32
+ else:
33
+ for ch in params['apply_to_channel']:
34
+ mn = img[ch].mean()
35
+ img[ch] -= mn
36
+ img[ch] *= -1
37
+ img[ch] += mn
38
+ return img
39
+
40
+
41
+ if __name__ == '__main__':
42
+ mbt = InvertImageTransform(0.5, 0.5, 0.5)
43
+ from batchviewer import view_batch
44
+
45
+ for _ in range(100):
46
+ data_dict = {'image': torch.ones((2, 20, 192, 64))}
47
+ data_dict['image'][0, :10] = -1
48
+ data_dict['image'][1, :5] = -1
49
+ ret = mbt(**data_dict)
50
+ print(ret['image'][0, 0, 0, 0], ret['image'][1, 0, 0, 0])
51
+ view_batch(mbt(**data_dict)['image'])
@@ -0,0 +1,101 @@
1
+ import torch
2
+ import numpy as np
3
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+
6
+
7
+ class CutOffOutliersTransform(ImageOnlyTransform):
8
+ """
9
+ Clamps intensities in the image to percentiles to remove outliers,
10
+ and optionally rescales the result to retain original standard deviation.
11
+
12
+ Args:
13
+ percentile_lower (RandomScalar): Lower cutoff percentile (0-100).
14
+ percentile_upper (RandomScalar): Upper cutoff percentile (0-100).
15
+ p_synchronize_channels (bool): If True, same percentiles are used for all channels.
16
+ p_per_channel (float): Probability to apply cutoff to each channel.
17
+ p_retain_std (float): Probability of retaining the original standard deviation after clipping.
18
+ """
19
+
20
+ def __init__(self,
21
+ percentile_lower: RandomScalar = 0.2,
22
+ percentile_upper: RandomScalar = 99.8,
23
+ p_synchronize_channels: bool = False,
24
+ p_per_channel: float = 1.0,
25
+ p_retain_std: float = 1.0):
26
+ super().__init__()
27
+ self.percentile_lower = percentile_lower
28
+ self.percentile_upper = percentile_upper
29
+ self.p_synchronize_channels = p_synchronize_channels
30
+ self.p_per_channel = p_per_channel
31
+ self.p_retain_std = p_retain_std
32
+
33
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
34
+ C = image.shape[0]
35
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
36
+
37
+ if self.p_synchronize_channels:
38
+ lower = float(sample_scalar(self.percentile_lower))
39
+ upper = float(sample_scalar(self.percentile_upper))
40
+ percentiles = [(lower, upper) if apply else None for apply in apply_channel]
41
+ else:
42
+ percentiles = []
43
+ for apply in apply_channel:
44
+ if not apply:
45
+ percentiles.append(None)
46
+ else:
47
+ lower = float(sample_scalar(self.percentile_lower))
48
+ upper = float(sample_scalar(self.percentile_upper))
49
+ percentiles.append((lower, upper))
50
+
51
+ retain_std_flags = [
52
+ np.random.rand() < self.p_retain_std if p is not None else False
53
+ for p in percentiles
54
+ ]
55
+
56
+ return {'percentiles': percentiles, 'retain_std': retain_std_flags}
57
+
58
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
59
+ percentiles = params['percentiles']
60
+ retain_std = params['retain_std']
61
+
62
+ for c, perc in enumerate(percentiles):
63
+ if perc is None:
64
+ continue
65
+
66
+ img_c = img[c]
67
+ if retain_std[c]:
68
+ orig_std = img_c.std()
69
+
70
+ # Percentiles in torch to avoid numpy roundtrip
71
+ q = torch.tensor([perc[0] / 100.0, perc[1] / 100.0], device=img_c.device, dtype=torch.float32)
72
+ lower_val, upper_val = torch.quantile(img_c.float(), q)
73
+
74
+ img_c_clipped = img_c.clamp(min=lower_val.item(), max=upper_val.item())
75
+
76
+ if retain_std[c]:
77
+ clipped_std = img_c_clipped.std()
78
+ if clipped_std > 1e-8:
79
+ img_c_clipped = (img_c_clipped - img_c_clipped.mean()) / clipped_std * orig_std + img_c_clipped.mean()
80
+
81
+ img[c] = img_c_clipped
82
+
83
+ return img
84
+
85
+ if __name__ == '__main__':
86
+ from batchviewer import view_batch
87
+
88
+ image = torch.randn(1, 32, 64, 64) * 5
89
+
90
+ transform = CutOffOutliersTransform(
91
+ percentile_lower=(0.5, 5),
92
+ percentile_upper=(95, 99.5),
93
+ p_synchronize_channels=True,
94
+ p_per_channel=1.0,
95
+ p_retain_std=0.5
96
+ )
97
+
98
+ params = transform.get_parameters(image=image)
99
+ image_clipped = transform._apply_to_image(image.clone(), **params)
100
+
101
+ view_batch(image, image_clipped, image_clipped-image)
File without changes
@@ -0,0 +1,177 @@
1
+ import torch
2
+ import numpy as np
3
+ from typing import List
4
+
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
7
+ from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
8
+
9
+
10
+ class BrightnessGradientAdditiveTransform(ImageOnlyTransform, LocalTransform):
11
+ """
12
+ Applies a localized brightness modulation to an image using a smooth Gaussian gradient.
13
+
14
+ This transform creates a spatial Gaussian kernel (in 2D or 3D), optionally zero-centers it,
15
+ scales its peak intensity, and adds it to the image. This can simulate intensity drift,
16
+ local contrast changes, or smooth lighting artifacts.
17
+
18
+ The effect is applied per channel, and each channel can have a different gradient or share the same one.
19
+
20
+ ---
21
+ Example use cases:
22
+ - Simulating local contrast shifts in MRI
23
+ - Adding spatial brightness gradients for robustness
24
+ - Mimicking smooth scanner inhomogeneity fields
25
+
26
+ Args:
27
+ scale (RandomScalar):
28
+ Controls the spatial spread of the Gaussian kernel (standard deviation).
29
+ Can be:
30
+ - float: fixed spread
31
+ - (min, max): uniformly sampled per-dimension
32
+ - callable(image_shape, dim): custom sampling per axis
33
+
34
+ loc (RandomScalar):
35
+ Controls the relative location of the Gaussian kernel (in percentage of image size).
36
+ Can be:
37
+ - (min, max): e.g. (-1, 2) allows centers to be far outside the image for smoother edges
38
+ - callable(image_shape, dim): custom sampling per axis
39
+
40
+ max_strength (RandomScalar):
41
+ Peak value of the additive brightness change (positive or negative depending on the Gaussian).
42
+ Can be:
43
+ - float: fixed strength
44
+ - (min, max): sampled strength
45
+ - callable(image, kernel): fully custom
46
+
47
+ same_for_all_channels (bool):
48
+ If True, one shared kernel is used across all channels.
49
+ If False, each channel gets its own random kernel and strength.
50
+
51
+ mean_centered (bool):
52
+ If True, the Gaussian kernel is mean-centered (i.e., ∑kernel = 0),
53
+ which ensures the overall mean intensity of the image stays constant.
54
+
55
+ clip_intensities (bool):
56
+ If True, clamps image values after modification to their original min/max.
57
+ Useful to prevent range overflow.
58
+
59
+ p_per_channel (float):
60
+ Probability to apply the transform to each channel independently.
61
+
62
+ Returns:
63
+ Modified image of the same shape with localized brightness modulation applied.
64
+
65
+ Example:
66
+ transform = BrightnessGradientAdditiveTransform(
67
+ scale=(5, 15),
68
+ max_strength=(0.1, 0.5),
69
+ same_for_all_channels=True,
70
+ mean_centered=True
71
+ )
72
+ """
73
+ def __init__(self,
74
+ scale: RandomScalar,
75
+ loc: RandomScalar = (-1, 2),
76
+ max_strength: RandomScalar = 1.0,
77
+ same_for_all_channels: bool = True,
78
+ mean_centered: bool = True,
79
+ clip_intensities: bool = False,
80
+ p_per_channel: float = 1.0):
81
+ ImageOnlyTransform.__init__(self)
82
+ LocalTransform.__init__(self, scale, loc)
83
+
84
+ self.max_strength = max_strength
85
+ self.same_for_all_channels = same_for_all_channels
86
+ self.mean_centered = mean_centered
87
+ self.clip_intensities = clip_intensities
88
+ self.p_per_channel = p_per_channel
89
+
90
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
91
+ C, *spatial = image.shape
92
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
93
+
94
+ # Early exit if nothing will be applied
95
+ if not any(apply_channel):
96
+ return {'kernels': [None] * C}
97
+
98
+ if self.same_for_all_channels:
99
+ kernel = self._generate_kernel(spatial)
100
+ if self.mean_centered:
101
+ kernel -= kernel.mean()
102
+
103
+ max_abs = np.abs(kernel).max()
104
+ if max_abs < 1e-8:
105
+ return {'kernels': [None] * C}
106
+
107
+ strength = sample_scalar(self.max_strength, image, kernel)
108
+ if strength == 0.0:
109
+ return {'kernels': [None] * C}
110
+
111
+ kernel /= max_abs
112
+ kernel *= strength
113
+
114
+ kernels = [kernel if apply else None for apply in apply_channel]
115
+
116
+ else:
117
+ kernels = []
118
+ for apply in apply_channel:
119
+ if not apply:
120
+ kernels.append(None)
121
+ continue
122
+
123
+ kernel = self._generate_kernel(spatial)
124
+ if self.mean_centered:
125
+ kernel -= kernel.mean()
126
+ max_abs = np.abs(kernel).max()
127
+ if max_abs < 1e-8:
128
+ kernels.append(None)
129
+ continue
130
+
131
+ strength = sample_scalar(self.max_strength, image, kernel)
132
+ if strength == 0.0:
133
+ kernels.append(None)
134
+ continue
135
+
136
+ kernel /= max_abs
137
+ kernel *= strength
138
+ kernels.append(kernel)
139
+
140
+ return {'kernels': kernels}
141
+
142
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
143
+ for c, kernel in enumerate(params['kernels']):
144
+ if kernel is None:
145
+ continue
146
+ kernel_tensor = torch.from_numpy(kernel).to(img.device, dtype=img.dtype)
147
+ img[c].add_(kernel_tensor)
148
+
149
+ if self.clip_intensities:
150
+ img.clamp_(min=img.min(), max=img.max())
151
+
152
+ return img
153
+
154
+ if __name__ == '__main__':
155
+ import torch
156
+ from batchviewer import view_batch
157
+
158
+ # Create synthetic z-score normalized 3D image (C, D, H, W)
159
+ image = torch.randn(1, 32, 64, 64) # single-channel 3D volume
160
+
161
+ # Instantiate the transform
162
+ transform = BrightnessGradientAdditiveTransform(
163
+ scale=(25, 50), # controls width of Gaussian
164
+ loc=(-0.5, 1.5),
165
+ max_strength=(2, 5), # how strong the modulation is
166
+ same_for_all_channels=True,
167
+ mean_centered=True,
168
+ clip_intensities=False,
169
+ p_per_channel=1.0 # always apply
170
+ )
171
+
172
+ # Get transform parameters and apply
173
+ params = transform.get_parameters(image=image)
174
+ image_modulated = transform._apply_to_image(image.clone(), **params)
175
+
176
+ # Visualize with your preferred viewer
177
+ view_batch(image, image_modulated)