batchgeneratorsv2 0.2.3__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/PKG-INFO +3 -3
  2. batchgeneratorsv2-0.3.2/batchgeneratorsv2/helpers/fft_conv.py +149 -0
  3. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/base/basic_transform.py +1 -1
  4. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  5. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/intensity/contrast.py +54 -27
  6. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  7. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  8. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  9. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  10. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  11. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  12. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  13. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  14. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +2 -5
  15. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  16. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +3 -3
  17. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  18. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/noise/rician.py +61 -0
  19. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  20. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  21. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/low_resolution.py +1 -1
  22. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  23. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/spatial.py +292 -264
  24. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/utils/compose.py +89 -0
  25. batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  26. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +4 -2
  27. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/pseudo2d.py +3 -5
  28. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/random.py +23 -0
  29. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/remove_label.py +4 -1
  30. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/seg_to_regions.py +1 -1
  31. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/PKG-INFO +3 -3
  32. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/SOURCES.txt +15 -1
  33. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/requires.txt +0 -1
  34. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/pyproject.toml +2 -3
  35. batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/intensity/brightness.py +0 -63
  36. batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/intensity/gamma.py +0 -88
  37. batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -80
  38. batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/utils/compose.py +0 -14
  39. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/LICENSE +0 -0
  40. {batchgeneratorsv2-0.2.3/batchgeneratorsv2 → batchgeneratorsv2-0.3.2/batchgeneratorsv2/benchmarks}/__init__.py +0 -0
  41. {batchgeneratorsv2-0.2.3/batchgeneratorsv2/benchmarks → batchgeneratorsv2-0.3.2/batchgeneratorsv2/benchmarks/bg_comparison}/__init__.py +0 -0
  42. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +0 -0
  43. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +0 -0
  44. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
  45. {batchgeneratorsv2-0.2.3/batchgeneratorsv2/benchmarks/bg_comparison → batchgeneratorsv2-0.3.2/batchgeneratorsv2/dataloading}/__init__.py +0 -0
  46. {batchgeneratorsv2-0.2.3/batchgeneratorsv2/dataloading → batchgeneratorsv2-0.3.2/batchgeneratorsv2/helpers}/__init__.py +0 -0
  47. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
  48. {batchgeneratorsv2-0.2.3/batchgeneratorsv2/helpers → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms}/__init__.py +0 -0
  49. {batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/base}/__init__.py +0 -0
  50. {batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/base → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity}/__init__.py +0 -0
  51. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/intensity/inversion.py +0 -0
  52. {batchgeneratorsv2-0.2.3/batchgeneratorsv2/transforms/intensity → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/local}/__init__.py +0 -0
  53. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  54. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
  55. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
  56. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  57. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  58. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
  59. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/transpose.py +0 -0
  60. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  61. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
  62. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
  63. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
  64. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
  65. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/readme.md +0 -0
  66. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/setup.cfg +0 -0
  67. {batchgeneratorsv2-0.2.3 → batchgeneratorsv2-0.3.2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: batchgeneratorsv2
3
- Version: 0.2.3
3
+ Version: 0.3.2
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -220,8 +220,8 @@ Description-Content-Type: text/markdown
220
220
  License-File: LICENSE
221
221
  Requires-Dist: torch>=2.0.0
222
222
  Requires-Dist: numpy
223
- Requires-Dist: fft-conv-pytorch
224
223
  Requires-Dist: batchgenerators>=0.25
224
+ Dynamic: license-file
225
225
 
226
226
  # batchgeneratorsv2
227
227
  This repository is work in progress. If builds upon the [batchgenerators](https://github.com/MIC-DKFZ/batchgenerators)
@@ -0,0 +1,149 @@
1
+ from typing import Iterable, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as f
5
+ from torch import Tensor, nn
6
+ from torch.fft import irfftn, rfftn
7
+ from math import ceil, floor
8
+
9
+
10
+ # Taken from here: https://github.com/vcasellesb/fft-conv-pytorch/tree/non-tup-slice-fix. THANK YOU!
11
+ # Original codebase is here: https://github.com/fkodom/fft-conv-pytorch -> unfortunately was not updated for pytorch 2.9.
12
+
13
+
14
+ def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
15
+ """Multiplies two complex-valued tensors."""
16
+ # Scalar matrix multiplication of two tensors, over only the first channel
17
+ # dimensions. Dimensions 3 and higher will have the same shape after multiplication.
18
+ # We also allow for "grouped" multiplications, where multiple sections of channels
19
+ # are multiplied independently of one another (required for group convolutions).
20
+ a = a.view(a.size(0), groups, -1, *a.shape[2:])
21
+ b = b.view(groups, -1, *b.shape[1:])
22
+
23
+ a = torch.movedim(a, 2, a.dim() - 1).unsqueeze(-2)
24
+ b = torch.movedim(b, (1, 2), (b.dim() - 1, b.dim() - 2))
25
+
26
+ # complex value matrix multiplication
27
+ real = a.real @ b.real - a.imag @ b.imag
28
+ imag = a.imag @ b.real + a.real @ b.imag
29
+ real = torch.movedim(real, real.dim() - 1, 2).squeeze(-1)
30
+ imag = torch.movedim(imag, imag.dim() - 1, 2).squeeze(-1)
31
+ c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
32
+ c.real, c.imag = real, imag
33
+
34
+ return c.view(c.size(0), -1, *c.shape[3:])
35
+
36
+
37
+ def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:
38
+ """Casts to a tuple with length 'n'. Useful for automatically computing the
39
+ padding and stride for convolutions, where users may only provide an integer.
40
+
41
+ Args:
42
+ val: (Union[int, Iterable[int]]) Value to cast into a tuple.
43
+ n: (int) Desired length of the tuple
44
+
45
+ Returns:
46
+ (Tuple[int, ...]) Tuple of length 'n'
47
+ """
48
+ if isinstance(val, Iterable):
49
+ out = tuple(val)
50
+ if len(out) == n:
51
+ return out
52
+ else:
53
+ raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.")
54
+ else:
55
+ return n * (val,)
56
+
57
+
58
+ def fft_conv(
59
+ signal: Tensor,
60
+ kernel: Tensor,
61
+ bias: Tensor = None,
62
+ padding: Union[int, Iterable[int], str] = 0,
63
+ padding_mode: str = "constant",
64
+ stride: Union[int, Iterable[int]] = 1,
65
+ dilation: Union[int, Iterable[int]] = 1,
66
+ groups: int = 1,
67
+ ) -> Tensor:
68
+ """Performs N-d convolution of Tensors using a fast fourier transform, which
69
+ is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
70
+ the convolution (in order ot mimic the PyTorch direct convolution).
71
+
72
+ Args:
73
+ signal: (Tensor) Input tensor to be convolved with the kernel.
74
+ kernel: (Tensor) Convolution kernel.
75
+ bias: (Tensor) Bias tensor to add to the output.
76
+ padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
77
+ input on the last dimension. If str, "same" supported to pad input for size preservation.
78
+ padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
79
+ reflection not available for 3d.
80
+ stride: (Union[int, Iterable[int]) Stride size for computing output values.
81
+ dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
82
+ groups: (int) Number of groups for the convolution.
83
+
84
+ Returns:
85
+ (Tensor) Convolved tensor
86
+ """
87
+
88
+ # Cast padding, stride & dilation to tuples.
89
+ n = signal.ndim - 2
90
+ stride_ = to_ntuple(stride, n=n)
91
+ dilation_ = to_ntuple(dilation, n=n)
92
+ if isinstance(padding, str):
93
+ if padding == "same":
94
+ if stride != 1 or dilation != 1:
95
+ raise ValueError("stride must be 1 for padding='same'.")
96
+ padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
97
+ # else:
98
+ # raise ValueError(f"Padding mode {padding} not supported.")
99
+ else:
100
+ padding_ = to_ntuple(padding, n=n)
101
+
102
+ # internal dilation offsets
103
+ offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
104
+ offset[(slice(None), slice(None), *((0,) * n))] = 1.0
105
+
106
+ # correct the kernel by cutting off unwanted dilation trailing zeros
107
+ cutoff = tuple(slice(None, -d + 1 if d != 1 else None) for d in dilation_)
108
+
109
+ # pad the kernel internally according to the dilation parameters
110
+ kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
111
+
112
+ # Pad the input signal & kernel tensors (round to support even sized convolutions)
113
+ signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
114
+ signal = f.pad(signal, signal_padding, mode=padding_mode)
115
+
116
+ # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
117
+ # have *even* length. Just pad with one more zero if the final dimension is odd.
118
+ signal_size = signal.size() # original signal size without padding to even
119
+ if signal.size(-1) % 2 != 0:
120
+ signal = f.pad(signal, [0, 1])
121
+
122
+ kernel_padding = [
123
+ pad
124
+ for i in reversed(range(2, signal.ndim))
125
+ for pad in [0, signal.size(i) - kernel.size(i)]
126
+ ]
127
+ padded_kernel = f.pad(kernel, kernel_padding)
128
+
129
+ # Perform fourier convolution -- FFT, matrix multiply, then IFFT
130
+ signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
131
+ kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))
132
+
133
+ kernel_fr.imag *= -1
134
+ output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
135
+ output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
136
+
137
+ # Remove extra padded values
138
+ crop_slices = (slice(None), slice(None)) + tuple(
139
+ slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2])
140
+ for i in range(2, signal.ndim)
141
+ )
142
+ output = output[crop_slices].contiguous()
143
+
144
+ # Optionally, add a bias term before returning.
145
+ if bias is not None:
146
+ bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
147
+ output += bias.view(bias_shape)
148
+
149
+ return output
@@ -22,7 +22,7 @@ class BasicTransform(abc.ABC):
22
22
  data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
