batchgeneratorsv2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. batchgeneratorsv2/benchmarks/__init__.py +0 -0
  2. batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  3. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  4. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  5. batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  6. batchgeneratorsv2/dataloading/__init__.py +0 -0
  7. batchgeneratorsv2/helpers/__init__.py +0 -0
  8. batchgeneratorsv2/helpers/fft_conv.py +149 -0
  9. batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
  13. batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  15. batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
  16. batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  17. batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  18. batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  19. batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  20. batchgeneratorsv2/transforms/local/__init__.py +0 -0
  21. batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  22. batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  23. batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  24. batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  25. batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  26. batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
  28. batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  29. batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  30. batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  32. batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  33. batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  34. batchgeneratorsv2/transforms/noise/rician.py +61 -0
  35. batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  36. batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  37. batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  38. batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
  39. batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  40. batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  41. batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
  42. batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  43. batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  44. batchgeneratorsv2/transforms/utils/compose.py +89 -0
  45. batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  46. batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  47. batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  48. batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
  49. batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
  50. batchgeneratorsv2/transforms/utils/random.py +46 -0
  51. batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
  52. batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  53. batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
  54. batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
  55. batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
  56. batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
  57. batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,90 @@
