batchgeneratorsv2 0.3.0__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/PKG-INFO +1 -2
  2. batchgeneratorsv2-0.3.3/batchgeneratorsv2/helpers/fft_conv.py +149 -0
  3. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/intensity/brightness.py +33 -16
  4. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/intensity/contrast.py +53 -27
  5. batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  6. batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  7. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/intensity/random_clip.py +5 -6
  8. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +2 -5
  9. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/noise/blank_rectangle.py +2 -2
  10. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +3 -3
  11. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/noise/sharpen.py +2 -2
  12. batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  13. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/spatial/low_resolution.py +1 -1
  14. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/spatial/spatial.py +292 -264
  15. batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/utils/compose.py +89 -0
  16. batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  17. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +4 -2
  18. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/utils/pseudo2d.py +3 -5
  19. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/utils/remove_label.py +4 -1
  20. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/utils/seg_to_regions.py +1 -1
  21. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2.egg-info/PKG-INFO +1 -2
  22. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2.egg-info/SOURCES.txt +3 -1
  23. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
  24. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2.egg-info/requires.txt +0 -1
  25. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
  26. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/pyproject.toml +2 -3
  27. batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/intensity/gamma.py +0 -88
  28. batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -80
  29. batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  30. batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/utils/compose.py +0 -14
  31. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/LICENSE +0 -0
  32. {batchgeneratorsv2-0.3.0/batchgeneratorsv2 → batchgeneratorsv2-0.3.3/batchgeneratorsv2/benchmarks}/__init__.py +0 -0
  33. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/benchmarks → batchgeneratorsv2-0.3.3/batchgeneratorsv2/benchmarks/bg_comparison}/__init__.py +0 -0
  34. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +0 -0
  35. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +0 -0
  36. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
  37. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/benchmarks/bg_comparison → batchgeneratorsv2-0.3.3/batchgeneratorsv2/dataloading}/__init__.py +0 -0
  38. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/dataloading → batchgeneratorsv2-0.3.3/batchgeneratorsv2/helpers}/__init__.py +0 -0
  39. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
  40. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/helpers → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms}/__init__.py +0 -0
  41. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/base}/__init__.py +0 -0
  42. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/base/basic_transform.py +0 -0
  43. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/base → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/intensity}/__init__.py +0 -0
  44. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/intensity/inversion.py +0 -0
  45. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/intensity → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/local}/__init__.py +0 -0
  46. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/local/brightness_gradient.py +0 -0
  47. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/local/local_contrast.py +0 -0
  48. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/local/local_gamma.py +0 -0
  49. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/local/local_smoothing.py +0 -0
  50. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/local/local_transform.py +0 -0
  51. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/local → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/nnunet}/__init__.py +0 -0
  52. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
  53. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
  54. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/nnunet → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/noise}/__init__.py +0 -0
  55. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/noise/median_filter.py +0 -0
  56. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/noise/rician.py +0 -0
  57. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/noise → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/spatial}/__init__.py +0 -0
  58. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
  59. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/spatial/rot90.py +0 -0
  60. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/spatial/transpose.py +0 -0
  61. {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/spatial → batchgeneratorsv2-0.3.3/batchgeneratorsv2/transforms/utils}/__init__.py +0 -0
  62. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
  63. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
  64. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/batchgeneratorsv2/transforms/utils/random.py +0 -0
  65. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/readme.md +0 -0
  66. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/setup.cfg +0 -0
  67. {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.3}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: batchgeneratorsv2
3
- Version: 0.3.0
3
+ Version: 0.3.3
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -220,7 +220,6 @@ Description-Content-Type: text/markdown
220
220
  License-File: LICENSE
221
221
  Requires-Dist: torch>=2.0.0
222
222
  Requires-Dist: numpy
223
- Requires-Dist: fft-conv-pytorch
224
223
  Requires-Dist: batchgenerators>=0.25
225
224
  Dynamic: license-file
226
225
 