23
23
 
24
24
  if data_dict.get('regression_target') is not None:
25
- data_dict['regression_target'] = self._apply_to_segmentation(data_dict['regression_target'], **params)
25
+ data_dict['regression_target'] = self._apply_to_regr_target(data_dict['regression_target'], **params)
26
26
 
27
27
  if data_dict.get('segmentation') is not None:
28
28
  data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
@@ -0,0 +1,123 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class MultiplicativeBrightnessTransform(ImageOnlyTransform):
9
+ def __init__(self, multiplier_range: RandomScalar, synchronize_channels: bool, p_per_channel: float = 1):
10
+ super().__init__()
11
+ self.multiplier_range = multiplier_range
12
+ self.synchronize_channels = synchronize_channels
13
+ self.p_per_channel = p_per_channel
14
+
15
+ def get_parameters(self, **data_dict) -> dict:
16
+ shape = data_dict['image'].shape
17
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
18
+ if self.synchronize_channels:
19
+ multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
20
+ else:
21
+ multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
22
+ return {
23
+ 'apply_to_channel': apply_to_channel,
24
+ 'multipliers': multipliers
25
+ }
26
+
27
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
28
+ if len(params['apply_to_channel']) == 0:
29
+ return img
30
+ # even though this is array notation it's a lot slower. Shame shame
31
+ # img[params['apply_to_channel']] *= params['multipliers'].view(-1, *[1]*(img.ndim - 1))
32
+ for c, m in zip(params['apply_to_channel'], params['multipliers']):
33
+ img[c] *= m
34
+ return img
35
+
36
+
37
+ class BrightnessAdditiveTransform(ImageOnlyTransform):
38
+ """
39
+ Adds random additive brightness noise sampled from a Gaussian distribution (mu, sigma).
40
+
41
+ Supports either synchronized brightness shift across all channels or per-channel brightness shift.
42
+
43
+ Args:
44
+ mu (float): Mean of the Gaussian used to sample brightness shifts.
45
+ sigma (float): Standard deviation of the Gaussian.
46
+ synchronize_channels (bool): If True, brightness shifts are shared across all channels.
47
+ p_per_channel (float): Probability to apply the brightness shift to each channel.
48
+ """
49
+
50
+ def __init__(self,
51
+ mu: float,
52
+ sigma: float,
53
+ synchronize_channels: bool = True, # Changed to synchronize_channels
54
+ p_per_channel: float = 1.0):
55
+ super().__init__()
56
+ self.mu = mu
57
+ self.sigma = sigma
58
+ self.synchronize_channels = synchronize_channels # Now it's being used
59
+ self.p_per_channel = p_per_channel
60
+
61
+ def get_parameters(self, **data_dict) -> dict:
62
+ img = data_dict["image"]
63
+ c = img.shape[0]
64
+ apply_to_channel = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
65
+
66
+ if len(apply_to_channel) == 0:
67
+ return {"apply_to_channel": apply_to_channel, "shift": None}
68
+
69
+ # Apply either synchronized or per-channel brightness shift
70
+ if self.synchronize_channels:
71
+ shift_value = float(sample_scalar((self.mu, self.sigma), image=img, channel=None))
72
+ shift = torch.full((c,), shift_value, device=img.device)
73
+ else:
74
+ shift = torch.empty(c, device=img.device).normal_(float(self.mu), float(self.sigma))
75
+
76
+ return {"apply_to_channel": apply_to_channel, "shift": shift}
77
+
78
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
79
+ if params["shift"] is None:
80
+ return img
81
+
82
+ apply_idx = params["apply_to_channel"]
83
+ if apply_idx.numel() == 0:
84
+ return img
85
+
86
+ shift = params["shift"]
87
+ # Build full per-channel shift vector; non-selected channels get shift 0
88
+ shift_full = torch.zeros((img.shape[0],), device=img.device, dtype=img.dtype)
89
+ shift_full[apply_idx] = shift[apply_idx]
90
+
91
+ view_shape = (img.shape[0],) + (1,) * (img.ndim - 1)
92
+ img.add_(shift_full.view(view_shape))
93
+ return img
94
+
95
+
96
+ if __name__ == '__main__':
97
+ from time import time
98
+ import os
99
+
100
+ os.environ['OMP_NUM_THREADS'] = '1'
101
+ torch.set_num_threads(1)
102
+
103
+ # mbt = BrightnessAdditiveTransform(0, 0.5,True, 1)
104
+ mbt = MultiplicativeBrightnessTransform((0.5, 2),False, 1)
105
+
106
+ times_torch = []
107
+ for _ in range(1000):
108
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
109
+ st = time()
110
+ out = mbt(**data_dict)
111
+ times_torch.append(time() - st)
112
+ print('torch', np.mean(times_torch))
113
+
114
+ from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform
115
+
116
+ gnt_bg = BrightnessMultiplicativeTransform((0.5, 2), True, p_per_sample=1)
117
+ times_bg = []
118
+ for _ in range(1000):
119
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
120
+ st = time()
121
+ out = gnt_bg(**data_dict)
122
+ times_bg.append(time() - st)
123
+ print('bg', np.mean(times_bg))
@@ -25,44 +25,71 @@ class BGContrast():
25
25
  def __repr__(self):