1
+ import torch
2
+ import numpy as np
3
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
6
+
7
+
8
+ class LocalContrastTransform(ImageOnlyTransform, LocalTransform):
9
+ """
10
+ Applies localized contrast modification using a spatial Gaussian mask.
11
+
12
+ A contrast-modified version of the image is blended with the original using a kernel-based interpolation.
13
+
14
+ Args:
15
+ scale (RandomScalar): Gaussian spread for the spatial weighting mask.
16
+ loc (RandomScalar): Relative center position for the Gaussian (in % of image size).
17
+ new_contrast (RandomScalar): Multiplicative factor for local contrast. 1.0 = no change.
18
+ same_for_all_channels (bool): Whether to use one kernel/contrast value for all channels.
19
+ p_per_channel (float): Probability to apply to each channel.
20
+ """
21
+
22
+ def __init__(self,
23
+ scale: RandomScalar,
24
+ loc: RandomScalar = (-1, 2),
25
+ new_contrast: RandomScalar = (0.5, 1.5),
26
+ same_for_all_channels: bool = True,
27
+ p_per_channel: float = 1.0):
28
+ ImageOnlyTransform.__init__(self)
29
+ LocalTransform.__init__(self, scale, loc)
30
+
31
+ self.new_contrast = new_contrast
32
+ self.same_for_all_channels = same_for_all_channels
33
+ self.p_per_channel = p_per_channel
34
+
35
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
36
+ C, *spatial = image.shape
37
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
38
+
39
+ if not any(apply_channel):
40
+ return {'kernels': [None] * C, 'contrasts': [None] * C}
41
+
42
+ if self.same_for_all_channels:
43
+ kernel = self._generate_kernel(spatial).astype(np.float32)
44
+ contrast = sample_scalar(self.new_contrast)
45
+
46
+ kernels = [kernel if apply else None for apply in apply_channel]
47
+ contrasts = [contrast if apply else None for apply in apply_channel]
48
+ else:
49
+ kernels, contrasts = [], []
50
+ for apply in apply_channel:
51
+ if not apply:
52
+ kernels.append(None)
53
+ contrasts.append(None)
54
+ continue
55
+ kernel = self._generate_kernel(spatial).astype(np.float32)
56
+ contrast = sample_scalar(self.new_contrast)
57
+ kernels.append(kernel)
58
+ contrasts.append(contrast)
59
+
60
+ return {'kernels': kernels, 'contrasts': contrasts}
61
+
62
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
63
+ img_np = img.cpu().numpy()
64
+
65
+ for c, (kernel, contrast) in enumerate(zip(params['kernels'], params['contrasts'])):
66
+ if kernel is None:
67
+ continue
68
+
69
+ channel = img_np[c]
70
+ mean = (channel * kernel).sum() / (kernel.sum() + 1e-8)
71
+ modified = (channel - mean) * contrast + mean
72
+ img_np[c] = self.run_interpolation(channel, modified, kernel)
73
+
74
+ return torch.from_numpy(img_np).to(img.device, dtype=img.dtype)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ from batchviewer import view_batch
79
+
80
+ # Single-channel synthetic volume
81
+ image = torch.rand(1, 32, 64, 64)
82
+
83
+ # Or contrast
84
+ contrast = LocalContrastTransform(scale=(10, 20), new_contrast=(0.3, 2.0), p_per_channel=1.0)
85
+
86
+ # Apply either one
87
+ params = contrast.get_parameters(image=image)
88
+ image_aug = contrast._apply_to_image(image.clone(), **params)
89
+
90
+ view_batch(image, image_aug)
@@ -0,0 +1,104 @@
1
+ import torch
2
+ import numpy as np
3
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
6
+
7
+
8
+ class LocalGammaTransform(ImageOnlyTransform, LocalTransform):
9
+ """
10
+ Applies locally varying gamma correction to an image using a spatial Gaussian weighting mask.
11
+
12
+ A Gaussian kernel is randomly placed in the image and used to blend between the original image and
13
+ a gamma-corrected version. This simulates localized nonlinear intensity shifts, useful for data augmentation
14
+ in medical imaging or general contrast robustness.
15
+
16
+ Parameters:
17
+ scale (RandomScalar): Controls the width of the Gaussian (std dev). Recommend large values (e.g. 10–30).
18
+ loc (RandomScalar): Controls Gaussian center as a % of image size. E.g. (-1, 2) allows off-canvas kernels.
19
+ gamma (RandomScalar): The gamma exponent applied locally. Try wild distributions :)
20
+ same_for_all_channels (bool): If True, one kernel is reused across all channels. Otherwise sampled per-channel.
21
+ p_per_channel (float): Probability to apply gamma correction to each channel independently.
22
+ """
23
+
24
+ def __init__(self,
25
+ scale: RandomScalar,
26
+ loc: RandomScalar = (-1, 2),
27
+ gamma: RandomScalar = (0.5, 1),
28
+ same_for_all_channels: bool = True,
29
+ p_per_channel: float = 1.0):
30
+ ImageOnlyTransform.__init__(self)
31
+ LocalTransform.__init__(self, scale, loc)
32
+
33
+ self.gamma = gamma
34
+ self.same_for_all_channels = same_for_all_channels
35
+ self.p_per_channel = p_per_channel
36
+
37
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
38
+ C, *spatial = image.shape
39
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
40
+
41
+ if not any(apply_channel):
42
+ return {'kernels': [None] * C, 'gammas': [None] * C}
43
+
44
+ if self.same_for_all_channels:
45
+ kernel = self._generate_kernel(spatial).astype(np.float32)
46
+ gamma = sample_scalar(self.gamma)
47
+
48
+ kernels = [kernel if apply else None for apply in apply_channel]
49
+ gammas = [gamma if apply else None for apply in apply_channel]
50
+ else:
51
+ kernels, gammas = [], []
52
+ for apply in apply_channel:
53
+ if not apply:
54
+ kernels.append(None)
55
+ gammas.append(None)
56
+ continue
57
+ kernel = self._generate_kernel(spatial).astype(np.float32)
58
+ gamma = sample_scalar(self.gamma)
59
+ kernels.append(kernel)
60
+ gammas.append(gamma)
61
+
62
+ return {'kernels': kernels, 'gammas': gammas}
63
+
64
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
65
+ img_np = img.cpu().numpy()
66
+
67
+ for c, (kernel, gamma) in enumerate(zip(params['kernels'], params['gammas'])):
68
+ if kernel is None:
69
+ continue
70
+
71
+ channel = img_np[c]
72
+ min_val, max_val = channel.min(), channel.max()
73
+ denom = max(max_val - min_val, 1e-8)
74
+
75
+ # Normalize to [0, 1]
76
+ norm = (channel - min_val) / denom
77
+ gamma_corrected = np.power(norm, gamma)
78
+
79
+ # Blend using kernel
80
+ blended = self.run_interpolation(norm, gamma_corrected, kernel)
81
+
82
+ # Rescale to original range
83
+ img_np[c] = blended * denom + min_val
84
+
85
+ return torch.from_numpy(img_np).to(img.device, dtype=img.dtype)
86
+
87
+
88
+ if __name__ == '__main__':
89
+ import torch
90
+ from batchviewer import view_batch
91
+
92
+ image = torch.rand(1, 32, 64, 64) # (C, D, H, W)
93
+
94
+ transform = LocalGammaTransform(
95
+ scale=(10, 20),
96
+ gamma=lambda *_: np.random.uniform(0.01, 1) if np.random.rand() < 0.5 else np.random.uniform(1, 3),
97
+ same_for_all_channels=False,
98
+ p_per_channel=1.0
99
+ )
100
+
101
+ params = transform.get_parameters(image=image)
102
+ image_gamma = transform._apply_to_image(image.clone(), **params)
103
+
104
+ view_batch(image, image_gamma)
@@ -0,0 +1,98 @@
1
+ import torch
2
+ import numpy as np
3
+ from batchgeneratorsv2.transforms.local.local_transform import LocalTransform
4
+ from scipy.ndimage import gaussian_filter
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
7
+
8
+
9
+ class LocalSmoothingTransform(ImageOnlyTransform, LocalTransform):
10
+ """
11
+ Applies localized Gaussian smoothing to parts of the image using a spatial Gaussian mask.
12
+
13
+ A blurred copy of the image is interpolated with the original, weighted by a Gaussian kernel.
14
+ The strength and extent of the blur are both controllable.
15
+
16
+ Args:
17
+ scale (RandomScalar): Gaussian spread for the spatial weighting mask.
18
+ loc (RandomScalar): Relative center position for the Gaussian (in % of image size).
19
+ smoothing_strength (RandomScalar): Max weight of the smoothed image in the interpolation [0, 1].
20
+ kernel_size (RandomScalar): Sigma for the actual Gaussian smoothing of the image.
21
+ same_for_all_channels (bool): Whether to apply the same kernel to all channels.
22
+ p_per_channel (float): Probability of applying transform per channel.
23
+ """
24
+
25
+ def __init__(self,
26
+ scale: RandomScalar,
27
+ loc: RandomScalar = (-1, 2),
28
+ smoothing_strength: RandomScalar = (0.5, 1.0),
29
+ kernel_size: RandomScalar = (0.5, 1.5),
30
+ same_for_all_channels: bool = True,
31
+ p_per_channel: float = 1.0):
32
+ ImageOnlyTransform.__init__(self)
33
+ LocalTransform.__init__(self, scale, loc)
34
+
35
+ self.smoothing_strength = smoothing_strength
36
+ self.kernel_size = kernel_size
37
+ self.same_for_all_channels = same_for_all_channels
38
+ self.p_per_channel = p_per_channel
39
+
40
+ def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
41
+ C, *spatial = image.shape
42
+ apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
43
+
44
+ if not any(apply_channel):
45
+ return {'kernels': [None] * C, 'sigma': None, 'strengths': [None] * C}
46
+
47
+ sigma = sample_scalar(self.kernel_size)
48
+
49
+ if self.same_for_all_channels:
50
+ kernel = self._generate_kernel(spatial).astype(np.float32)
51
+ strength = sample_scalar(self.smoothing_strength)
52
+
53
+ kernels = [kernel if apply else None for apply in apply_channel]
54
+ strengths = [strength if apply else None for apply in apply_channel]
55
+ else:
56
+ kernels, strengths = [], []
57
+ for apply in apply_channel:
58
+ if not apply:
59
+ kernels.append(None)
60
+ strengths.append(None)
61
+ continue
62
+ kernel = self._generate_kernel(spatial).astype(np.float32)
63
+ strength = sample_scalar(self.smoothing_strength)
64
+ kernels.append(kernel)
65
+ strengths.append(strength)
66
+
67
+ return {'kernels': kernels, 'sigma': sigma, 'strengths': strengths}
68
+
69
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
70
+ img_np = img.cpu().numpy()
71
+ sigma = params['sigma']
72
+
73
+ for c, (kernel, strength) in enumerate(zip(params['kernels'], params['strengths'])):
74
+ if kernel is None:
75
+ continue
76
+
77
+ kernel = kernel * strength # scale kernel by smoothing strength
78
+ smoothed = gaussian_filter(img_np[c], sigma=sigma)
79
+ img_np[c] = self.run_interpolation(img_np[c], smoothed, kernel)
80
+
81
+ return torch.from_numpy(img_np).to(img.device, dtype=img.dtype)
82
+
83
+
84
+ if __name__ == '__main__':
85
+ from batchviewer import view_batch
86
+
87
+ # Single-channel synthetic volume
88
+ image = torch.rand(1, 32, 64, 64)
89
+
90
+ # Or contrast
91
+ smoother = LocalSmoothingTransform(loc=(0, 1), scale=(10, 20), kernel_size=(3, 10), p_per_channel=1.0)
92
+
93
+ # Apply either one
94
+ params = smoother.get_parameters(image=image)
95
+ image_aug = smoother._apply_to_image(image.clone(), **params)
96
+
97
+ view_batch(image, image_aug)
98
+
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+ import scipy.stats as st
3
+ from abc import ABC
4
+ from typing import Tuple
5
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
6
+
7
+
8
+ class LocalTransform(ABC):
9
+ def __init__(self, scale: RandomScalar, loc: RandomScalar = (-1, 2)):
10
+ self.loc = loc
11
+ self.scale = scale
12
+
13
+ def _generate_kernel(self, img_shp: Tuple[int, ...]) -> np.ndarray:
14
+ ndim = len(img_shp)
15
+ x_grids = [np.arange(-0.5, s + 0.5, dtype=np.float32) for s in img_shp]
16
+ kernels = []
17
+
18
+ for d in range(ndim):
19
+ loc_val = sample_scalar(self.loc, img_shp, d)
20
+ scale_val = sample_scalar(self.scale, img_shp, d)
21
+ loc_rescaled = loc_val * img_shp[d]
22
+ cdf = st.norm.cdf(x_grids[d], loc=loc_rescaled, scale=scale_val)
23
+ kernels.append(np.diff(cdf).astype(np.float32))
24
+
25
+ kernel = kernels[0][:, None] @ kernels[1][None]
26
+ if ndim == 3:
27
+ kernel = kernel[:, :, None] @ kernels[2][None]
28
+
29
+ kernel -= kernel.min()
30
+ kernel_max = kernel.max()
31
+ if kernel_max > 0:
32
+ kernel /= kernel_max
33
+ return kernel
34
+
35
+ def _generate_multiple_kernel_image(self, img_shp: Tuple[int, ...], num_kernels: int) -> np.ndarray:
36
+ """
37
+ Places multiple additive Gaussians in the image and normalizes the sum to [0, 1].
38
+
39
+ Parameters:
40
+ img_shp (Tuple[int, ...]): Spatial shape (e.g., (X, Y[, Z]))
41
+ num_kernels (int): Number of kernels to generate and sum
42
+
43
+ Returns:
44
+ np.ndarray: Combined kernel image with values in [0, 1]
45
+ """
46
+ kernel_image = np.zeros(img_shp, dtype=np.float32)
47
+ for _ in range(num_kernels):
48
+ kernel_image += self._generate_kernel(img_shp)
49
+
50
+ kernel_image -= kernel_image.min()
51
+ kernel_max = kernel_image.max()
52
+ if kernel_max > 0:
53
+ kernel_image /= kernel_max
54
+ return kernel_image
55
+
56
+ @staticmethod
57
+ def invert_kernel(kernel_image: np.ndarray) -> np.ndarray:
58
+ """
59
+ Inverts a normalized kernel: 1 - kernel
60
+
61
+ Assumes input is in [0, 1].
62
+
63
+ Parameters:
64
+ kernel_image (np.ndarray): Input kernel in [0, 1]
65
+
66
+ Returns:
67
+ np.ndarray: Inverted kernel in [0, 1]
68
+ """
69
+ return 1.0 - kernel_image
70
+
71
+ @staticmethod
72
+ def run_interpolation(original_image: np.ndarray,
73
+ modified_image: np.ndarray,
74
+ kernel_image: np.ndarray) -> np.ndarray:
75
+ """
76
+ Blends original and modified images using the given kernel as a per-pixel weight map.
77
+
78
+ Parameters:
79
+ original_image (np.ndarray): Unmodified input image
80
+ modified_image (np.ndarray): Modified version (e.g., gamma-corrected)
81
+ kernel_image (np.ndarray): Kernel in [0, 1], where 0 = keep original, 1 = keep modified
82
+
83
+ Returns:
84
+ np.ndarray: Blended result
85
+ """
86
+ return original_image * (1.0 - kernel_image) + modified_image * kernel_image
File without changes
@@ -0,0 +1,190 @@
1
+ from time import time
2
+ from typing import Union, List, Tuple, Callable
3
+
4
+ import numpy as np
5
+ import torch
6
+ from skimage.morphology import ball, disk
7
+
8
+ from batchgeneratorsv2.helpers.fft_conv import fft_conv
9
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
10
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
11
+
12
+
13
+ def binary_dilation_torch(input_tensor, structure_element):
14
+ # Convert the boolean tensor to float
15
+ input_tensor = input_tensor.float()
16
+
17
+ # Get the number of dimensions of the input tensor
18
+ num_dims = input_tensor.dim()
19
+
20
+ # Prepare the structure element for convolution
21
+ # Adding extra dimensions to match the input shape for convolution
22
+ if num_dims == 2: # For 2D inputs
23
+ structure_element = structure_element.unsqueeze(0).unsqueeze(0).float()
24
+ elif num_dims == 3: # For 3D inputs, adding batch dimension
25
+ structure_element = structure_element.unsqueeze(0).unsqueeze(0).float()
26
+ else:
27
+ raise ValueError("Input tensor must be 2D (X, Y) or 3D (X, Y, Z).")
28
+
29
+ # Perform the convolution
30
+ # if num_dims == 2: # 2D convolution
31
+ # output = F.conv2d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same')
32
+ # elif num_dims == 3: # 3D convolution
33
+ # output = F.conv3d(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same')
34
+ output = torch.round(fft_conv(input_tensor.unsqueeze(0).unsqueeze(0), structure_element, padding='same'), decimals=0)
35
+
36
+ # Threshold to get binary output
37
+ output = output > 0
38
+
39
+ # Squeeze the batch dimension out and convert to bool
40
+ return output.squeeze(0).squeeze(0).bool()
41
+
42
+
43
+ def binary_erosion_torch(input_tensor, structure_element):
44
+ return ~binary_dilation_torch(~input_tensor, structure_element)
45
+
46
+
47
+ def binary_opening_torch(input_tensor, structure_element):
48
+ return binary_dilation_torch(binary_erosion_torch(input_tensor, structure_element), structure_element)
49
+
50
+
51
+ def binary_closing_torch(input_tensor, structure_element):
52
+ return binary_erosion_torch(binary_dilation_torch(input_tensor, structure_element), structure_element)
53
+
54
+
55
+ class ApplyRandomBinaryOperatorTransform(ImageOnlyTransform):
56
+ def __init__(self,
57
+ channel_idx: Union[int, List[int], Tuple[int, ...]],
58
+ any_of_these: Tuple[Callable, ...] = (binary_dilation_torch, binary_erosion_torch, binary_closing_torch, binary_opening_torch),
59
+ strel_size: RandomScalar = (1, 10),
60
+ p_per_label: float = 1):
61
+ """
62
+ We use fft conv. Slower for small kernels but boi does it perform on larger kernels
63
+
64
+ Args:
65
+ channel_idx:
66
+ any_of_these:
67
+ strel_size:
68
+ p_per_label:
69
+ """
70
+ super().__init__()
71
+ if not isinstance(channel_idx, (list, tuple)):
72
+ channel_idx = [channel_idx]
73
+ if isinstance(channel_idx, tuple):
74
+ channel_idx = list(channel_idx)
75
+ self.channel_idx = channel_idx
76
+ self.any_of_these = any_of_these
77
+ self.strel_size = strel_size
78
+ self.p_per_label = p_per_label
79
+
80
+ def get_parameters(self, **data_dict) -> dict:
81
+ # this needs to be applied in random order to the channels
82
+ np.random.shuffle(self.channel_idx)
83
+ apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
84
+ operators = [np.random.choice(self.any_of_these) for _ in apply_to_channels]
85
+ strel_size = [sample_scalar(self.strel_size, image=data_dict['image'], channel=a) for a in apply_to_channels]
86
+ return {
87
+ 'apply_to_channels': apply_to_channels,
88
+ 'operators': operators,
89
+ 'strel_size': strel_size,
90
+ }
91
+
92
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
93
+ for a, o, s in zip(params['apply_to_channels'], params['operators'], params['strel_size']):
94
+ # this is a binary map so bool is fine
95
+ workon = img[a]#.numpy()
96
+ orig_dtype = workon.dtype
97
+ workon = workon.to(bool)
98
+ if workon.ndim == 2:
99
+ strel = disk(s, dtype=bool)
100
+ else:
101
+ strel = ball(s, dtype=bool)
102
+ result = o(workon, torch.from_numpy(strel))
103
+ other_ch = [i for i in self.channel_idx if i != a]
104
+ if len(other_ch) > 0:
105
+ was_added_mask = result & (~workon)
106
+ for oc in other_ch:
107
+ img[oc][was_added_mask] = 0
108
+ img[a] = result.to(orig_dtype)#torch.from_numpy(result)
109
+ return img
110
+
111
+
112
+ # class ApplyRandomBinaryOperatorTransformNpy(ImageOnlyTransform):
113
+ # def __init__(self,
114
+ # channel_idx: Union[int, List[int], Tuple[int, ...]],
115
+ # any_of_these: Tuple[Callable, ...] = (binary_dilation, binary_erosion, binary_closing, binary_opening),
116
+ # strel_size: ScalarType = (1, 10),
117
+ # p_per_label: float = 1):
118
+ # """
119
+ # We stick to the nnunet implementation and dont optimize. This is a TODO for the future
120
+ #
121
+ # Args:
122
+ # channel_idx:
123
+ # any_of_these:
124
+ # strel_size:
125
+ # p_per_label:
126
+ # """
127
+ # super().__init__()
128
+ # if not isinstance(channel_idx, (list, tuple)):
129
+ # channel_idx = [channel_idx]
130
+ # if isinstance(channel_idx, tuple):
131
+ # channel_idx = list(channel_idx)
132
+ # self.channel_idx = channel_idx
133
+ # self.any_of_these = any_of_these
134
+ # self.strel_size = strel_size
135
+ # self.p_per_label = p_per_label
136
+ #
137
+ # def get_parameters(self, **data_dict) -> dict:
138
+ # # this needs to be applied in random order to the channels
139
+ # np.random.shuffle(self.channel_idx)
140
+ # apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
141
+ # operators = [np.random.choice(self.any_of_these) for _ in apply_to_channels]
142
+ # strel_size = [sample_scalar(self.strel_size, image=data_dict['image'], channel=a) for a in apply_to_channels]
143
+ # return {
144
+ # 'apply_to_channels': apply_to_channels,
145
+ # 'operators': operators,
146
+ # 'strel_size': strel_size,
147
+ # }
148
+ #
149
+ # def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
150
+ # for a, o, s in zip(params['apply_to_channels'], params['operators'], params['strel_size']):
151
+ # # this is a binary map so bool is fine
152
+ # workon = img[a].numpy()
153
+ # orig_dtype = workon.dtype
154
+ # workon = workon.astype(bool)
155
+ # if workon.ndim == 2:
156
+ # strel = disk(s, dtype=bool)
157
+ # else:
158
+ # strel = ball(s, dtype=bool)
159
+ # result = o(workon, strel)
160
+ # other_ch = [i for i in self.channel_idx if i != a]
161
+ # if len(other_ch) > 0:
162
+ # was_added_mask = result & (~workon)
163
+ # for oc in other_ch:
164
+ # img[oc][was_added_mask] = 0
165
+ # img[a] = torch.from_numpy(result)
166
+ # return img
167
+
168
+
169
+ if __name__ == '__main__':
170
+ torch.set_num_threads(1)
171
+
172
+ tr = ApplyRandomBinaryOperatorTransform(
173
+ channel_idx=(0, 1),
174
+ strel_size=1,
175
+ )
176
+
177
+ times_torch = []
178
+ for _ in range(10):
179
+ # img = (torch.rand((1, 128, 128, 128)) < 0.5).to(torch.int16)
180
+ # img = torch.cat((img, 1 - img))
181
+ img = torch.zeros((1, 50, 50, 50))
182
+ img[0, :10, :10, :10] = 1
183
+ data_dict = {'image': torch.cat((img, 1 - img))}
184
+ st = time()
185
+ out = tr(**data_dict)
186
+ times_torch.append(time() - st)
187
+ print('torch', np.mean(times_torch))
188
+
189
+ # from batchviewer import view_batch
190
+ # view_batch(torch.cat((img, 1 - img)), out['image'], torch.cat((img, 1 - img)) - out['image'])
@@ -0,0 +1,86 @@
1
+ from time import time
2
+ from typing import Union, List, Tuple
3
+ import numpy as np
4
+ import torch
5
+ from acvl_utils.morphology.morphology_helper import label_with_component_sizes
6
+
7
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
8
+
9
+
10
+ class RemoveRandomConnectedComponentFromOneHotEncodingTransform(ImageOnlyTransform):
11
+ def __init__(self,
12
+ channel_idx: Union[int, List[int], Tuple[int, ...]],
13
+ fill_with_other_class_p: float = 0.25,
14
+ dont_do_if_covers_more_than_x_percent: float = 0.25,
15
+ p_per_label: float = 1
16
+ ):
17
+ super().__init__()
18
+ if not isinstance(channel_idx, (list, tuple)):
19
+ channel_idx = [channel_idx]
20
+ if isinstance(channel_idx, tuple):
21
+ channel_idx = list(channel_idx)
22
+ self.channel_idx = channel_idx
23
+ self.fill_with_other_class_p = fill_with_other_class_p
24
+ self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent
25
+ self.p_per_label = p_per_label
26
+
27
+ def get_parameters(self, **data_dict) -> dict:
28
+ # this needs to be applied in random order to the channels
29
+ np.random.shuffle(self.channel_idx)
30
+ apply_to_channels = [self.channel_idx[i] for i, j in enumerate(torch.rand(len(self.channel_idx)) < self.p_per_label) if j]
31
+
32
+ # self.fill_with_other_class_p cannot be resolved here because we don't know how many components there are
33
+ return {
34
+ 'apply_to_channels': apply_to_channels,
35
+ }
36
+
37
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
38
+ for a in params['apply_to_channels']:
39
+ workon = img[a].to(bool).numpy()
40
+ if not np.any(workon):
41
+ continue
42
+ num_voxels = np.prod(workon.shape, dtype=np.uint64)
43
+ lab, component_sizes = label_with_component_sizes(workon.astype(bool))
44
+ if len(component_sizes) > 0:
45
+ valid_component_ids = [i for i, j in component_sizes.items() if j <
46
+ num_voxels * self.dont_do_if_covers_more_than_x_percent]
47
+ # print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c,
48
+ # np.unique(data[b, c]), len(component_sizes), valid_component_ids,
49
+ # len(valid_component_ids))
50
+ if len(valid_component_ids) > 0:
51
+ random_component = np.random.choice(valid_component_ids)
52
+ img[a][lab == random_component] = 0
53
+ if np.random.uniform() < self.fill_with_other_class_p:
54
+ other_ch = [i for i in self.channel_idx if i != a]
55
+ if len(other_ch) > 0:
56
+ other_class = np.random.choice(other_ch)
57
+ img[other_class][lab == random_component] = 1
58
+ return img
59
+
60
+
61
+ if __name__ == '__main__':
62
+ torch.set_num_threads(1)
63
+
64
+ tr = RemoveRandomConnectedComponentFromOneHotEncodingTransform(
65
+ channel_idx=(0, 1),
66
+ fill_with_other_class_p=1,
67
+ dont_do_if_covers_more_than_x_percent=0.25,
68
+ p_per_label=1
69
+ )
70
+
71
+ times_torch = []
72
+ for _ in range(10):
73
+ # img = (torch.rand((1, 128, 128, 128)) < 0.5).to(torch.int16)
74
+ # img = torch.cat((img, 1 - img))
75
+ img = torch.zeros((1, 64, 64, 64))
76
+ img[0, :10, :10, :10] = 1
77
+ img[0, -10:, -10:, -10:] = 1
78
+ img[0, -10:, :10, -10:] = 1
79
+ data_dict = {'image': torch.cat((img, 1 - img))}
80
+ st = time()
81
+ out = tr(**data_dict)
82
+ times_torch.append(time() - st)
83
+ print('torch', np.mean(times_torch))
84
+
85
+ from batchviewer import view_batch
86
+ view_batch(torch.cat((img, 1 - img)), out['image'], torch.cat((img, 1 - img)) - out['image'])