batchgeneratorsv2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2/helpers/fft_conv.py +149 -0
- batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
- batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
- batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
- batchgeneratorsv2/transforms/local/__init__.py +0 -0
- batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
- batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
- batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
- batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
- batchgeneratorsv2/transforms/local/local_transform.py +86 -0
- batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
- batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
- batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
- batchgeneratorsv2/transforms/noise/rician.py +61 -0
- batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
- batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
- batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
- batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
- batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
- batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
- batchgeneratorsv2/transforms/utils/random.py +46 -0
- batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
- batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
- batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
- batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
- batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
- batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LocalContrastTransform(ImageOnlyTransform, LocalTransform):
|
|
9
|
+
"""
|
|
10
|
+
Applies localized contrast modification using a spatial Gaussian mask.
|
|
11
|
+
|
|
12
|
+
A contrast-modified version of the image is blended with the original using a kernel-based interpolation.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
scale (RandomScalar): Gaussian spread for the spatial weighting mask.
|
|
16
|
+
loc (RandomScalar): Relative center position for the Gaussian (in % of image size).
|
|
17
|
+
new_contrast (RandomScalar): Multiplicative factor for local contrast. 1.0 = no change.
|
|
18
|
+
same_for_all_channels (bool): Whether to use one kernel/contrast value for all channels.
|
|
19
|
+
p_per_channel (float): Probability to apply to each channel.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
scale: RandomScalar,
|
|
24
|
+
loc: RandomScalar = (-1, 2),
|
|
25
|
+
new_contrast: RandomScalar = (0.5, 1.5),
|
|
26
|
+
same_for_all_channels: bool = True,
|
|
27
|
+
p_per_channel: float = 1.0):
|
|
28
|
+
ImageOnlyTransform.__init__(self)
|
|
29
|
+
LocalTransform.__init__(self, scale, loc)
|
|
30
|
+
|
|
31
|
+
self.new_contrast = new_contrast
|
|
32
|
+
self.same_for_all_channels = same_for_all_channels
|
|
33
|
+
self.p_per_channel = p_per_channel
|
|
34
|
+
|
|
35
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
36
|
+
C, *spatial = image.shape
|
|
37
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
38
|
+
|
|
39
|
+
if not any(apply_channel):
|
|
40
|
+
return {'kernels': [None] * C, 'contrasts': [None] * C}
|
|
41
|
+
|
|
42
|
+
if self.same_for_all_channels:
|
|
43
|
+
kernel = self._generate_kernel(spatial).astype(np.float32)
|
|
44
|
+
contrast = sample_scalar(self.new_contrast)
|
|
45
|
+
|
|
46
|
+
kernels = [kernel if apply else None for apply in apply_channel]
|
|
47
|
+
contrasts = [contrast if apply else None for apply in apply_channel]
|
|
48
|
+
else:
|
|
49
|
+
kernels, contrasts = [], []
|
|
50
|
+
for apply in apply_channel:
|
|
51
|
+
if not apply:
|
|
52
|
+
kernels.append(None)
|
|
53
|
+
contrasts.append(None)
|
|
54
|
+
continue
|
|
55
|
+
kernel = self._generate_kernel(spatial).astype(np.float32)
|
|
56
|
+
contrast = sample_scalar(self.new_contrast)
|
|
57
|
+
kernels.append(kernel)
|
|
58
|
+
contrasts.append(contrast)
|
|
59
|
+
|
|
60
|
+
return {'kernels': kernels, 'contrasts': contrasts}
|
|
61
|
+
|
|
62
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
63
|
+
img_np = img.cpu().numpy()
|
|
64
|
+
|
|
65
|
+
for c, (kernel, contrast) in enumerate(zip(params['kernels'], params['contrasts'])):
|
|
66
|
+
if kernel is None:
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
channel = img_np[c]
|
|
70
|
+
mean = (channel * kernel).sum() / (kernel.sum() + 1e-8)
|
|
71
|
+
modified = (channel - mean) * contrast + mean
|
|
72
|
+
img_np[c] = self.run_interpolation(channel, modified, kernel)
|
|
73
|
+
|
|
74
|
+
return torch.from_numpy(img_np).to(img.device, dtype=img.dtype)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
if __name__ == '__main__':
|
|
78
|
+
from batchviewer import view_batch
|
|
79
|
+
|
|
80
|
+
# Single-channel synthetic volume
|
|
81
|
+
image = torch.rand(1, 32, 64, 64)
|
|
82
|
+
|
|
83
|
+
# Or contrast
|
|
84
|
+
contrast = LocalContrastTransform(scale=(10, 20), new_contrast=(0.3, 2.0), p_per_channel=1.0)
|
|
85
|
+
|
|
86
|
+
# Apply either one
|
|
87
|
+
params = contrast.get_parameters(image=image)
|
|
88
|
+
image_aug = contrast._apply_to_image(image.clone(), **params)
|
|
89
|
+
|
|
90
|
+
view_batch(image, image_aug)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LocalGammaTransform(ImageOnlyTransform, LocalTransform):
|
|
9
|
+
"""
|
|
10
|
+
Applies locally varying gamma correction to an image using a spatial Gaussian weighting mask.
|
|
11
|
+
|
|
12
|
+
A Gaussian kernel is randomly placed in the image and used to blend between the original image and
|
|
13
|
+
a gamma-corrected version. This simulates localized nonlinear intensity shifts, useful for data augmentation
|
|
14
|
+
in medical imaging or general contrast robustness.
|
|
15
|
+
|
|
16
|
+
Parameters:
|
|
17
|
+
scale (RandomScalar): Controls the width of the Gaussian (std dev). Recommend large values (e.g. 10–30).
|
|
18
|
+
loc (RandomScalar): Controls Gaussian center as a % of image size. E.g. (-1, 2) allows off-canvas kernels.
|
|
19
|
+
gamma (RandomScalar): The gamma exponent applied locally. Try wild distributions :)
|
|
20
|
+
same_for_all_channels (bool): If True, one kernel is reused across all channels. Otherwise sampled per-channel.
|
|
21
|
+
p_per_channel (float): Probability to apply gamma correction to each channel independently.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
scale: RandomScalar,
|
|
26
|
+
loc: RandomScalar = (-1, 2),
|
|
27
|
+
gamma: RandomScalar = (0.5, 1),
|
|
28
|
+
same_for_all_channels: bool = True,
|
|
29
|
+
p_per_channel: float = 1.0):
|
|
30
|
+
ImageOnlyTransform.__init__(self)
|
|
31
|
+
LocalTransform.__init__(self, scale, loc)
|
|
32
|
+
|
|
33
|
+
self.gamma = gamma
|
|
34
|
+
self.same_for_all_channels = same_for_all_channels
|
|
35
|
+
self.p_per_channel = p_per_channel
|
|
36
|
+
|
|
37
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
38
|
+
C, *spatial = image.shape
|
|
39
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
40
|
+
|
|
41
|
+
if not any(apply_channel):
|
|
42
|
+
return {'kernels': [None] * C, 'gammas': [None] * C}
|
|
43
|
+
|
|
44
|
+
if self.same_for_all_channels:
|
|
45
|
+
kernel = self._generate_kernel(spatial).astype(np.float32)
|
|
46
|
+
gamma = sample_scalar(self.gamma)
|
|
47
|
+
|
|
48
|
+
kernels = [kernel if apply else None for apply in apply_channel]
|
|
49
|
+
gammas = [gamma if apply else None for apply in apply_channel]
|
|
50
|
+
else:
|
|
51
|
+
kernels, gammas = [], []
|
|
52
|
+
for apply in apply_channel:
|
|
53
|
+
if not apply:
|
|
54
|
+
kernels.append(None)
|
|
55
|
+
gammas.append(None)
|
|
56
|
+
continue
|
|
57
|
+
kernel = self._generate_kernel(spatial).astype(np.float32)
|
|
58
|
+
gamma = sample_scalar(self.gamma)
|
|
59
|
+
kernels.append(kernel)
|
|
60
|
+
gammas.append(gamma)
|
|
61
|
+
|
|
62
|
+
return {'kernels': kernels, 'gammas': gammas}
|
|
63
|
+
|
|
64
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
65
|
+
img_np = img.cpu().numpy()
|
|
66
|
+
|
|
67
|
+
for c, (kernel, gamma) in enumerate(zip(params['kernels'], params['gammas'])):
|
|
68
|
+
if kernel is None:
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
channel = img_np[c]
|
|
72
|
+
min_val, max_val = channel.min(), channel.max()
|
|
73
|
+
denom = max(max_val - min_val, 1e-8)
|
|
74
|
+
|
|
75
|
+
# Normalize to [0, 1]
|
|
76
|
+
norm = (channel - min_val) / denom
|
|
77
|
+
gamma_corrected = np.power(norm, gamma)
|
|
78
|
+
|
|
79
|
+
# Blend using kernel
|
|
80
|
+
blended = self.run_interpolation(norm, gamma_corrected, kernel)
|
|
81
|
+
|
|
82
|
+
# Rescale to original range
|
|
83
|
+
img_np[c] = blended * denom + min_val
|
|
84
|
+
|
|
85
|
+
return torch.from_numpy(img_np).to(img.device, dtype=img.dtype)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
if __name__ == '__main__':
|
|
89
|
+
import torch
|
|
90
|
+
from batchviewer import view_batch
|
|
91
|
+
|
|
92
|
+
image = torch.rand(1, 32, 64, 64) # (C, D, H, W)
|
|
93
|
+
|
|
94
|
+
transform = LocalGammaTransform(
|
|
95
|
+
scale=(10, 20),
|
|
96
|
+
gamma=lambda *_: np.random.uniform(0.01, 1) if np.random.rand() < 0.5 else np.random.uniform(1, 3),
|
|
97
|
+
same_for_all_channels=False,
|
|
98
|
+
p_per_channel=1.0
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
params = transform.get_parameters(image=image)
|
|
102
|
+
image_gamma = transform._apply_to_image(image.clone(), **params)
|
|
103
|
+
|
|
104
|
+
view_batch(image, image_gamma)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
|
|
4
|
+
from scipy.ndimage import gaussian_filter
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
6
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LocalSmoothingTransform(ImageOnlyTransform, LocalTransform):
|
|
10
|
+
"""
|
|
11
|
+
Applies localized Gaussian smoothing to parts of the image using a spatial Gaussian mask.
|
|
12
|
+
|
|
13
|
+
A blurred copy of the image is interpolated with the original, weighted by a Gaussian kernel.
|
|
14
|
+
The strength and extent of the blur are both controllable.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
scale (RandomScalar): Gaussian spread for the spatial weighting mask.
|
|
18
|
+
loc (RandomScalar): Relative center position for the Gaussian (in % of image size).
|
|
19
|
+
smoothing_strength (RandomScalar): Max weight of the smoothed image in the interpolation [0, 1].
|
|
20
|
+
kernel_size (RandomScalar): Sigma for the actual Gaussian smoothing of the image.
|
|
21
|
+
same_for_all_channels (bool): Whether to apply the same kernel to all channels.
|
|
22
|
+
p_per_channel (float): Probability of applying transform per channel.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self,
|
|
26
|
+
scale: RandomScalar,
|
|
27
|
+
loc: RandomScalar = (-1, 2),
|
|
28
|
+
smoothing_strength: RandomScalar = (0.5, 1.0),
|
|
29
|
+
kernel_size: RandomScalar = (0.5, 1.5),
|
|
30
|
+
same_for_all_channels: bool = True,
|
|
31
|
+
p_per_channel: float = 1.0):
|
|
32
|
+
ImageOnlyTransform.__init__(self)
|
|
33
|
+
LocalTransform.__init__(self, scale, loc)
|
|
34
|
+
|
|
35
|
+
self.smoothing_strength = smoothing_strength
|
|
36
|
+
self.kernel_size = kernel_size
|
|
37
|
+
self.same_for_all_channels = same_for_all_channels
|
|
38
|
+
self.p_per_channel = p_per_channel
|
|
39
|
+
|
|
40
|
+
def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
|
|
41
|
+
C, *spatial = image.shape
|
|
42
|
+
apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
|
|
43
|
+
|
|
44
|
+
if not any(apply_channel):
|
|
45
|
+
return {'kernels': [None] * C, 'sigma': None, 'strengths': [None] * C}
|
|
46
|
+
|
|
47
|
+
sigma = sample_scalar(self.kernel_size)
|
|
48
|
+
|
|
49
|
+
if self.same_for_all_channels:
|
|
50
|
+
kernel = self._generate_kernel(spatial).astype(np.float32)
|
|
51
|
+
strength = sample_scalar(self.smoothing_strength)
|
|
52
|
+
|
|
53
|
+
kernels = [kernel if apply else None for apply in apply_channel]
|
|
54
|
+
strengths = [strength if apply else None for apply in apply_channel]
|
|
55
|
+
else:
|
|
56
|
+
kernels, strengths = [], []
|
|
57
|
+
for apply in apply_channel:
|
|
58
|
+
if not apply:
|
|
59
|
+
kernels.append(None)
|
|
60
|
+
strengths.append(None)
|
|
61
|
+
continue
|
|
62
|
+
kernel = self._generate_kernel(spatial).astype(np.float32)
|
|
63
|
+
strength = sample_scalar(self.smoothing_strength)
|
|
64
|
+
kernels.append(kernel)
|
|
65
|
+
strengths.append(strength)
|
|
66
|
+
|
|
67
|
+
return {'kernels': kernels, 'sigma': sigma, 'strengths': strengths}
|
|
68
|
+
|
|
69
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
70
|
+
img_np = img.cpu().numpy()
|
|
71
|
+
sigma = params['sigma']
|
|
72
|
+
|
|
73
|
+
for c, (kernel, strength) in enumerate(zip(params['kernels'], params['strengths'])):
|
|
74
|
+
if kernel is None:
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
kernel = kernel * strength # scale kernel by smoothing strength
|
|
78
|
+
smoothed = gaussian_filter(img_np[c], sigma=sigma)
|
|
79
|
+
img_np[c] = self.run_interpolation(img_np[c], smoothed, kernel)
|
|
80
|
+
|
|
81
|
+
return torch.from_numpy(img_np).to(img.device, dtype=img.dtype)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == '__main__':
|
|
85
|
+
from batchviewer import view_batch
|
|
86
|
+
|
|
87
|
+
# Single-channel synthetic volume
|
|
88
|
+
image = torch.rand(1, 32, 64, 64)
|
|
89
|
+
|
|
90
|
+
# Or contrast
|
|
91
|
+
smoother = LocalSmoothingTransform(loc=(0, 1), scale=(10, 20), kernel_size=(3, 10), p_per_channel=1.0)
|
|
92
|
+
|
|
93
|
+
# Apply either one
|
|
94
|
+
params = smoother.get_parameters(image=image)
|
|
95
|
+
image_aug = smoother._apply_to_image(image.clone(), **params)
|
|
96
|
+
|
|
97
|
+
view_batch(image, image_aug)
|
|
98
|
+
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import scipy.stats as st
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LocalTransform(ABC):
|
|
9
|
+
def __init__(self, scale: RandomScalar, loc: RandomScalar = (-1, 2)):
|
|
10
|
+
self.loc = loc
|
|
11
|
+
self.scale = scale
|
|
12
|
+
|
|
13
|
+
def _generate_kernel(self, img_shp: Tuple[int, ...]) -> np.ndarray:
|
|
14
|
+
ndim = len(img_shp)
|
|
15
|
+
x_grids = [np.arange(-0.5, s + 0.5, dtype=np.float32) for s in img_shp]
|
|
16
|
+
kernels = []
|
|
17
|
+
|
|
18
|
+
for d in range(ndim):
|
|
19
|
+
loc_val = sample_scalar(self.loc, img_shp, d)
|
|
20
|
+
scale_val = sample_scalar(self.scale, img_shp, d)
|
|
21
|
+
loc_rescaled = loc_val * img_shp[d]
|
|
22
|
+
cdf = st.norm.cdf(x_grids[d], loc=loc_rescaled, scale=scale_val)
|
|
23
|
+
kernels.append(np.diff(cdf).astype(np.float32))
|
|
24
|
+
|
|
25
|
+
kernel = kernels[0][:, None] @ kernels[1][None]
|
|
26
|
+
if ndim == 3:
|
|
27
|
+
kernel = kernel[:, :, None] @ kernels[2][None]
|
|
28
|
+
|
|
29
|
+
kernel -= kernel.min()
|
|
30
|
+
kernel_max = kernel.max()
|
|
31
|
+
if kernel_max > 0:
|
|
32
|
+
kernel /= kernel_max
|
|
33
|
+
return kernel
|
|
34
|
+
|
|
35
|
+
def _generate_multiple_kernel_image(self, img_shp: Tuple[int, ...], num_kernels: int) -> np.ndarray:
|
|
36
|
+
"""
|
|
37
|
+
Places multiple additive Gaussians in the image and normalizes the sum to [0, 1].
|
|
38
|
+
|
|
39
|
+
Parameters:
|
|
40
|
+
img_shp (Tuple[int, ...]): Spatial shape (e.g., (X, Y[, Z]))
|
|
41
|
+
num_kernels (int): Number of kernels to generate and sum
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
np.ndarray: Combined kernel image with values in [0, 1]
|
|
45
|
+
"""
|
|
46
|
+
kernel_image = np.zeros(img_shp, dtype=np.float32)
|
|
47
|
+
for _ in range(num_kernels):
|
|
48
|
+
kernel_image += self._generate_kernel(img_shp)
|
|
49
|
+
|
|
50
|
+
kernel_image -= kernel_image.min()
|
|
51
|
+
kernel_max = kernel_image.max()
|
|
52
|
+
if kernel_max > 0:
|
|
53
|
+
kernel_image /= kernel_max
|
|
54
|
+
return kernel_image
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def invert_kernel(kernel_image: np.ndarray) -> np.ndarray:
|
|
58
|
+
"""
|
|
59
|
+
Inverts a normalized kernel: 1 - kernel
|
|
60
|
+
|
|
61
|
+
Assumes input is in [0, 1].
|
|
62
|
+
|
|
63
|
+
Parameters:
|
|
64
|
+
kernel_image (np.ndarray): Input kernel in [0, 1]
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
np.ndarray: Inverted kernel in [0, 1]
|
|
68
|
+
"""
|
|
69
|
+
return 1.0 - kernel_image
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def run_interpolation(original_image: np.ndarray,
|
|
73
|
+
modified_image: np.ndarray,
|
|
74
|
+
kernel_image: np.ndarray) -> np.ndarray:
|
|
75
|
+
"""
|
|
76
|
+
Blends original and modified images using the given kernel as a per-pixel weight map.
|
|
77
|
+
|
|
78
|
+
Parameters:
|
|
79
|
+
original_image (np.ndarray): Unmodified input image
|
|
80
|
+
modified_image (np.ndarray): Modified version (e.g., gamma-corrected)
|
|
81
|
+
kernel_image (np.ndarray): Kernel in [0, 1], where 0 = keep original, 1 = keep modified
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
np.ndarray: Blended result
|
|
85
|
+
"""
|
|
86
|
+
return original_image * (1.0 - kernel_image) + modified_image * kernel_image
|
|
File without changes
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
from typing import Union, List, Tuple, Callable
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from skimage.morphology import ball, disk
|
|
7
|
+
|
|
8
|
+
from batchgeneratorsv2.helpers.fft_conv import fft_conv
|
|
9
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
10
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def binary_dilation_torch(input_tensor, structure_element):
|
|
14
|
+
# Convert the boolean tensor to float
|
|
15
|
+
input_tensor = input_tensor.float()
|
|
16
|
+
|
|
17
|
+
# Get the number of dimensions of the input tensor
|
|
18
|
+
num_dims = input_tensor.dim()
|
|
19
|
+
|
|
20
|
+
# Prepare the structure element for convolution
|
|
21
|
+
# Adding extra dimensions to match the input shape for convolution
|
|
22
|
+
if num_dims == 2: # For 2D inputs
|
|
23
|
+
structure_element = structure_element.unsqueeze(0).unsqueeze(0).float()
|
|
24
|
+
elif num_dims == 3: # For 3D inputs, adding batch dimension
|
|
25
|
+
structure_element = structure_element.unsqueeze(0).unsqueeze(0).float()
|
|
26
|
+
else:
|
|
27
|
+
raise ValueError("Input tensor must be 2D (X, Y) or 3D (X, Y, Z).")
|
|
28
|
+
|
|
29
|
+
# Perform the convolution
|
|
30
|
+
# if num_dims == 2: # 2D convolution
|
|
31
|
+
# output = F.conv2d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same')
|
|
32
|
+
# elif num_dims == 3: # 3D convolution
|
|
33
|
+
# output = F.conv3d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same')
|
|
34
|
+
output = torch.round(fft_conv(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same'), decimals=0)
|
|
35
|
+
|
|
36
|
+
# Threshold to get binary output
|
|
37
|
+
output = output > 0
|
|
38
|
+
|
|
39
|
+
# Squeeze the batch dimension out and convert to bool
|
|
40
|
+
return output.squeeze(0).squeeze(0).bool()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def binary_erosion_torch(input_tensor, structure_element):
|
|
44
|
+
return ~binary_dilation_torch(~input_tensor, structure_element)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def binary_opening_torch(input_tensor, structure_element):
|
|
48
|
+
return binary_dilation_torch(binary_erosion_torch(input_tensor, structure_element), structure_element)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def binary_closing_torch(input_tensor, structure_element):
|
|
52
|
+
return binary_erosion_torch(binary_dilation_torch(input_tensor, structure_element), structure_element)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ApplyRandomBinaryOperatorTransform(ImageOnlyTransform):
|
|
56
|
+
def __init__(self,
|
|
57
|
+
channel_idx: Union[int, List[int], Tuple[int, ...]],
|
|
58
|
+
any_of_these: Tuple[Callable, ...] = (binary_dilation_torch, binary_erosion_torch, binary_closing_torch, binary_opening_torch),
|
|
59
|
+
strel_size: RandomScalar = (1, 10),
|
|
60
|
+
p_per_label: float = 1):
|
|
61
|
+
"""
|
|
62
|
+
We use fft conv. Slower for small kernels but boi does it perform on larger kernels
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
channel_idx:
|
|
66
|
+
any_of_these:
|
|
67
|
+
strel_size:
|
|
68
|
+
p_per_label:
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
if not isinstance(channel_idx, (list, tuple)):
|
|
72
|
+
channel_idx = [channel_idx]
|
|
73
|
+
if isinstance(channel_idx, tuple):
|
|
74
|
+
channel_idx = list(channel_idx)
|
|
75
|
+
self.channel_idx = channel_idx
|
|
76
|
+
self.any_of_these = any_of_these
|
|
77
|
+
self.strel_size = strel_size
|
|
78
|
+
self.p_per_label = p_per_label
|
|
79
|
+
|
|
80
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
81
|
+
# this needs to be applied in random order to the channels
|
|
82
|
+
np.random.shuffle(self.channel_idx)
|
|
83
|
+
apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
|
|
84
|
+
operators = [np.random.choice(self.any_of_these) for _ in apply_to_channels]
|
|
85
|
+
strel_size = [sample_scalar(self.strel_size, image=data_dict['image'], channel=a) for a in apply_to_channels]
|
|
86
|
+
return {
|
|
87
|
+
'apply_to_channels': apply_to_channels,
|
|
88
|
+
'operators': operators,
|
|
89
|
+
'strel_size': strel_size,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
93
|
+
for a, o, s in zip(params['apply_to_channels'], params['operators'], params['strel_size']):
|
|
94
|
+
# this is a binary map so bool is fine
|
|
95
|
+
workon = img[a]#.numpy()
|
|
96
|
+
orig_dtype = workon.dtype
|
|
97
|
+
workon = workon.to(bool)
|
|
98
|
+
if workon.ndim == 2:
|
|
99
|
+
strel = disk(s, dtype=bool)
|
|
100
|
+
else:
|
|
101
|
+
strel = ball(s, dtype=bool)
|
|
102
|
+
result = o(workon, torch.from_numpy(strel))
|
|
103
|
+
other_ch = [i for i in self.channel_idx if i != a]
|
|
104
|
+
if len(other_ch) > 0:
|
|
105
|
+
was_added_mask = result & (~workon)
|
|
106
|
+
for oc in other_ch:
|
|
107
|
+
img[oc][was_added_mask] = 0
|
|
108
|
+
img[a] = result.to(orig_dtype)#torch.from_numpy(result)
|
|
109
|
+
return img
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# class ApplyRandomBinaryOperatorTransformNpy(ImageOnlyTransform):
|
|
113
|
+
# def __init__(self,
|
|
114
|
+
# channel_idx: Union[int, List[int], Tuple[int, ...]],
|
|
115
|
+
# any_of_these: Tuple[Callable, ...] = (binary_dilation, binary_erosion, binary_closing, binary_opening),
|
|
116
|
+
# strel_size: ScalarType = (1, 10),
|
|
117
|
+
# p_per_label: float = 1):
|
|
118
|
+
# """
|
|
119
|
+
# We stick to the nnunet implementation and dont optimize. This is a TODO for the future
|
|
120
|
+
#
|
|
121
|
+
# Args:
|
|
122
|
+
# channel_idx:
|
|
123
|
+
# any_of_these:
|
|
124
|
+
# strel_size:
|
|
125
|
+
# p_per_label:
|
|
126
|
+
# """
|
|
127
|
+
# super().__init__()
|
|
128
|
+
# if not isinstance(channel_idx, (list, tuple)):
|
|
129
|
+
# channel_idx = [channel_idx]
|
|
130
|
+
# if isinstance(channel_idx, tuple):
|
|
131
|
+
# channel_idx = list(channel_idx)
|
|
132
|
+
# self.channel_idx = channel_idx
|
|
133
|
+
# self.any_of_these = any_of_these
|
|
134
|
+
# self.strel_size = strel_size
|
|
135
|
+
# self.p_per_label = p_per_label
|
|
136
|
+
#
|
|
137
|
+
# def get_parameters(self, **data_dict) -> dict:
|
|
138
|
+
# # this needs to be applied in random order to the channels
|
|
139
|
+
# np.random.shuffle(self.channel_idx)
|
|
140
|
+
# apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
|
|
141
|
+
# operators = [np.random.choice(self.any_of_these) for _ in apply_to_channels]
|
|
142
|
+
# strel_size = [sample_scalar(self.strel_size, image=data_dict['image'], channel=a) for a in apply_to_channels]
|
|
143
|
+
# return {
|
|
144
|
+
# 'apply_to_channels': apply_to_channels,
|
|
145
|
+
# 'operators': operators,
|
|
146
|
+
# 'strel_size': strel_size,
|
|
147
|
+
# }
|
|
148
|
+
#
|
|
149
|
+
# def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
150
|
+
# for a, o, s in zip(params['apply_to_channels'], params['operators'], params['strel_size']):
|
|
151
|
+
# # this is a binary map so bool is fine
|
|
152
|
+
# workon = img[a].numpy()
|
|
153
|
+
# orig_dtype = workon.dtype
|
|
154
|
+
# workon = workon.astype(bool)
|
|
155
|
+
# if workon.ndim == 2:
|
|
156
|
+
# strel = disk(s, dtype=bool)
|
|
157
|
+
# else:
|
|
158
|
+
# strel = ball(s, dtype=bool)
|
|
159
|
+
# result = o(workon, strel)
|
|
160
|
+
# other_ch = [i for i in self.channel_idx if i != a]
|
|
161
|
+
# if len(other_ch) > 0:
|
|
162
|
+
# was_added_mask = result & (~workon)
|
|
163
|
+
# for oc in other_ch:
|
|
164
|
+
# img[oc][was_added_mask] = 0
|
|
165
|
+
# img[a] = torch.from_numpy(result)
|
|
166
|
+
# return img
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
if __name__ == '__main__':
|
|
170
|
+
torch.set_num_threads(1)
|
|
171
|
+
|
|
172
|
+
tr = ApplyRandomBinaryOperatorTransform(
|
|
173
|
+
channel_idx=(0, 1),
|
|
174
|
+
strel_size=1,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
times_torch = []
|
|
178
|
+
for _ in range(10):
|
|
179
|
+
# img = (torch.rand((1, 128, 128, 128)) < 0.5).to(torch.int16)
|
|
180
|
+
# img = torch.cat((img, 1 - img))
|
|
181
|
+
img = torch.zeros((1, 50, 50, 50))
|
|
182
|
+
img[0, :10, :10, :10] = 1
|
|
183
|
+
data_dict = {'image': torch.cat((img, 1 - img))}
|
|
184
|
+
st = time()
|
|
185
|
+
out = tr(**data_dict)
|
|
186
|
+
times_torch.append(time() - st)
|
|
187
|
+
print('torch', np.mean(times_torch))
|
|
188
|
+
|
|
189
|
+
# from batchviewer import view_batch
|
|
190
|
+
# view_batch(torch.cat((img, 1 - img)), out['image'], torch.cat((img, 1 - img)) - out['image'])
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
from typing import Union, List, Tuple
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from acvl_utils.morphology.morphology_helper import label_with_component_sizes
|
|
6
|
+
|
|
7
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RemoveRandomConnectedComponentFromOneHotEncodingTransform(ImageOnlyTransform):
|
|
11
|
+
def __init__(self,
|
|
12
|
+
channel_idx: Union[int, List[int], Tuple[int, ...]],
|
|
13
|
+
fill_with_other_class_p: float = 0.25,
|
|
14
|
+
dont_do_if_covers_more_than_x_percent: float = 0.25,
|
|
15
|
+
p_per_label: float = 1
|
|
16
|
+
):
|
|
17
|
+
super().__init__()
|
|
18
|
+
if not isinstance(channel_idx, (list, tuple)):
|
|
19
|
+
channel_idx = [channel_idx]
|
|
20
|
+
if isinstance(channel_idx, tuple):
|
|
21
|
+
channel_idx = list(channel_idx)
|
|
22
|
+
self.channel_idx = channel_idx
|
|
23
|
+
self.fill_with_other_class_p = fill_with_other_class_p
|
|
24
|
+
self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent
|
|
25
|
+
self.p_per_label = p_per_label
|
|
26
|
+
|
|
27
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
28
|
+
# this needs to be applied in random order to the channels
|
|
29
|
+
np.random.shuffle(self.channel_idx)
|
|
30
|
+
apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
|
|
31
|
+
|
|
32
|
+
# self.fill_with_other_class_p cannot be resolved here because we don't know how many components there are
|
|
33
|
+
return {
|
|
34
|
+
'apply_to_channels': apply_to_channels,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
38
|
+
for a in params['apply_to_channels']:
|
|
39
|
+
workon = img[a].to(bool).numpy()
|
|
40
|
+
if not np.any(workon):
|
|
41
|
+
continue
|
|
42
|
+
num_voxels = np.prod(workon.shape, dtype=np.uint64)
|
|
43
|
+
lab, component_sizes = label_with_component_sizes(workon.astype(bool))
|
|
44
|
+
if len(component_sizes) > 0:
|
|
45
|
+
valid_component_ids = [i for i, j in component_sizes.items() if j <
|
|
46
|
+
num_voxels * self.dont_do_if_covers_more_than_x_percent]
|
|
47
|
+
# print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c,
|
|
48
|
+
# np.unique(data[b, c]), len(component_sizes), valid_component_ids,
|
|
49
|
+
# len(valid_component_ids))
|
|
50
|
+
if len(valid_component_ids) > 0:
|
|
51
|
+
random_component = np.random.choice(valid_component_ids)
|
|
52
|
+
img[a][lab == random_component] = 0
|
|
53
|
+
if np.random.uniform() < self.fill_with_other_class_p:
|
|
54
|
+
other_ch = [i for i in self.channel_idx if i != a]
|
|
55
|
+
if len(other_ch) > 0:
|
|
56
|
+
other_class = np.random.choice(other_ch)
|
|
57
|
+
img[other_class][lab == random_component] = 1
|
|
58
|
+
return img
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if __name__ == '__main__':
|
|
62
|
+
torch.set_num_threads(1)
|
|
63
|
+
|
|
64
|
+
tr = RemoveRandomConnectedComponentFromOneHotEncodingTransform(
|
|
65
|
+
channel_idx=(0, 1),
|
|
66
|
+
fill_with_other_class_p=1,
|
|
67
|
+
dont_do_if_covers_more_than_x_percent=0.25,
|
|
68
|
+
p_per_label=1
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
times_torch = []
|
|
72
|
+
for _ in range(10):
|
|
73
|
+
# img = (torch.rand((1, 128, 128, 128)) < 0.5).to(torch.int16)
|
|
74
|
+
# img = torch.cat((img, 1 - img))
|
|
75
|
+
img = torch.zeros((1, 64, 64, 64))
|
|
76
|
+
img[0, :10, :10, :10] = 1
|
|
77
|
+
img[0, -10:, -10:, -10:] = 1
|
|
78
|
+
img[0, -10:, :10, -10:] = 1
|
|
79
|
+
data_dict = {'image': torch.cat((img, 1 - img))}
|
|
80
|
+
st = time()
|
|
81
|
+
out = tr(**data_dict)
|
|
82
|
+
times_torch.append(time() - st)
|
|
83
|
+
print('torch', np.mean(times_torch))
|
|
84
|
+
|
|
85
|
+
from batchviewer import view_batch
|
|
86
|
+
view_batch(torch.cat((img, 1 - img)), out['image'], torch.cat((img, 1 - img)) - out['image'])
|