26
26
  return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
27
27
 
28
+
29
+ import torch
30
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
31
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
32
+
33
+
28
34
  class ContrastTransform(ImageOnlyTransform):
29
- def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
35
+ def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1.0):
30
36
  super().__init__()
31
37
  self.contrast_range = contrast_range
32
38
  self.preserve_range = preserve_range
33
39
  self.synchronize_channels = synchronize_channels
34
- self.p_per_channel = p_per_channel
40
+ self.p_per_channel = float(p_per_channel)
35
41
 
36
42
  def get_parameters(self, **data_dict) -> dict:
37
- shape = data_dict['image'].shape
38
- apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
39
- if self.synchronize_channels:
40
- multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
43
+ img = data_dict["image"]
44
+ c = img.shape[0]
45
+
46
+ # sample on correct device
47
+ apply_idx = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
48
+ n = apply_idx.numel()
49
+
50
+ if n == 0:
51
+ multipliers = None
52
+ elif self.synchronize_channels:
53
+ m = float(sample_scalar(self.contrast_range, image=img, channel=None))
54
+ multipliers = torch.full((n,), m, device=img.device, dtype=img.dtype)
41
55
  else:
42
- multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
43
- return {
44
- 'apply_to_channel': apply_to_channel,
45
- 'multipliers': multipliers
46
- }
56
+ # Still a Python loop because sample_scalar is scalar-by-scalar
57
+ # Use .tolist() to avoid iterating tensor scalars in Python
58
+ ms = [sample_scalar(self.contrast_range, image=img, channel=int(ch)) for ch in apply_idx.tolist()]
59
+ multipliers = torch.as_tensor(ms, device=img.device, dtype=img.dtype)
60
+
61
+ return {"apply_to_channel": apply_idx, "multipliers": multipliers}
47
62
 