@@ -0,0 +1,149 @@
1
+ from typing import Iterable, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as f
5
+ from torch import Tensor, nn
6
+ from torch.fft import irfftn, rfftn
7
+ from math import ceil, floor
8
+
9
+
10
+ # Taken from here: https://github.com/vcasellesb/fft-conv-pytorch/tree/non-tup-slice-fix. THANK YOU!
11
+ # Original codebase is here: https://github.com/fkodom/fft-conv-pytorch -> unfortunately was not updated for pytorch 2.9.
12
+
13
+
14
+ def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
15
+ """Multiplies two complex-valued tensors."""
16
+ # Scalar matrix multiplication of two tensors, over only the first channel
17
+ # dimensions. Dimensions 3 and higher will have the same shape after multiplication.
18
+ # We also allow for "grouped" multiplications, where multiple sections of channels
19
+ # are multiplied independently of one another (required for group convolutions).
20
+ a = a.view(a.size(0), groups, -1, *a.shape[2:])
21
+ b = b.view(groups, -1, *b.shape[1:])
22
+
23
+ a = torch.movedim(a, 2, a.dim() - 1).unsqueeze(-2)
24
+ b = torch.movedim(b, (1, 2), (b.dim() - 1, b.dim() - 2))
25
+
26
+ # complex value matrix multiplication
27
+ real = a.real @ b.real - a.imag @ b.imag
28
+ imag = a.imag @ b.real + a.real @ b.imag
29
+ real = torch.movedim(real, real.dim() - 1, 2).squeeze(-1)
30
+ imag = torch.movedim(imag, imag.dim() - 1, 2).squeeze(-1)
31
+ c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
32
+ c.real, c.imag = real, imag
33
+
34
+ return c.view(c.size(0), -1, *c.shape[3:])
35
+
36
+
37
+ def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:
38
+ """Casts to a tuple with length 'n'. Useful for automatically computing the
39
+ padding and stride for convolutions, where users may only provide an integer.
40
+
41
+ Args:
42
+ val: (Union[int, Iterable[int]]) Value to cast into a tuple.
43
+ n: (int) Desired length of the tuple
44
+
45
+ Returns:
46
+ (Tuple[int, ...]) Tuple of length 'n'
47
+ """
48
+ if isinstance(val, Iterable):
49
+ out = tuple(val)
50
+ if len(out) == n:
51
+ return out
52
+ else:
53
+ raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.")
54
+ else:
55
+ return n * (val,)
56
+
57
+
58
+ def fft_conv(
59
+ signal: Tensor,
60
+ kernel: Tensor,
61
+ bias: Tensor = None,
62
+ padding: Union[int, Iterable[int], str] = 0,
63
+ padding_mode: str = "constant",
64
+ stride: Union[int, Iterable[int]] = 1,
65
+ dilation: Union[int, Iterable[int]] = 1,
66
+ groups: int = 1,
67
+ ) -> Tensor:
68
+ """Performs N-d convolution of Tensors using a fast fourier transform, which
69
+ is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
70
+ the convolution (in order ot mimic the PyTorch direct convolution).
71
+
72
+ Args:
73
+ signal: (Tensor) Input tensor to be convolved with the kernel.
74
+ kernel: (Tensor) Convolution kernel.
75
+ bias: (Tensor) Bias tensor to add to the output.
76
+ padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
77
+ input on the last dimension. If str, "same" supported to pad input for size preservation.
78
+ padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
79
+ reflection not available for 3d.
80
+ stride: (Union[int, Iterable[int]) Stride size for computing output values.
81
+ dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
82
+ groups: (int) Number of groups for the convolution.
83
+
84
+ Returns:
85
+ (Tensor) Convolved tensor
86
+ """
87
+
88
+ # Cast padding, stride & dilation to tuples.
89
+ n = signal.ndim - 2
90
+ stride_ = to_ntuple(stride, n=n)
91
+ dilation_ = to_ntuple(dilation, n=n)
92
+ if isinstance(padding, str):
93
+ if padding == "same":
94
+ if stride != 1 or dilation != 1:
95
+ raise ValueError("stride must be 1 for padding='same'.")
96
+ padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
97
+ # else:
98
+ # raise ValueError(f"Padding mode {padding} not supported.")
99
+ else:
100
+ padding_ = to_ntuple(padding, n=n)
101
+
102
+ # internal dilation offsets
103
+ offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
104
+ offset[(slice(None), slice(None), *((0,) * n))] = 1.0
105
+
106
+ # correct the kernel by cutting off unwanted dilation trailing zeros
107
+ cutoff = tuple(slice(None, -d + 1 if d != 1 else None) for d in dilation_)
108
+
109
+ # pad the kernel internally according to the dilation parameters
110
+ kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
111
+
112
+ # Pad the input signal & kernel tensors (round to support even sized convolutions)
113
+ signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
114
+ signal = f.pad(signal, signal_padding, mode=padding_mode)
115
+
116
+ # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
117
+ # have *even* length. Just pad with one more zero if the final dimension is odd.
118
+ signal_size = signal.size() # original signal size without padding to even
119
+ if signal.size(-1) % 2 != 0:
120
+ signal = f.pad(signal, [0, 1])
121
+
122
+ kernel_padding = [
123
+ pad
124
+ for i in reversed(range(2, signal.ndim))
125
+ for pad in [0, signal.size(i) - kernel.size(i)]
126
+ ]
127
+ padded_kernel = f.pad(kernel, kernel_padding)
128
+
129
+ # Perform fourier convolution -- FFT, matrix multiply, then IFFT
130
+ signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
131
+ kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))
132
+
133
+ kernel_fr.imag *= -1
134
+ output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
135
+ output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
136
+
137
+ # Remove extra padded values
138
+ crop_slices = (slice(None), slice(None)) + tuple(
139
+ slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2])
140
+ for i in range(2, signal.ndim)
141
+ )
142
+ output = output[crop_slices].contiguous()
143
+
144
+ # Optionally, add a bias term before returning.
145
+ if bias is not None:
146
+ bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
147
+ output += bias.view(bias_shape)
148
+
149
+ return output
@@ -38,42 +38,58 @@ class BrightnessAdditiveTransform(ImageOnlyTransform):
38
38
  """
39
39
  Adds random additive brightness noise sampled from a Gaussian distribution (mu, sigma).
40
40
 
41
- Supports per-channel brightness sampling or shared brightness across all channels.
41
+ Supports either synchronized brightness shift across all channels or per-channel brightness shift.
42
42
 
43
43
  Args:
44
44
  mu (float): Mean of the Gaussian used to sample brightness shifts.
45
45
  sigma (float): Standard deviation of the Gaussian.
46
- per_channel (bool): If True, brightness shifts are sampled separately per channel.
46
+ synchronize_channels (bool): If True, brightness shifts are shared across all channels.
47
47
  p_per_channel (float): Probability to apply the brightness shift to each channel.
48
48
  """
