batchgeneratorsv2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2/helpers/fft_conv.py +149 -0
- batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
- batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
- batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
- batchgeneratorsv2/transforms/local/__init__.py +0 -0
- batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
- batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
- batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
- batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
- batchgeneratorsv2/transforms/local/local_transform.py +86 -0
- batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
- batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
- batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
- batchgeneratorsv2/transforms/noise/rician.py +61 -0
- batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
- batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
- batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
- batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
- batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
- batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
- batchgeneratorsv2/transforms/utils/random.py +46 -0
- batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
- batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
- batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
- batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
- batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
- batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BGContrast():
|
|
9
|
+
def __init__(self, contrast_range):
|
|
10
|
+
self.contrast_range = contrast_range
|
|
11
|
+
|
|
12
|
+
def sample_contrast(self, *args, **kwargs):
|
|
13
|
+
if callable(self.contrast_range):
|
|
14
|
+
factor = self.contrast_range()
|
|
15
|
+
else:
|
|
16
|
+
if np.random.random() < 0.5 and self.contrast_range[0] < 1:
|
|
17
|
+
factor = np.random.uniform(self.contrast_range[0], 1)
|
|
18
|
+
else:
|
|
19
|
+
factor = np.random.uniform(max(self.contrast_range[0], 1), self.contrast_range[1])
|
|
20
|
+
return factor
|
|
21
|
+
|
|
22
|
+
def __call__(self, *args, **kwargs):
|
|
23
|
+
return self.sample_contrast(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
def __repr__(self):
|
|
26
|
+
return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
31
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ContrastTransform(ImageOnlyTransform):
|
|
35
|
+
def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1.0):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.contrast_range = contrast_range
|
|
38
|
+
self.preserve_range = preserve_range
|
|
39
|
+
self.synchronize_channels = synchronize_channels
|
|
40
|
+
self.p_per_channel = float(p_per_channel)
|
|
41
|
+
|
|
42
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
43
|
+
img = data_dict["image"]
|
|
44
|
+
c = img.shape[0]
|
|
45
|
+
|
|
46
|
+
# sample on correct device
|
|
47
|
+
apply_idx = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
|
|
48
|
+
n = apply_idx.numel()
|
|
49
|
+
|
|
50
|
+
if n == 0:
|
|
51
|
+
multipliers = None
|
|
52
|
+
elif self.synchronize_channels:
|
|
53
|
+
m = float(sample_scalar(self.contrast_range, image=img, channel=None))
|
|
54
|
+
multipliers = torch.full((n,), m, device=img.device, dtype=img.dtype)
|
|
55
|
+
else:
|
|
56
|
+
# Still a Python loop because sample_scalar is scalar-by-scalar
|
|
57
|
+
# Use .tolist() to avoid iterating tensor scalars in Python
|
|
58
|
+
ms = [sample_scalar(self.contrast_range, image=img, channel=int(ch)) for ch in apply_idx.tolist()]
|
|
59
|
+
multipliers = torch.as_tensor(ms, device=img.device, dtype=img.dtype)
|
|
60
|
+
|
|
61
|
+
return {"apply_to_channel": apply_idx, "multipliers": multipliers}
|
|
62
|
+
|
|
63
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
64
|
+
idx = params["apply_to_channel"]
|
|
65
|
+
multipliers = params["multipliers"]
|
|
66
|
+
if multipliers is None or idx.numel() == 0:
|
|
67
|
+
return img
|
|
68
|
+
|
|
69
|
+
if self.preserve_range:
|
|
70
|
+
for i in range(idx.numel()):
|
|
71
|
+
c = int(idx[i])
|
|
72
|
+
m = multipliers[i]
|
|
73
|
+
|
|
74
|
+
x = img[c]
|
|
75
|
+
mean = x.mean()
|
|
76
|
+
minm = x.min()
|
|
77
|
+
maxm = x.max()
|
|
78
|
+
|
|
79
|
+
x.sub_(mean)
|
|
80
|
+
x.mul_(m)
|
|
81
|
+
x.add_(mean)
|
|
82
|
+
x.clamp_(minm, maxm)
|
|
83
|
+
else:
|
|
84
|
+
for i in range(idx.numel()):
|
|
85
|
+
c = int(idx[i])
|
|
86
|
+
m = multipliers[i]
|
|
87
|
+
|
|
88
|
+
x = img[c]
|
|
89
|
+
mean = x.mean()
|
|
90
|
+
x.sub_(mean)
|
|
91
|
+
x.mul_(m)
|
|
92
|
+
x.add_(mean)
|
|
93
|
+
|
|
94
|
+
return img
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
if __name__ == '__main__':
|
|
98
|
+
from time import time
|
|
99
|
+
import os
|
|
100
|
+
|
|
101
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
102
|
+
torch.set_num_threads(1)
|
|
103
|
+
|
|
104
|
+
mbt = ContrastTransform(BGContrast((0.75, 1.25)).sample_contrast, True, False, p_per_channel=1)
|
|
105
|
+
|
|
106
|
+
times_torch = []
|
|
107
|
+
for _ in range(100):
|
|
108
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
109
|
+
st = time()
|
|
110
|
+
out = mbt(**data_dict)
|
|
111
|
+
times_torch.append(time() - st)
|
|
112
|
+
print('torch', np.mean(times_torch))
|
|
113
|
+
|
|
114
|
+
from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform
|
|
115
|
+
|
|
116
|
+
gnt_bg = ContrastAugmentationTransform((0.75, 1.25), preserve_range=True, per_channel=True, p_per_channel=1)
|
|
117
|
+
times_bg = []
|
|
118
|
+
for _ in range(100):
|
|
119
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
120
|
+
st = time()
|
|
121
|
+
out = gnt_bg(**data_dict)
|
|
122
|
+
times_bg.append(time() - st)
|
|
123
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GammaTransform(ImageOnlyTransform):
|
|
9
|
+
def __init__(self,
|
|
10
|
+
gamma: RandomScalar,
|
|
11
|
+
p_invert_image: float,
|
|
12
|
+
synchronize_channels: bool,
|
|
13
|
+
p_per_channel: float,
|
|
14
|
+
p_retain_stats: float):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.gamma = gamma
|
|
17
|
+
self.p_invert_image = float(p_invert_image)
|
|
18
|
+
self.synchronize_channels = synchronize_channels
|
|
19
|
+
self.p_per_channel = float(p_per_channel)
|
|
20
|
+
self.p_retain_stats = float(p_retain_stats)
|
|
21
|
+
|
|
22
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
23
|
+
img: torch.Tensor = data_dict["image"]
|
|
24
|
+
c = img.shape[0]
|
|
25
|
+
device = img.device
|
|
26
|
+
dtype = img.dtype
|
|
27
|
+
|
|
28
|
+
apply_idx = (torch.rand(c, device=device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
|
|
29
|
+
n = apply_idx.numel()
|
|
30
|
+
if n == 0:
|
|
31
|
+
return {"apply_to_channel": apply_idx,
|
|
32
|
+
"retain_stats": None,
|
|
33
|
+
"invert_image": None,
|
|
34
|
+
"gamma": None}
|
|
35
|
+
|
|
36
|
+
retain_stats = (torch.rand(n, device=device) < self.p_retain_stats)
|
|
37
|
+
invert_image = (torch.rand(n, device=device) < self.p_invert_image)
|
|
38
|
+
|
|
39
|
+
if self.synchronize_channels:
|
|
40
|
+
g = float(sample_scalar(self.gamma, image=img, channel=None))
|
|
41
|
+
gamma = torch.full((n,), g, device=device, dtype=dtype)
|
|
42
|
+
else:
|
|
43
|
+
# sample_scalar is scalar-based; keep loop but avoid tensor scalar iteration
|
|
44
|
+
gs = [float(sample_scalar(self.gamma, image=img, channel=int(ch))) for ch in apply_idx.tolist()]
|
|
45
|
+
gamma = torch.as_tensor(gs, device=device, dtype=dtype)
|
|
46
|
+
|
|
47
|
+
return {
|
|
48
|
+
"apply_to_channel": apply_idx,
|
|
49
|
+
"retain_stats": retain_stats,
|
|
50
|
+
"invert_image": invert_image,
|
|
51
|
+
"gamma": gamma,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
55
|
+
idx: torch.Tensor = params["apply_to_channel"]
|
|
56
|
+
if idx.numel() == 0:
|
|
57
|
+
return img
|
|
58
|
+
|
|
59
|
+
retain_stats: torch.Tensor = params["retain_stats"]
|
|
60
|
+
invert_image: torch.Tensor = params["invert_image"]
|
|
61
|
+
gamma: torch.Tensor = params["gamma"]
|
|
62
|
+
|
|
63
|
+
# constants
|
|
64
|
+
eps = 1e-7
|
|
65
|
+
|
|
66
|
+
# Loop over selected channels (good for small C)
|
|
67
|
+
for k in range(idx.numel()):
|
|
68
|
+
c = int(idx[k])
|
|
69
|
+
r = bool(retain_stats[k])
|
|
70
|
+
inv = bool(invert_image[k])
|
|
71
|
+
g = gamma[k]
|
|
72
|
+
|
|
73
|
+
x = img[c]
|
|
74
|
+
|
|
75
|
+
if inv:
|
|
76
|
+
x.mul_(-1)
|
|
77
|
+
|
|
78
|
+
if r:
|
|
79
|
+
mean = x.mean()
|
|
80
|
+
std = x.std()
|
|
81
|
+
|
|
82
|
+
minm = x.min()
|
|
83
|
+
maxm = x.max()
|
|
84
|
+
rnge = maxm - minm
|
|
85
|
+
denom = torch.clamp(rnge, min=eps)
|
|
86
|
+
|
|
87
|
+
# In-place gamma: x = (((x - min) / denom) ** g) * rnge + min
|
|
88
|
+
x.sub_(minm)
|
|
89
|
+
x.div_(denom)
|
|
90
|
+
x.pow_(g)
|
|
91
|
+
x.mul_(rnge)
|
|
92
|
+
x.add_(minm)
|
|
93
|
+
|
|
94
|
+
if r:
|
|
95
|
+
mn_here = x.mean()
|
|
96
|
+
std_here = x.std()
|
|
97
|
+
x.sub_(mn_here)
|
|
98
|
+
x.mul_(std / torch.clamp(std_here, min=eps))
|
|
99
|
+
x.add_(mean)
|
|
100
|
+
|
|
101
|
+
if inv:
|
|
102
|
+
x.mul_(-1)
|
|
103
|
+
|
|
104
|
+
return img
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == '__main__':
|
|
109
|
+
from time import time
|
|
110
|
+
import numpy as np
|
|
111
|
+
import os
|
|
112
|
+
|
|
113
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
114
|
+
torch.set_num_threads(1)
|
|
115
|
+
|
|
116
|
+
mbt = GammaTransform((0.7, 1.5), 0, False, 1, 1)
|
|
117
|
+
|
|
118
|
+
times_torch = []
|
|
119
|
+
for _ in range(100):
|
|
120
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
121
|
+
st = time()
|
|
122
|
+
out = mbt(**data_dict)
|
|
123
|
+
times_torch.append(time() - st)
|
|
124
|
+
print('torch', np.mean(times_torch))
|
|
125
|
+
|
|
126
|
+
from batchgenerators.transforms.color_transforms import GammaTransform as BGGamma
|
|
127
|
+
|
|
128
|
+
gnt_bg = BGGamma((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1)
|
|
129
|
+
times_bg = []
|
|
130
|
+
for _ in range(100):
|
|
131
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
132
|
+
st = time()
|
|
133
|
+
out = gnt_bg(**data_dict)
|
|
134
|
+
times_bg.append(time() - st)
|
|
135
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GaussianNoiseTransform(ImageOnlyTransform):
|
|
9
|
+
def __init__(self,
|
|
10
|
+
noise_variance: RandomScalar = (0, 0.1),
|
|
11
|
+
p_per_channel: float = 1.,
|
|
12
|
+
synchronize_channels: bool = False):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.noise_variance = noise_variance
|
|
15
|
+
self.p_per_channel = p_per_channel
|
|
16
|
+
self.synchronize_channels = synchronize_channels
|
|
17
|
+
|
|
18
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
19
|
+
img = data_dict["image"]
|
|
20
|
+
c = img.shape[0]
|
|
21
|
+
|
|
22
|
+
# bool mask on same device as image
|
|
23
|
+
apply = torch.rand(c, device=img.device) < self.p_per_channel
|
|
24
|
+
|
|
25
|
+
# store also count / indices to avoid recomputing later
|
|
26
|
+
idx = apply.nonzero(as_tuple=False).flatten()
|
|
27
|
+
n = idx.numel()
|
|
28
|
+
|
|
29
|
+
if n == 0:
|
|
30
|
+
sigmas = None
|
|
31
|
+
elif self.synchronize_channels:
|
|
32
|
+
sigmas = sample_scalar(self.noise_variance, img)
|
|
33
|
+
else:
|
|
34
|
+
# still uses sample_scalar, but avoids list->cat in _apply
|
|
35
|
+
# if sample_scalar is cheap, this is fine; otherwise see note below
|
|
36
|
+
sigmas = [sample_scalar(self.noise_variance, img) for _ in range(n)]
|
|
37
|
+
|
|
38
|
+
return {"apply_mask": apply, "apply_idx": idx, "num_apply": n, "sigmas": sigmas}
|
|
39
|
+
|
|
40
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
41
|
+
n = params["num_apply"]
|
|
42
|
+
if n == 0:
|
|
43
|
+
return img
|
|
44
|
+
|
|
45
|
+
idx = params["apply_idx"]
|
|
46
|
+
spatial = img.shape[1:]
|
|
47
|
+
device = img.device
|
|
48
|
+
dtype = img.dtype
|
|
49
|
+
|
|
50
|
+
sigmas = params["sigmas"]
|
|
51
|
+
|
|
52
|
+
if sigmas is None:
|
|
53
|
+
return img
|
|
54
|
+
|
|
55
|
+
# Create noise only for selected channels
|
|
56
|
+
if not self.synchronize_channels:
|
|
57
|
+
# vectorize per-channel sigma by creating a tensor of shape (n, 1, 1, ...)
|
|
58
|
+
# list->tensor is small (n floats), then broadcast
|
|
59
|
+
sigma_t = torch.as_tensor(sigmas, device=device, dtype=dtype)
|
|
60
|
+
view_shape = (n,) + (1,) * len(spatial)
|
|
61
|
+
sigma_t = sigma_t.view(view_shape)
|
|
62
|
+
|
|
63
|
+
noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_()
|
|
64
|
+
noise.mul_(sigma_t)
|
|
65
|
+
else:
|
|
66
|
+
sigma = sigmas
|
|
67
|
+
noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_(mean=0.0, std=float(sigma))
|
|
68
|
+
|
|
69
|
+
# Advanced indexing (img[idx]) returns a copy, so use indexed assignment
|
|
70
|
+
# to make sure modifications are written back to img.
|
|
71
|
+
img[idx] += noise
|
|
72
|
+
return img
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
if __name__ == "__main__":
|
|
76
|
+
from time import time
|
|
77
|
+
import numpy as np
|
|
78
|
+
|
|
79
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
80
|
+
torch.set_num_threads(1)
|
|
81
|
+
|
|
82
|
+
gnt = GaussianNoiseTransform((0, 0.1), 1, False)
|
|
83
|
+
|
|
84
|
+
times = []
|
|
85
|
+
for _ in range(1000):
|
|
86
|
+
data_dict = {'image': torch.ones((2, 32, 32, 32))}
|
|
87
|
+
st = time()
|
|
88
|
+
out = gnt(**data_dict)
|
|
89
|
+
times.append(time() - st)
|
|
90
|
+
print('torch', np.mean(times))
|
|
91
|
+
|
|
92
|
+
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform
|
|
93
|
+
|
|
94
|
+
gnt_bg = GaussianNoiseTransform((0, 0.1), 1, 1, True)
|
|
95
|
+
|
|
96
|
+
times = []
|
|
97
|
+
for _ in range(1000):
|
|
98
|
+
data_dict = {'data': np.ones((1, 2, 32, 32, 32))}
|
|
99
|
+
st = time()
|
|
100
|
+
out = gnt_bg(**data_dict)
|
|
101
|
+
times.append(time() - st)
|
|
102
|
+
|
|
103
|
+
print('bg', np.mean(times))
|
|
104
|
+
# torch is 2.5x faster
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class InvertImageTransform(ImageOnlyTransform):
|
|
8
|
+
def __init__(self, p_invert_image: float, p_synchronize_channels: float = 1, p_per_channel: float = 1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.p_invert_image = p_invert_image
|
|
11
|
+
self.p_synchronize_channels = p_synchronize_channels
|
|
12
|
+
self.p_per_channel = p_per_channel
|
|
13
|
+
|
|
14
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
15
|
+
shape = data_dict['image'].shape
|
|
16
|
+
apply = np.random.uniform() < self.p_invert_image
|
|
17
|
+
if apply:
|
|
18
|
+
if np.random.uniform() < self.p_synchronize_channels:
|
|
19
|
+
apply_to_channel = torch.arange(0, shape[0])
|
|
20
|
+
else:
|
|
21
|
+
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
|
|
22
|
+
else:
|
|
23
|
+
apply_to_channel = []
|
|
24
|
+
return {
|
|
25
|
+
'apply_to_channel': apply_to_channel,
|
|
26
|
+
'apply': apply,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
30
|
+
if not params['apply']:
|
|
31
|
+
return img
|
|
32
|
+
else:
|
|
33
|
+
for ch in params['apply_to_channel']:
|
|
34
|
+
mn = img[ch].mean()
|
|
35
|
+
img[ch] -= mn
|
|
36
|
+
img[ch] *= -1
|
|
37
|
+
img[ch] += mn
|
|
38
|
+
return img
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
if __name__ == '__main__':
|
|
42
|
+
mbt = InvertImageTransform(0.5, 0.5, 0.5)
|
|
43
|
+
from batchviewer import view_batch
|
|
44
|
+
|
|
45
|
+
for _ in range(100):
|
|
46
|
+
data_dict = {'image': torch.ones((2, 20, 192, 64))}
|
|
47
|
+
data_dict['image'][0, :10] = -1
|
|
48
|
+
data_dict['image'][1, :5] = -1
|
|
49
|
+
ret = mbt(**data_dict)
|
|
50
|
+
print(ret['image'][0, 0, 0, 0], ret['image'][1, 0, 0, 0])
|
|
51
|
+
view_batch(mbt(**data_dict)['image'])
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CutOffOutliersTransform(ImageOnlyTransform):
|
|
8
|
+
"""
|
|
9
|
+
Clamps intensities in the image to percentiles to remove outliers,
|
|
10
|
+
and optionally rescales the result to retain original standard deviation.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
percentile_lower (RandomScalar): Lower cutoff percentile (0-100).
|
|
14
|
+
percentile_upper (RandomScalar): Upper cutoff percentile (0-100).
|
|
15
|
+
p_synchronize_channels (bool): If True, same percentiles are used for all channels.
|
|
16
|
+
p_per_channel (float): Probability to apply cutoff to each channel.
|
|
17
|
+
p_retain_std (float): Probability of retaining the original standard deviation after clipping.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self,
|
|
21
|
+
percentile_lower: RandomScalar = 0.2,
|
|
22
|
+
percentile_upper: RandomScalar = 99.8,
|
|
23
|
+
p_synchronize_channels: bool = False,
|
|
24
|
+
p_per_channel: float = 1.0,
|
|
25
|
+
p_retain_std: float = 1.0):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.percentile_lower = percentile_lower
|
|
28
|
+
self.percentile_upper = percentile_upper
|
|
29
|
+
self.p_synchronize_channels = p_synchronize_channels
|
|
30
|
+
self.p_per_channel = p_per_channel
|
|
31
|
+
self.p_retain_std = p_retain_std
|
|
32
|
+
|
|
33
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
34
|
+
C = image.shape[0]
|
|
35
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
36
|
+
|
|
37
|
+
if self.p_synchronize_channels:
|
|
38
|
+
lower = float(sample_scalar(self.percentile_lower))
|
|
39
|
+
upper = float(sample_scalar(self.percentile_upper))
|
|
40
|
+
percentiles = [(lower, upper) if apply else None for apply in apply_channel]
|
|
41
|
+
else:
|
|
42
|
+
percentiles = []
|
|
43
|
+
for apply in apply_channel:
|
|
44
|
+
if not apply:
|
|
45
|
+
percentiles.append(None)
|
|
46
|
+
else:
|
|
47
|
+
lower = float(sample_scalar(self.percentile_lower))
|
|
48
|
+
upper = float(sample_scalar(self.percentile_upper))
|
|
49
|
+
percentiles.append((lower, upper))
|
|
50
|
+
|
|
51
|
+
retain_std_flags = [
|
|
52
|
+
np.random.rand() < self.p_retain_std if p is not None else False
|
|
53
|
+
for p in percentiles
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
return {'percentiles': percentiles, 'retain_std': retain_std_flags}
|
|
57
|
+
|
|
58
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
59
|
+
percentiles = params['percentiles']
|
|
60
|
+
retain_std = params['retain_std']
|
|
61
|
+
|
|
62
|
+
for c, perc in enumerate(percentiles):
|
|
63
|
+
if perc is None:
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
img_c = img[c]
|
|
67
|
+
if retain_std[c]:
|
|
68
|
+
orig_std = img_c.std()
|
|
69
|
+
|
|
70
|
+
# Percentiles in torch to avoid numpy roundtrip
|
|
71
|
+
q = torch.tensor([perc[0] / 100.0, perc[1] / 100.0], device=img_c.device, dtype=torch.float32)
|
|
72
|
+
lower_val, upper_val = torch.quantile(img_c.float(), q)
|
|
73
|
+
|
|
74
|
+
img_c_clipped = img_c.clamp(min=lower_val.item(), max=upper_val.item())
|
|
75
|
+
|
|
76
|
+
if retain_std[c]:
|
|
77
|
+
clipped_std = img_c_clipped.std()
|
|
78
|
+
if clipped_std > 1e-8:
|
|
79
|
+
img_c_clipped = (img_c_clipped - img_c_clipped.mean()) / clipped_std * orig_std + img_c_clipped.mean()
|
|
80
|
+
|
|
81
|
+
img[c] = img_c_clipped
|
|
82
|
+
|
|
83
|
+
return img
|
|
84
|
+
|
|
85
|
+
if __name__ == '__main__':
|
|
86
|
+
from batchviewer import view_batch
|
|
87
|
+
|
|
88
|
+
image = torch.randn(1, 32, 64, 64) * 5
|
|
89
|
+
|
|
90
|
+
transform = CutOffOutliersTransform(
|
|
91
|
+
percentile_lower=(0.5, 5),
|
|
92
|
+
percentile_upper=(95, 99.5),
|
|
93
|
+
p_synchronize_channels=True,
|
|
94
|
+
p_per_channel=1.0,
|
|
95
|
+
p_retain_std=0.5
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
params = transform.get_parameters(image=image)
|
|
99
|
+
image_clipped = transform._apply_to_image(image.clone(), **params)
|
|
100
|
+
|
|
101
|
+
view_batch(image, image_clipped, image_clipped-image)
|
|
File without changes
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
6
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
7
|
+
from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BrightnessGradientAdditiveTransform(ImageOnlyTransform, LocalTransform):
|
|
11
|
+
"""
|
|
12
|
+
Applies a localized brightness modulation to an image using a smooth Gaussian gradient.
|
|
13
|
+
|
|
14
|
+
This transform creates a spatial Gaussian kernel (in 2D or 3D), optionally zero-centers it,
|
|
15
|
+
scales its peak intensity, and adds it to the image. This can simulate intensity drift,
|
|
16
|
+
local contrast changes, or smooth lighting artifacts.
|
|
17
|
+
|
|
18
|
+
The effect is applied per channel, and each channel can have a different gradient or share the same one.
|
|
19
|
+
|
|
20
|
+
---
|
|
21
|
+
Example use cases:
|
|
22
|
+
- Simulating local contrast shifts in MRI
|
|
23
|
+
- Adding spatial brightness gradients for robustness
|
|
24
|
+
- Mimicking smooth scanner inhomogeneity fields
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
scale (RandomScalar):
|
|
28
|
+
Controls the spatial spread of the Gaussian kernel (standard deviation).
|
|
29
|
+
Can be:
|
|
30
|
+
- float: fixed spread
|
|
31
|
+
- (min, max): uniformly sampled per-dimension
|
|
32
|
+
- callable(image_shape, dim): custom sampling per axis
|
|
33
|
+
|
|
34
|
+
loc (RandomScalar):
|
|
35
|
+
Controls the relative location of the Gaussian kernel (in percentage of image size).
|
|
36
|
+
Can be:
|
|
37
|
+
- (min, max): e.g. (-1, 2) allows centers to be far outside the image for smoother edges
|
|
38
|
+
- callable(image_shape, dim): custom sampling per axis
|
|
39
|
+
|
|
40
|
+
max_strength (RandomScalar):
|
|
41
|
+
Peak value of the additive brightness change (positive or negative depending on the Gaussian).
|
|
42
|
+
Can be:
|
|
43
|
+
- float: fixed strength
|
|
44
|
+
- (min, max): sampled strength
|
|
45
|
+
- callable(image, kernel): fully custom
|
|
46
|
+
|
|
47
|
+
same_for_all_channels (bool):
|
|
48
|
+
If True, one shared kernel is used across all channels.
|
|
49
|
+
If False, each channel gets its own random kernel and strength.
|
|
50
|
+
|
|
51
|
+
mean_centered (bool):
|
|
52
|
+
If True, the Gaussian kernel is mean-centered (i.e., ∑kernel = 0),
|
|
53
|
+
which ensures the overall mean intensity of the image stays constant.
|
|
54
|
+
|
|
55
|
+
clip_intensities (bool):
|
|
56
|
+
If True, clamps image values after modification to their original min/max.
|
|
57
|
+
Useful to prevent range overflow.
|
|
58
|
+
|
|
59
|
+
p_per_channel (float):
|
|
60
|
+
Probability to apply the transform to each channel independently.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Modified image of the same shape with localized brightness modulation applied.
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
transform = BrightnessGradientAdditiveTransform(
|
|
67
|
+
scale=(5, 15),
|
|
68
|
+
max_strength=(0.1, 0.5),
|
|
69
|
+
same_for_all_channels=True,
|
|
70
|
+
mean_centered=True
|
|
71
|
+
)
|
|
72
|
+
"""
|
|
73
|
+
def __init__(self,
|
|
74
|
+
scale: RandomScalar,
|
|
75
|
+
loc: RandomScalar = (-1, 2),
|
|
76
|
+
max_strength: RandomScalar = 1.0,
|
|
77
|
+
same_for_all_channels: bool = True,
|
|
78
|
+
mean_centered: bool = True,
|
|
79
|
+
clip_intensities: bool = False,
|
|
80
|
+
p_per_channel: float = 1.0):
|
|
81
|
+
ImageOnlyTransform.__init__(self)
|
|
82
|
+
LocalTransform.__init__(self, scale, loc)
|
|
83
|
+
|
|
84
|
+
self.max_strength = max_strength
|
|
85
|
+
self.same_for_all_channels = same_for_all_channels
|
|
86
|
+
self.mean_centered = mean_centered
|
|
87
|
+
self.clip_intensities = clip_intensities
|
|
88
|
+
self.p_per_channel = p_per_channel
|
|
89
|
+
|
|
90
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
91
|
+
C, *spatial = image.shape
|
|
92
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
93
|
+
|
|
94
|
+
# Early exit if nothing will be applied
|
|
95
|
+
if not any(apply_channel):
|
|
96
|
+
return {'kernels': [None] * C}
|
|
97
|
+
|
|
98
|
+
if self.same_for_all_channels:
|
|
99
|
+
kernel = self._generate_kernel(spatial)
|
|
100
|
+
if self.mean_centered:
|
|
101
|
+
kernel -= kernel.mean()
|
|
102
|
+
|
|
103
|
+
max_abs = np.abs(kernel).max()
|
|
104
|
+
if max_abs < 1e-8:
|
|
105
|
+
return {'kernels': [None] * C}
|
|
106
|
+
|
|
107
|
+
strength = sample_scalar(self.max_strength, image, kernel)
|
|
108
|
+
if strength == 0.0:
|
|
109
|
+
return {'kernels': [None] * C}
|
|
110
|
+
|
|
111
|
+
kernel /= max_abs
|
|
112
|
+
kernel *= strength
|
|
113
|
+
|
|
114
|
+
kernels = [kernel if apply else None for apply in apply_channel]
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
kernels = []
|
|
118
|
+
for apply in apply_channel:
|
|
119
|
+
if not apply:
|
|
120
|
+
kernels.append(None)
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
kernel = self._generate_kernel(spatial)
|
|
124
|
+
if self.mean_centered:
|
|
125
|
+
kernel -= kernel.mean()
|
|
126
|
+
max_abs = np.abs(kernel).max()
|
|
127
|
+
if max_abs < 1e-8:
|
|
128
|
+
kernels.append(None)
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
strength = sample_scalar(self.max_strength, image, kernel)
|
|
132
|
+
if strength == 0.0:
|
|
133
|
+
kernels.append(None)
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
kernel /= max_abs
|
|
137
|
+
kernel *= strength
|
|
138
|
+
kernels.append(kernel)
|
|
139
|
+
|
|
140
|
+
return {'kernels': kernels}
|
|
141
|
+
|
|
142
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
143
|
+
for c, kernel in enumerate(params['kernels']):
|
|
144
|
+
if kernel is None:
|
|
145
|
+
continue
|
|
146
|
+
kernel_tensor = torch.from_numpy(kernel).to(img.device, dtype=img.dtype)
|
|
147
|
+
img[c].add_(kernel_tensor)
|
|
148
|
+
|
|
149
|
+
if self.clip_intensities:
|
|
150
|
+
img.clamp_(min=img.min(), max=img.max())
|
|
151
|
+
|
|
152
|
+
return img
|
|
153
|
+
|
|
154
|
+
if __name__ == '__main__':
|
|
155
|
+
import torch
|
|
156
|
+
from batchviewer import view_batch
|
|
157
|
+
|
|
158
|
+
# Create synthetic z-score normalized 3D image (C, D, H, W)
|
|
159
|
+
image = torch.randn(1, 32, 64, 64) # single-channel 3D volume
|
|
160
|
+
|
|
161
|
+
# Instantiate the transform
|
|
162
|
+
transform = BrightnessGradientAdditiveTransform(
|
|
163
|
+
scale=(25, 50), # controls width of Gaussian
|
|
164
|
+
loc=(-0.5, 1.5),
|
|
165
|
+
max_strength=(2, 5), # how strong the modulation is
|
|
166
|
+
same_for_all_channels=True,
|
|
167
|
+
mean_centered=True,
|
|
168
|
+
clip_intensities=False,
|
|
169
|
+
p_per_channel=1.0 # always apply
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Get transform parameters and apply
|
|
173
|
+
params = transform.get_parameters(image=image)
|
|
174
|
+
image_modulated = transform._apply_to_image(image.clone(), **params)
|
|
175
|
+
|
|
176
|
+
# Visualize with your preferred viewer
|
|
177
|
+
view_batch(image, image_modulated)
|