48
63
  def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
49
- if len(params['apply_to_channel']) == 0:
64
+ idx = params["apply_to_channel"]
65
+ multipliers = params["multipliers"]
66
+ if multipliers is None or idx.numel() == 0:
50
67
  return img
51
- # array notation is not faster, let's leave it like this
52
- for i in range(len(params['apply_to_channel'])):
53
- c = params['apply_to_channel'][i]
54
- mean = img[c].mean()
55
- if self.preserve_range:
56
- minm = img[c].min()
57
- maxm = img[c].max()
58
-
59
- # this is faster than having it in one line because this circumvents reallocating memory
60
- img[c] -= mean
61
- img[c] *= params['multipliers'][i]
62
- img[c] += mean
63
-
64
- if self.preserve_range:
65
- img[c].clamp_(minm, maxm)
68
+
69
+ if self.preserve_range:
70
+ for i in range(idx.numel()):
71
+ c = int(idx[i])
72
+ m = multipliers[i]
73
+
74
+ x = img[c]
75
+ mean = x.mean()
76
+ minm = x.min()
77
+ maxm = x.max()
78
+
79
+ x.sub_(mean)
80
+ x.mul_(m)
81
+ x.add_(mean)
82
+ x.clamp_(minm, maxm)
83
+ else:
84
+ for i in range(idx.numel()):
85
+ c = int(idx[i])
86
+ m = multipliers[i]
87
+
88
+ x = img[c]
89
+ mean = x.mean()
90
+ x.sub_(mean)
91
+ x.mul_(m)
92
+ x.add_(mean)
66
93
 