49
49
 
50
50
  def __init__(self,
51
51
  mu: float,
52
52
  sigma: float,
53
- per_channel: bool = True,
53
+ synchronize_channels: bool = True, # Changed to synchronize_channels
54
54
  p_per_channel: float = 1.0):
55
55
  super().__init__()
56
56
  self.mu = mu
57
57
  self.sigma = sigma
58
- self.per_channel = per_channel
58
+ self.synchronize_channels = synchronize_channels # Now it's being used
59
59
  self.p_per_channel = p_per_channel
60
60
 
61
- def get_parameters(self, image: torch.Tensor, **kwargs) -> dict:
62
- C = image.shape[0]
63
- apply_channel = [np.random.rand() < self.p_per_channel for _ in range(C)]
61
+ def get_parameters(self, **data_dict) -> dict:
62
+ img = data_dict["image"]
63
+ c = img.shape[0]
64
+ apply_to_channel = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
65
+
66
+ if len(apply_to_channel) == 0:
67
+ return {"apply_to_channel": apply_to_channel, "shift": None}
64
68
 
65
- if self.per_channel:
66
- brightness = [np.random.normal(self.mu, self.sigma) if apply else None for apply in apply_channel]
69
+ # Apply either synchronized or per-channel brightness shift
70
+ if self.synchronize_channels:
71
+ shift_value = float(sample_scalar((self.mu, self.sigma), image=img, channel=None))
72
+ shift = torch.full((c,), shift_value, device=img.device)
67
73
  else:
68
- global_brightness = np.random.normal(self.mu, self.sigma)
69
- brightness = [global_brightness if apply else None for apply in apply_channel]
74
+ shift = torch.empty(c, device=img.device).normal_(float(self.mu), float(self.sigma))
70
75
 
71
- return {'brightness': brightness}
76
+ return {"apply_to_channel": apply_to_channel, "shift": shift}
72
77
 
73
78
  def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
74
- for c, b in enumerate(params['brightness']):
75
- if b is not None:
76
- img[c].add_(float(b))
79
+ if params["shift"] is None:
80
+ return img
81
+
82
+ apply_idx = params["apply_to_channel"]
83
+ if apply_idx.numel() == 0:
84
+ return img
85
+
86
+ shift = params["shift"]
87
+ # Build full per-channel shift vector; non-selected channels get shift 0
88
+ shift_full = torch.zeros((img.shape[0],), device=img.device, dtype=img.dtype)
89
+ shift_full[apply_idx] = shift[apply_idx]
90
+
91
+ view_shape = (img.shape[0],) + (1,) * (img.ndim - 1)
92
+ img.add_(shift_full.view(view_shape))
77
93
  return img
78
94
 
79
95
 
@@ -84,7 +100,8 @@ if __name__ == '__main__':
84
100
  os.environ['OMP_NUM_THREADS'] = '1'
85
101
  torch.set_num_threads(1)
86
102
 
87
- mbt = MultiplicativeBrightnessTransform((0.5, 2.), False, 1)
103
+ # mbt = BrightnessAdditiveTransform(0, 0.5,True, 1)
104
+ mbt = MultiplicativeBrightnessTransform((0.5, 2),False, 1)
88
105
 
89
106
  times_torch = []
90
107
  for _ in range(1000):
@@ -26,44 +26,70 @@ class BGContrast():
26
26
  return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
27
27
 
28
28
 
29
+ import torch
30
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
31
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
32
+
33
+
29
34
  class ContrastTransform(ImageOnlyTransform):
30
- def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
35
+ def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1.0):
31
36
  super().__init__()
32
37
  self.contrast_range = contrast_range
33
38
  self.preserve_range = preserve_range
34
39
  self.synchronize_channels = synchronize_channels
35
- self.p_per_channel = p_per_channel
40
+ self.p_per_channel = float(p_per_channel)
36
41
 
37
42
  def get_parameters(self, **data_dict) -> dict:
38
- shape = data_dict['image'].shape
39
- apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
40
- if self.synchronize_channels:
41
- multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
43
+ img = data_dict["image"]
44
+ c = img.shape[0]
45
+
46
+ # sample on correct device
47
+ apply_idx = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
48
+ n = apply_idx.numel()
49
+
50
+ if n == 0:
51
+ multipliers = None
52
+ elif self.synchronize_channels:
53
+ m = float(sample_scalar(self.contrast_range, image=img, channel=None))
54
+ multipliers = torch.full((n,), m, device=img.device, dtype=img.dtype)
42
55
  else:
43
- multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
44
- return {
45
- 'apply_to_channel': apply_to_channel,
46
- 'multipliers': multipliers
47
- }
56
+ # Still a Python loop because sample_scalar is scalar-by-scalar
57
+ # Use .tolist() to avoid iterating tensor scalars in Python
58
+ ms = [sample_scalar(self.contrast_range, image=img, channel=int(ch)) for ch in apply_idx.tolist()]
59
+ multipliers = torch.as_tensor(ms, device=img.device, dtype=img.dtype)
60
+
61
+ return {"apply_to_channel": apply_idx, "multipliers": multipliers}
48
62
 
49
63
  def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
50
- if len(params['apply_to_channel']) == 0:
64
+ idx = params["apply_to_channel"]
65
+ multipliers = params["multipliers"]
66
+ if multipliers is None or idx.numel() == 0:
51
67
  return img
52
- # array notation is not faster, let's leave it like this
53
- for i in range(len(params['apply_to_channel'])):
54
- c = params['apply_to_channel'][i]
55
- mean = img[c].mean()
56
- if self.preserve_range:
57
- minm = img[c].min()
58
- maxm = img[c].max()
59
-
60
- # this is faster than having it in one line because this circumvents reallocating memory
61
- img[c] -= mean
62
- img[c] *= params['multipliers'][i]
63
- img[c] += mean
64
-
65
- if self.preserve_range:
66
- img[c].clamp_(minm, maxm)
68
+
69
+ if self.preserve_range:
70
+ for i in range(idx.numel()):
71
+ c = int(idx[i])
72
+ m = multipliers[i]
73
+
74
+ x = img[c]
75
+ mean = x.mean()
76
+ minm = x.min()
77
+ maxm = x.max()
78
+
79
+ x.sub_(mean)
80
+ x.mul_(m)
81
+ x.add_(mean)
82
+ x.clamp_(minm, maxm)
83
+ else:
84
+ for i in range(idx.numel()):
85
+ c = int(idx[i])
86
+ m = multipliers[i]
87
+
88
+ x = img[c]
89
+ mean = x.mean()
90
+ x.sub_(mean)
91
+ x.mul_(m)
92
+ x.add_(mean)
67
93
 