67
94
  return img
68
95
 
@@ -0,0 +1,135 @@
1
+ from typing import Optional
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class GammaTransform(ImageOnlyTransform):
9
+ def __init__(self,
10
+ gamma: RandomScalar,
11
+ p_invert_image: float,
12
+ synchronize_channels: bool,
13
+ p_per_channel: float,
14
+ p_retain_stats: float):
15
+ super().__init__()
16
+ self.gamma = gamma
17
+ self.p_invert_image = float(p_invert_image)
18
+ self.synchronize_channels = synchronize_channels
19
+ self.p_per_channel = float(p_per_channel)
20
+ self.p_retain_stats = float(p_retain_stats)
21
+
22
+ def get_parameters(self, **data_dict) -> dict:
23
+ img: torch.Tensor = data_dict["image"]
24
+ c = img.shape[0]
25
+ device = img.device
26
+ dtype = img.dtype
27
+
28
+ apply_idx = (torch.rand(c, device=device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
29
+ n = apply_idx.numel()
30
+ if n == 0:
31
+ return {"apply_to_channel": apply_idx,
32
+ "retain_stats": None,
33
+ "invert_image": None,
34
+ "gamma": None}
35
+
36
+ retain_stats = (torch.rand(n, device=device) < self.p_retain_stats)
37
+ invert_image = (torch.rand(n, device=device) < self.p_invert_image)
38
+
39
+ if self.synchronize_channels:
40
+ g = float(sample_scalar(self.gamma, image=img, channel=None))
41
+ gamma = torch.full((n,), g, device=device, dtype=dtype)
42
+ else:
43
+ # sample_scalar is scalar-based; keep loop but avoid tensor scalar iteration
44
+ gs = [float(sample_scalar(self.gamma, image=img, channel=int(ch))) for ch in apply_idx.tolist()]
45
+ gamma = torch.as_tensor(gs, device=device, dtype=dtype)
46
+
47
+ return {
48
+ "apply_to_channel": apply_idx,
49
+ "retain_stats": retain_stats,
50
+ "invert_image": invert_image,
51
+ "gamma": gamma,
52
+ }
53
+
54
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
55
+ idx: torch.Tensor = params["apply_to_channel"]
56
+ if idx.numel() == 0:
57
+ return img
58
+
59
+ retain_stats: torch.Tensor = params["retain_stats"]
60
+ invert_image: torch.Tensor = params["invert_image"]
61
+ gamma: torch.Tensor = params["gamma"]
62
+
63
+ # constants
64
+ eps = 1e-7
65
+
66
+ # Loop over selected channels (good for small C)
67
+ for k in range(idx.numel()):
68
+ c = int(idx[k])
69
+ r = bool(retain_stats[k])
70
+ inv = bool(invert_image[k])
71
+ g = gamma[k]
72
+
73
+ x = img[c]
74
+
75
+ if inv:
76
+ x.mul_(-1)
77
+
78
+ if r:
79
+ mean = x.mean()
80
+ std = x.std()
81
+
82
+ minm = x.min()
83
+ maxm = x.max()
84
+ rnge = maxm - minm
85
+ denom = torch.clamp(rnge, min=eps)
86
+
87
+ # In-place gamma: x = (((x - min) / denom) ** g) * rnge + min
88
+ x.sub_(minm)
89
+ x.div_(denom)
90
+ x.pow_(g)
91
+ x.mul_(rnge)
92
+ x.add_(minm)
93
+
94
+ if r:
95
+ mn_here = x.mean()
96
+ std_here = x.std()
97
+ x.sub_(mn_here)
98
+ x.mul_(std / torch.clamp(std_here, min=eps))
99
+ x.add_(mean)
100
+
101
+ if inv:
102
+ x.mul_(-1)
103
+
104
+ return img
105
+
106
+
107
+
108
+ if __name__ == '__main__':
109
+ from time import time
110
+ import numpy as np
111
+ import os
112
+
113
+ os.environ['OMP_NUM_THREADS'] = '1'
114
+ torch.set_num_threads(1)
115
+
116
+ mbt = GammaTransform((0.7, 1.5), 0, False, 1, 1)
117
+
118
+ times_torch = []
119
+ for _ in range(100):
120
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
121
+ st = time()
122
+ out = mbt(**data_dict)
123
+ times_torch.append(time() - st)
124
+ print('torch', np.mean(times_torch))
125
+
126
+ from batchgenerators.transforms.color_transforms import GammaTransform as BGGamma
127
+
128
+ gnt_bg = BGGamma((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1)
129
+ times_bg = []
130
+ for _ in range(100):
131
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
132
+ st = time()
133
+ out = gnt_bg(**data_dict)
134
+ times_bg.append(time() - st)
135
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,104 @@
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
5
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
6
+
7
+
8
+ class GaussianNoiseTransform(ImageOnlyTransform):
9
+ def __init__(self,
10
+ noise_variance: RandomScalar = (0, 0.1),
11
+ p_per_channel: float = 1.,
12
+ synchronize_channels: bool = False):
13
+ super().__init__()
14
+ self.noise_variance = noise_variance
15
+ self.p_per_channel = p_per_channel
16
+ self.synchronize_channels = synchronize_channels
17
+
18
+ def get_parameters(self, **data_dict) -> dict:
19
+ img = data_dict["image"]
20
+ c = img.shape[0]
21
+
22
+ # bool mask on same device as image
23
+ apply = torch.rand(c, device=img.device) < self.p_per_channel
24
+
25
+ # store also count / indices to avoid recomputing later
26
+ idx = apply.nonzero(as_tuple=False).flatten()
27
+ n = idx.numel()
28
+
29
+ if n == 0:
30
+ sigmas = None
31
+ elif self.synchronize_channels:
32
+ sigmas = sample_scalar(self.noise_variance, img)
33
+ else:
34
+ # still uses sample_scalar, but avoids list->cat in _apply
35
+ # if sample_scalar is cheap, this is fine; otherwise see note below
36
+ sigmas = [sample_scalar(self.noise_variance, img) for _ in range(n)]
37
+
38
+ return {"apply_mask": apply, "apply_idx": idx, "num_apply": n, "sigmas": sigmas}
39
+
40
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
41
+ n = params["num_apply"]
42
+ if n == 0:
43
+ return img
44
+
45
+ idx = params["apply_idx"]
46
+ spatial = img.shape[1:]
47
+ device = img.device
48
+ dtype = img.dtype
49
+
50
+ sigmas = params["sigmas"]
51
+
52
+ if sigmas is None:
53
+ return img
54
+
55
+ # Create noise only for selected channels
56
+ if not self.synchronize_channels:
57
+ # vectorize per-channel sigma by creating a tensor of shape (n, 1, 1, ...)
58
+ # list->tensor is small (n floats), then broadcast
59
+ sigma_t = torch.as_tensor(sigmas, device=device, dtype=dtype)
60
+ view_shape = (n,) + (1,) * len(spatial)
61
+ sigma_t = sigma_t.view(view_shape)
62
+
63
+ noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_()
64
+ noise.mul_(sigma_t)
65
+ else:
66
+ sigma = sigmas
67
+ noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_(mean=0.0, std=float(sigma))
68
+
69
+ # Advanced indexing (img[idx]) returns a copy, so use indexed assignment
70
+ # to make sure modifications are written back to img.
71
+ img[idx] += noise
72
+ return img
73
+
74
+
75
+ if __name__ == "__main__":
76
+ from time import time
77
+ import numpy as np
78
+
79
+ os.environ['OMP_NUM_THREADS'] = '1'
80
+ torch.set_num_threads(1)
81
+
82
+ gnt = GaussianNoiseTransform((0, 0.1), 1, False)
83
+
84
+ times = []
85
+ for _ in range(1000):
86
+ data_dict = {'image': torch.ones((2, 32, 32, 32))}
87
+ st = time()
88
+ out = gnt(**data_dict)
89
+ times.append(time() - st)
90
+ print('torch', np.mean(times))
91
+
92
+ from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform
93
+
94
+ gnt_bg = GaussianNoiseTransform((0, 0.1), 1, 1, True)
95
+
96
+ times = []
97
+ for _ in range(1000):
98
+ data_dict = {'data': np.ones((1, 2, 32, 32, 32))}
99
+ st = time()
100
+ out = gnt_bg(**data_dict)
101
+ times.append(time() - st)
102
+
103
+ print('bg', np.mean(times))
104
+ # torch is 2.5x faster