68
94
  return img
69
95
 
@@ -0,0 +1,135 @@
1
+ from typing import Optional
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class GammaTransform(ImageOnlyTransform):
9
+ def __init__(self,
10
+ gamma: RandomScalar,
11
+ p_invert_image: float,
12
+ synchronize_channels: bool,
13
+ p_per_channel: float,
14
+ p_retain_stats: float):
15
+ super().__init__()
16
+ self.gamma = gamma
17
+ self.p_invert_image = float(p_invert_image)
18
+ self.synchronize_channels = synchronize_channels
19
+ self.p_per_channel = float(p_per_channel)
20
+ self.p_retain_stats = float(p_retain_stats)
21
+
22
+ def get_parameters(self, **data_dict) -> dict:
23
+ img: torch.Tensor = data_dict["image"]
24
+ c = img.shape[0]
25
+ device = img.device
26
+ dtype = img.dtype
27
+
28
+ apply_idx = (torch.rand(c, device=device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
29
+ n = apply_idx.numel()
30
+ if n == 0:
31
+ return {"apply_to_channel": apply_idx,
32
+ "retain_stats": None,
33
+ "invert_image": None,
34
+ "gamma": None}
35
+
36
+ retain_stats = (torch.rand(n, device=device) < self.p_retain_stats)
37
+ invert_image = (torch.rand(n, device=device) < self.p_invert_image)
38
+
39
+ if self.synchronize_channels:
40
+ g = float(sample_scalar(self.gamma, image=img, channel=None))
41
+ gamma = torch.full((n,), g, device=device, dtype=dtype)
42
+ else:
43
+ # sample_scalar is scalar-based; keep loop but avoid tensor scalar iteration
44
+ gs = [float(sample_scalar(self.gamma, image=img, channel=int(ch))) for ch in apply_idx.tolist()]
45
+ gamma = torch.as_tensor(gs, device=device, dtype=dtype)
46
+
47
+ return {
48
+ "apply_to_channel": apply_idx,
49
+ "retain_stats": retain_stats,
50
+ "invert_image": invert_image,
51
+ "gamma": gamma,
52
+ }
53
+
54
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
55
+ idx: torch.Tensor = params["apply_to_channel"]
56
+ if idx.numel() == 0:
57
+ return img
58
+
59
+ retain_stats: torch.Tensor = params["retain_stats"]
60
+ invert_image: torch.Tensor = params["invert_image"]
61
+ gamma: torch.Tensor = params["gamma"]
62
+
63
+ # constants
64
+ eps = 1e-7
65
+
66
+ # Loop over selected channels (good for small C)
67
+ for k in range(idx.numel()):
68
+ c = int(idx[k])
69
+ r = bool(retain_stats[k])
70
+ inv = bool(invert_image[k])
71
+ g = gamma[k]
72
+
73
+ x = img[c]
74
+
75
+ if inv:
76
+ x.mul_(-1)
77
+
78
+ if r:
79
+ mean = x.mean()
80
+ std = x.std()
81
+
82
+ minm = x.min()
83
+ maxm = x.max()
84
+ rnge = maxm - minm
85
+ denom = torch.clamp(rnge, min=eps)
86
+
87
+ # In-place gamma: x = (((x - min) / denom) ** g) * rnge + min
88
+ x.sub_(minm)
89
+ x.div_(denom)
90
+ x.pow_(g)
91
+ x.mul_(rnge)
92
+ x.add_(minm)
93
+
94
+ if r:
95
+ mn_here = x.mean()
96
+ std_here = x.std()
97
+ x.sub_(mn_here)
98
+ x.mul_(std / torch.clamp(std_here, min=eps))
99
+ x.add_(mean)
100
+
101
+ if inv:
102
+ x.mul_(-1)
103
+
104
+ return img
105
+
106
+
107
+
108
+ if __name__ == '__main__':
109
+ from time import time
110
+ import numpy as np
111
+ import os
112
+
113
+ os.environ['OMP_NUM_THREADS'] = '1'
114
+ torch.set_num_threads(1)
115
+
116
+ mbt = GammaTransform((0.7, 1.5), 0, False, 1, 1)
117
+
118
+ times_torch = []
119
+ for _ in range(100):
120
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
121
+ st = time()
122
+ out = mbt(**data_dict)
123
+ times_torch.append(time() - st)
124
+ print('torch', np.mean(times_torch))
125
+
126
+ from batchgenerators.transforms.color_transforms import GammaTransform as BGGamma
127
+
128
+ gnt_bg = BGGamma((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1)
129
+ times_bg = []
130
+ for _ in range(100):
131
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
132
+ st = time()
133
+ out = gnt_bg(**data_dict)
134
+ times_bg.append(time() - st)
135
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,104 @@
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class GaussianNoiseTransform(ImageOnlyTransform):
9
+ def __init__(self,
10
+ noise_variance: RandomScalar = (0, 0.1),
11
+ p_per_channel: float = 1.,
12
+ synchronize_channels: bool = False):
13
+ super().__init__()
14
+ self.noise_variance = noise_variance
15
+ self.p_per_channel = p_per_channel
16
+ self.synchronize_channels = synchronize_channels
17
+
18
+ def get_parameters(self, **data_dict) -> dict:
19
+ img = data_dict["image"]
20
+ c = img.shape[0]
21
+
22
+ # bool mask on same device as image
23
+ apply = torch.rand(c, device=img.device) < self.p_per_channel
24
+
25
+ # store also count / indices to avoid recomputing later
26
+ idx = apply.nonzero(as_tuple=False).flatten()
27
+ n = idx.numel()
28
+
29
+ if n == 0:
30
+ sigmas = None
31
+ elif self.synchronize_channels:
32
+ sigmas = sample_scalar(self.noise_variance, img)
33
+ else:
34
+ # still uses sample_scalar, but avoids list->cat in _apply
35
+ # if sample_scalar is cheap, this is fine; otherwise see note below
36
+ sigmas = [sample_scalar(self.noise_variance, img) for _ in range(n)]
37
+
38
+ return {"apply_mask": apply, "apply_idx": idx, "num_apply": n, "sigmas": sigmas}
39
+
40
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
41
+ n = params["num_apply"]
42
+ if n == 0:
43
+ return img
44
+
45
+ idx = params["apply_idx"]
46
+ spatial = img.shape[1:]
47
+ device = img.device
48
+ dtype = img.dtype
49
+
50
+ sigmas = params["sigmas"]
51
+
52
+ if sigmas is None:
53
+ return img
54
+
55
+ # Create noise only for selected channels
56
+ if not self.synchronize_channels:
57
+ # vectorize per-channel sigma by creating a tensor of shape (n, 1, 1, ...)
58
+ # list->tensor is small (n floats), then broadcast
59
+ sigma_t = torch.as_tensor(sigmas, device=device, dtype=dtype)
60
+ view_shape = (n,) + (1,) * len(spatial)
61
+ sigma_t = sigma_t.view(view_shape)
62
+
63
+ noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_()
64
+ noise.mul_(sigma_t)
65
+ else:
66
+ sigma = sigmas
67
+ noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_(mean=0.0, std=float(sigma))
68
+
69
+ # Advanced indexing (img[idx]) returns a copy, so use indexed assignment
70
+ # to make sure modifications are written back to img.
71
+ img[idx] += noise
72
+ return img
73
+
74
+
75
+ if __name__ == "__main__":
76
+ from time import time
77
+ import numpy as np
78
+
79
+ os.environ['OMP_NUM_THREADS'] = '1'
80
+ torch.set_num_threads(1)
81
+
82
+ gnt = GaussianNoiseTransform((0, 0.1), 1, False)
83
+
84
+ times = []
85
+ for _ in range(1000):
86
+ data_dict = {'image': torch.ones((2, 32, 32, 32))}
87
+ st = time()
88
+ out = gnt(**data_dict)
89
+ times.append(time() - st)
90
+ print('torch', np.mean(times))
91
+
92
+ from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform
93
+
94
+ gnt_bg = GaussianNoiseTransform((0, 0.1), 1, 1, True)
95
+
96
+ times = []
97
+ for _ in range(1000):
98
+ data_dict = {'data': np.ones((1, 2, 32, 32, 32))}
99
+ st = time()
100
+ out = gnt_bg(**data_dict)
101
+ times.append(time() - st)
102
+
103
+ print('bg', np.mean(times))
104
+ # torch is 2.5x faster
@@ -67,12 +67,11 @@ class CutOffOutliersTransform(ImageOnlyTransform):
67
67
  if retain_std[c]:
68
68
  orig_std = img_c.std()
69
69
 
70
- # Use numpy only to calculate percentiles
71
- img_c_np = img_c.detach().cpu().numpy()
72
- lower_val = np.percentile(img_c_np, perc[0])
73
- upper_val = np.percentile(img_c_np, perc[1])
70
+ # Percentiles in torch to avoid numpy roundtrip
71
+ q = torch.tensor([perc[0] / 100.0, perc[1] / 100.0], device=img_c.device, dtype=torch.float32)
72
+ lower_val, upper_val = torch.quantile(img_c.float(), q)
74
73
 
75
- img_c_clipped = img_c.clamp(min=float(lower_val), max=float(upper_val))
74
+ img_c_clipped = img_c.clamp(min=lower_val.item(), max=upper_val.item())
76
75
 
77
76
  if retain_std[c]:
78
77
  clipped_std = img_c_clipped.std()
@@ -99,4 +98,4 @@ if __name__ == '__main__':
99
98
  params = transform.get_parameters(image=image)
100
99
  image_clipped = transform._apply_to_image(image.clone(), **params)
101
100
 
102
- view_batch(image, image_clipped, image_clipped-image)
101
+ view_batch(image, image_clipped, image_clipped-image)
@@ -3,14 +3,11 @@ from typing import Union, List, Tuple, Callable
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
- from fft_conv_pytorch import fft_conv
6
+ from skimage.morphology import ball, disk
7
7
 
8
+ from batchgeneratorsv2.helpers.fft_conv import fft_conv
8
9
  from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
9
10
  from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
10
- from skimage.morphology import ball, disk
11
- from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
12
-
13
- import torch.nn.functional as F
14
11
 
15
12
 
16
13
  def binary_dilation_torch(input_tensor, structure_element):
@@ -71,7 +71,7 @@ class BlankRectangleTransform(ImageOnlyTransform):
71
71
  }
72
72
 
73
73
  def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
74
- out = img.clone()
74
+ out = img
75
75
  for c, (apply, rects) in enumerate(zip(params['apply_channel'], params['rectangles'])):
76
76
  if not apply:
77
77
  continue
@@ -147,4 +147,4 @@ if __name__ == '__main__':
147
147
  image_aug = transform._apply_to_image(image, **params)
148
148
 
149
149
  from batchviewer import view_batch
150
- view_batch(image, image_aug)
150
+ view_batch(image, image_aug)
@@ -6,9 +6,9 @@ import torch
6
6
  from skimage.data import camera
7
7
  from torch.nn.functional import pad, conv3d, conv1d, conv2d
8
8
 
9
+ from batchgeneratorsv2.helpers.fft_conv import fft_conv
9
10
  from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
10
11
  from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
11
- from fft_conv_pytorch import fft_conv
12
12
 
13
13
 
14
14
  def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6):
@@ -219,7 +219,7 @@ if __name__ == "__main__":
219
219
  shape = (128, 164, 64)
220
220
  num_warmup_for_benchmark = 1
221
221
  num_repeats = 10
222
- for sigma_range in (0.1, 1, 10):
222
+ for sigma_range in (0.1, 1, 10, 20):
223
223
  print(shape, sigma_range)
224
224
  gnt2 = GaussianBlurTransform(sigma_range, False, False, 1, benchmark=False)
225
225
  times = []
@@ -257,4 +257,4 @@ if __name__ == "__main__":
257
257
  print('batchgenerator', np.median(times))
258
258
  print()
259
259
  #
260
- # print(gnt.benchmark_use_fft)
260
+ # print(gnt.benchmark_use_fft)