batchgeneratorsv2 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/PKG-INFO +1 -2
- batchgeneratorsv2-0.3.2/batchgeneratorsv2/helpers/fft_conv.py +149 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/intensity/brightness.py +33 -16
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/intensity/contrast.py +53 -27
- batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/intensity/random_clip.py +5 -6
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +2 -5
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/noise/blank_rectangle.py +2 -2
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/noise/gaussian_blur.py +3 -3
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/noise/sharpen.py +2 -2
- batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/low_resolution.py +1 -1
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/spatial.py +292 -264
- batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/nnunet_masking.py +4 -2
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/pseudo2d.py +3 -5
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/remove_label.py +4 -1
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/seg_to_regions.py +1 -1
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/PKG-INFO +1 -2
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/SOURCES.txt +3 -1
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/requires.txt +0 -1
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/pyproject.toml +2 -3
- batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/intensity/gamma.py +0 -88
- batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +0 -80
- batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/utils/compose.py +0 -14
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/LICENSE +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2 → batchgeneratorsv2-0.3.2/batchgeneratorsv2/benchmarks}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/benchmarks → batchgeneratorsv2-0.3.2/batchgeneratorsv2/benchmarks/bg_comparison}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/benchmarks/unique_values.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/benchmarks/bg_comparison → batchgeneratorsv2-0.3.2/batchgeneratorsv2/dataloading}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/dataloading → batchgeneratorsv2-0.3.2/batchgeneratorsv2/helpers}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/helpers/scalar_type.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/helpers → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/base}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/base/basic_transform.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/base → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/intensity}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/intensity/inversion.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/intensity → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/local}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/local/brightness_gradient.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/local/local_contrast.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/local/local_gamma.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/local/local_smoothing.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/local/local_transform.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/local → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/nnunet}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/nnunet → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/noise}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/noise/median_filter.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/noise/rician.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/noise → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/spatial}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/mirroring.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/rot90.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/spatial/transpose.py +0 -0
- {batchgeneratorsv2-0.3.0/batchgeneratorsv2/transforms/spatial → batchgeneratorsv2-0.3.2/batchgeneratorsv2/transforms/utils}/__init__.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/cropping.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/batchgeneratorsv2/transforms/utils/random.py +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/readme.md +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/setup.cfg +0 -0
- {batchgeneratorsv2-0.3.0 → batchgeneratorsv2-0.3.2}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: batchgeneratorsv2
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Batchgenerators but better
|
|
5
5
|
Author: Helmholtz Imaging Applied Computer Vision Lab
|
|
6
6
|
Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
|
|
@@ -220,7 +220,6 @@ Description-Content-Type: text/markdown
|
|
|
220
220
|
License-File: LICENSE
|
|
221
221
|
Requires-Dist: torch>=2.0.0
|
|
222
222
|
Requires-Dist: numpy
|
|
223
|
-
Requires-Dist: fft-conv-pytorch
|
|
224
223
|
Requires-Dist: batchgenerators>=0.25
|
|
225
224
|
Dynamic: license-file
|
|
226
225
|
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from typing import Iterable, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as f
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
from torch.fft import irfftn, rfftn
|
|
7
|
+
from math import ceil, floor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Taken from here: https://github.com/vcasellesb/fft-conv-pytorch/tree/non-tup-slice-fix. THANK YOU!
|
|
11
|
+
# Original codebase is here: https://github.com/fkodom/fft-conv-pytorch -> unfortunately was not updated for pytorch 2.9.
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
|
|
15
|
+
"""Multiplies two complex-valued tensors."""
|
|
16
|
+
# Scalar matrix multiplication of two tensors, over only the first channel
|
|
17
|
+
# dimensions. Dimensions 3 and higher will have the same shape after multiplication.
|
|
18
|
+
# We also allow for "grouped" multiplications, where multiple sections of channels
|
|
19
|
+
# are multiplied independently of one another (required for group convolutions).
|
|
20
|
+
a = a.view(a.size(0), groups, -1, *a.shape[2:])
|
|
21
|
+
b = b.view(groups, -1, *b.shape[1:])
|
|
22
|
+
|
|
23
|
+
a = torch.movedim(a, 2, a.dim() - 1).unsqueeze(-2)
|
|
24
|
+
b = torch.movedim(b, (1, 2), (b.dim() - 1, b.dim() - 2))
|
|
25
|
+
|
|
26
|
+
# complex value matrix multiplication
|
|
27
|
+
real = a.real @ b.real - a.imag @ b.imag
|
|
28
|
+
imag = a.imag @ b.real + a.real @ b.imag
|
|
29
|
+
real = torch.movedim(real, real.dim() - 1, 2).squeeze(-1)
|
|
30
|
+
imag = torch.movedim(imag, imag.dim() - 1, 2).squeeze(-1)
|
|
31
|
+
c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
|
|
32
|
+
c.real, c.imag = real, imag
|
|
33
|
+
|
|
34
|
+
return c.view(c.size(0), -1, *c.shape[3:])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:
|
|
38
|
+
"""Casts to a tuple with length 'n'. Useful for automatically computing the
|
|
39
|
+
padding and stride for convolutions, where users may only provide an integer.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
val: (Union[int, Iterable[int]]) Value to cast into a tuple.
|
|
43
|
+
n: (int) Desired length of the tuple
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
(Tuple[int, ...]) Tuple of length 'n'
|
|
47
|
+
"""
|
|
48
|
+
if isinstance(val, Iterable):
|
|
49
|
+
out = tuple(val)
|
|
50
|
+
if len(out) == n:
|
|
51
|
+
return out
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.")
|
|
54
|
+
else:
|
|
55
|
+
return n * (val,)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def fft_conv(
|
|
59
|
+
signal: Tensor,
|
|
60
|
+
kernel: Tensor,
|
|
61
|
+
bias: Tensor = None,
|
|
62
|
+
padding: Union[int, Iterable[int], str] = 0,
|
|
63
|
+
padding_mode: str = "constant",
|
|
64
|
+
stride: Union[int, Iterable[int]] = 1,
|
|
65
|
+
dilation: Union[int, Iterable[int]] = 1,
|
|
66
|
+
groups: int = 1,
|
|
67
|
+
) -> Tensor:
|
|
68
|
+
"""Performs N-d convolution of Tensors using a fast fourier transform, which
|
|
69
|
+
is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
|
|
70
|
+
the convolution (in order ot mimic the PyTorch direct convolution).
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
signal: (Tensor) Input tensor to be convolved with the kernel.
|
|
74
|
+
kernel: (Tensor) Convolution kernel.
|
|
75
|
+
bias: (Tensor) Bias tensor to add to the output.
|
|
76
|
+
padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
|
|
77
|
+
input on the last dimension. If str, "same" supported to pad input for size preservation.
|
|
78
|
+
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
|
|
79
|
+
reflection not available for 3d.
|
|
80
|
+
stride: (Union[int, Iterable[int]) Stride size for computing output values.
|
|
81
|
+
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
|
|
82
|
+
groups: (int) Number of groups for the convolution.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
(Tensor) Convolved tensor
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
# Cast padding, stride & dilation to tuples.
|
|
89
|
+
n = signal.ndim - 2
|
|
90
|
+
stride_ = to_ntuple(stride, n=n)
|
|
91
|
+
dilation_ = to_ntuple(dilation, n=n)
|
|
92
|
+
if isinstance(padding, str):
|
|
93
|
+
if padding == "same":
|
|
94
|
+
if stride != 1 or dilation != 1:
|
|
95
|
+
raise ValueError("stride must be 1 for padding='same'.")
|
|
96
|
+
padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
|
|
97
|
+
# else:
|
|
98
|
+
# raise ValueError(f"Padding mode {padding} not supported.")
|
|
99
|
+
else:
|
|
100
|
+
padding_ = to_ntuple(padding, n=n)
|
|
101
|
+
|
|
102
|
+
# internal dilation offsets
|
|
103
|
+
offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
|
|
104
|
+
offset[(slice(None), slice(None), *((0,) * n))] = 1.0
|
|
105
|
+
|
|
106
|
+
# correct the kernel by cutting off unwanted dilation trailing zeros
|
|
107
|
+
cutoff = tuple(slice(None, -d + 1 if d != 1 else None) for d in dilation_)
|
|
108
|
+
|
|
109
|
+
# pad the kernel internally according to the dilation parameters
|
|
110
|
+
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
|
|
111
|
+
|
|
112
|
+
# Pad the input signal & kernel tensors (round to support even sized convolutions)
|
|
113
|
+
signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
|
|
114
|
+
signal = f.pad(signal, signal_padding, mode=padding_mode)
|
|
115
|
+
|
|
116
|
+
# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
|
|
117
|
+
# have *even* length. Just pad with one more zero if the final dimension is odd.
|
|
118
|
+
signal_size = signal.size() # original signal size without padding to even
|
|
119
|
+
if signal.size(-1) % 2 != 0:
|
|
120
|
+
signal = f.pad(signal, [0, 1])
|
|
121
|
+
|
|
122
|
+
kernel_padding = [
|
|
123
|
+
pad
|
|
124
|
+
for i in reversed(range(2, signal.ndim))
|
|
125
|
+
for pad in [0, signal.size(i) - kernel.size(i)]
|
|
126
|
+
]
|
|
127
|
+
padded_kernel = f.pad(kernel, kernel_padding)
|
|
128
|
+
|
|
129
|
+
# Perform fourier convolution -- FFT, matrix multiply, then IFFT
|
|
130
|
+
signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
|
|
131
|
+
kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))
|
|
132
|
+
|
|
133
|
+
kernel_fr.imag *= -1
|
|
134
|
+
output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
|
|
135
|
+
output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
|
|
136
|
+
|
|
137
|
+
# Remove extra padded values
|
|
138
|
+
crop_slices = (slice(None), slice(None)) + tuple(
|
|
139
|
+
slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2])
|
|
140
|
+
for i in range(2, signal.ndim)
|
|
141
|
+
)
|
|
142
|
+
output = output[crop_slices].contiguous()
|
|
143
|
+
|
|
144
|
+
# Optionally, add a bias term before returning.
|
|
145
|
+
if bias is not None:
|
|
146
|
+
bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
|
|
147
|
+
output += bias.view(bias_shape)
|
|
148
|
+
|
|
149
|
+
return output
|
|
@@ -38,42 +38,58 @@ class BrightnessAdditiveTransform(ImageOnlyTransform):
|
|
|
38
38
|
"""
|
|
39
39
|
Adds random additive brightness noise sampled from a Gaussian distribution (mu, sigma).
|
|
40
40
|
|
|
41
|
-
Supports
|
|
41
|
+
Supports either synchronized brightness shift across all channels or per-channel brightness shift.
|
|
42
42
|
|
|
43
43
|
Args:
|
|
44
44
|
mu (float): Mean of the Gaussian used to sample brightness shifts.
|
|
45
45
|
sigma (float): Standard deviation of the Gaussian.
|
|
46
|
-
|
|
46
|
+
synchronize_channels (bool): If True, brightness shifts are shared across all channels.
|
|
47
47
|
p_per_channel (float): Probability to apply the brightness shift to each channel.
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
def __init__(self,
|
|
51
51
|
mu: float,
|
|
52
52
|
sigma: float,
|
|
53
|
-
|
|
53
|
+
synchronize_channels: bool = True, # Changed to synchronize_channels
|
|
54
54
|
p_per_channel: float = 1.0):
|
|
55
55
|
super().__init__()
|
|
56
56
|
self.mu = mu
|
|
57
57
|
self.sigma = sigma
|
|
58
|
-
self.
|
|
58
|
+
self.synchronize_channels = synchronize_channels # Now it's being used
|
|
59
59
|
self.p_per_channel = p_per_channel
|
|
60
60
|
|
|
61
|
-
def get_parameters(self,
|
|
62
|
-
|
|
63
|
-
|
|
61
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
62
|
+
img = data_dict["image"]
|
|
63
|
+
c = img.shape[0]
|
|
64
|
+
apply_to_channel = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
|
|
65
|
+
|
|
66
|
+
if len(apply_to_channel) == 0:
|
|
67
|
+
return {"apply_to_channel": apply_to_channel, "shift": None}
|
|
64
68
|
|
|
65
|
-
|
|
66
|
-
|
|
69
|
+
# Apply either synchronized or per-channel brightness shift
|
|
70
|
+
if self.synchronize_channels:
|
|
71
|
+
shift_value = float(sample_scalar((self.mu, self.sigma), image=img, channel=None))
|
|
72
|
+
shift = torch.full((c,), shift_value, device=img.device)
|
|
67
73
|
else:
|
|
68
|
-
|
|
69
|
-
brightness = [global_brightness if apply else None for apply in apply_channel]
|
|
74
|
+
shift = torch.empty(c, device=img.device).normal_(float(self.mu), float(self.sigma))
|
|
70
75
|
|
|
71
|
-
return {
|
|
76
|
+
return {"apply_to_channel": apply_to_channel, "shift": shift}
|
|
72
77
|
|
|
73
78
|
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
79
|
+
if params["shift"] is None:
|
|
80
|
+
return img
|
|
81
|
+
|
|
82
|
+
apply_idx = params["apply_to_channel"]
|
|
83
|
+
if apply_idx.numel() == 0:
|
|
84
|
+
return img
|
|
85
|
+
|
|
86
|
+
shift = params["shift"]
|
|
87
|
+
# Build full per-channel shift vector; non-selected channels get shift 0
|
|
88
|
+
shift_full = torch.zeros((img.shape[0],), device=img.device, dtype=img.dtype)
|
|
89
|
+
shift_full[apply_idx] = shift[apply_idx]
|
|
90
|
+
|
|
91
|
+
view_shape = (img.shape[0],) + (1,) * (img.ndim - 1)
|
|
92
|
+
img.add_(shift_full.view(view_shape))
|
|
77
93
|
return img
|
|
78
94
|
|
|
79
95
|
|
|
@@ -84,7 +100,8 @@ if __name__ == '__main__':
|
|
|
84
100
|
os.environ['OMP_NUM_THREADS'] = '1'
|
|
85
101
|
torch.set_num_threads(1)
|
|
86
102
|
|
|
87
|
-
mbt =
|
|
103
|
+
# mbt = BrightnessAdditiveTransform(0, 0.5,True, 1)
|
|
104
|
+
mbt = MultiplicativeBrightnessTransform((0.5, 2),False, 1)
|
|
88
105
|
|
|
89
106
|
times_torch = []
|
|
90
107
|
for _ in range(1000):
|
|
@@ -26,44 +26,70 @@ class BGContrast():
|
|
|
26
26
|
return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
import torch
|
|
30
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
31
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
32
|
+
|
|
33
|
+
|
|
29
34
|
class ContrastTransform(ImageOnlyTransform):
|
|
30
|
-
def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
|
|
35
|
+
def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1.0):
|
|
31
36
|
super().__init__()
|
|
32
37
|
self.contrast_range = contrast_range
|
|
33
38
|
self.preserve_range = preserve_range
|
|
34
39
|
self.synchronize_channels = synchronize_channels
|
|
35
|
-
self.p_per_channel = p_per_channel
|
|
40
|
+
self.p_per_channel = float(p_per_channel)
|
|
36
41
|
|
|
37
42
|
def get_parameters(self, **data_dict) -> dict:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
43
|
+
img = data_dict["image"]
|
|
44
|
+
c = img.shape[0]
|
|
45
|
+
|
|
46
|
+
# sample on correct device
|
|
47
|
+
apply_idx = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
|
|
48
|
+
n = apply_idx.numel()
|
|
49
|
+
|
|
50
|
+
if n == 0:
|
|
51
|
+
multipliers = None
|
|
52
|
+
elif self.synchronize_channels:
|
|
53
|
+
m = float(sample_scalar(self.contrast_range, image=img, channel=None))
|
|
54
|
+
multipliers = torch.full((n,), m, device=img.device, dtype=img.dtype)
|
|
42
55
|
else:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
56
|
+
# Still a Python loop because sample_scalar is scalar-by-scalar
|
|
57
|
+
# Use .tolist() to avoid iterating tensor scalars in Python
|
|
58
|
+
ms = [sample_scalar(self.contrast_range, image=img, channel=int(ch)) for ch in apply_idx.tolist()]
|
|
59
|
+
multipliers = torch.as_tensor(ms, device=img.device, dtype=img.dtype)
|
|
60
|
+
|
|
61
|
+
return {"apply_to_channel": apply_idx, "multipliers": multipliers}
|
|
48
62
|
|
|
49
63
|
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
50
|
-
|
|
64
|
+
idx = params["apply_to_channel"]
|
|
65
|
+
multipliers = params["multipliers"]
|
|
66
|
+
if multipliers is None or idx.numel() == 0:
|
|
51
67
|
return img
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
68
|
+
|
|
69
|
+
if self.preserve_range:
|
|
70
|
+
for i in range(idx.numel()):
|
|
71
|
+
c = int(idx[i])
|
|
72
|
+
m = multipliers[i]
|
|
73
|
+
|
|
74
|
+
x = img[c]
|
|
75
|
+
mean = x.mean()
|
|
76
|
+
minm = x.min()
|
|
77
|
+
maxm = x.max()
|
|
78
|
+
|
|
79
|
+
x.sub_(mean)
|
|
80
|
+
x.mul_(m)
|
|
81
|
+
x.add_(mean)
|
|
82
|
+
x.clamp_(minm, maxm)
|
|
83
|
+
else:
|
|
84
|
+
for i in range(idx.numel()):
|
|
85
|
+
c = int(idx[i])
|
|
86
|
+
m = multipliers[i]
|
|
87
|
+
|
|
88
|
+
x = img[c]
|
|
89
|
+
mean = x.mean()
|
|
90
|
+
x.sub_(mean)
|
|
91
|
+
x.mul_(m)
|
|
92
|
+
x.add_(mean)
|
|
67
93
|
|
|
68
94
|
return img
|
|
69
95
|
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GammaTransform(ImageOnlyTransform):
|
|
9
|
+
def __init__(self,
|
|
10
|
+
gamma: RandomScalar,
|
|
11
|
+
p_invert_image: float,
|
|
12
|
+
synchronize_channels: bool,
|
|
13
|
+
p_per_channel: float,
|
|
14
|
+
p_retain_stats: float):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.gamma = gamma
|
|
17
|
+
self.p_invert_image = float(p_invert_image)
|
|
18
|
+
self.synchronize_channels = synchronize_channels
|
|
19
|
+
self.p_per_channel = float(p_per_channel)
|
|
20
|
+
self.p_retain_stats = float(p_retain_stats)
|
|
21
|
+
|
|
22
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
23
|
+
img: torch.Tensor = data_dict["image"]
|
|
24
|
+
c = img.shape[0]
|
|
25
|
+
device = img.device
|
|
26
|
+
dtype = img.dtype
|
|
27
|
+
|
|
28
|
+
apply_idx = (torch.rand(c, device=device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
|
|
29
|
+
n = apply_idx.numel()
|
|
30
|
+
if n == 0:
|
|
31
|
+
return {"apply_to_channel": apply_idx,
|
|
32
|
+
"retain_stats": None,
|
|
33
|
+
"invert_image": None,
|
|
34
|
+
"gamma": None}
|
|
35
|
+
|
|
36
|
+
retain_stats = (torch.rand(n, device=device) < self.p_retain_stats)
|
|
37
|
+
invert_image = (torch.rand(n, device=device) < self.p_invert_image)
|
|
38
|
+
|
|
39
|
+
if self.synchronize_channels:
|
|
40
|
+
g = float(sample_scalar(self.gamma, image=img, channel=None))
|
|
41
|
+
gamma = torch.full((n,), g, device=device, dtype=dtype)
|
|
42
|
+
else:
|
|
43
|
+
# sample_scalar is scalar-based; keep loop but avoid tensor scalar iteration
|
|
44
|
+
gs = [float(sample_scalar(self.gamma, image=img, channel=int(ch))) for ch in apply_idx.tolist()]
|
|
45
|
+
gamma = torch.as_tensor(gs, device=device, dtype=dtype)
|
|
46
|
+
|
|
47
|
+
return {
|
|
48
|
+
"apply_to_channel": apply_idx,
|
|
49
|
+
"retain_stats": retain_stats,
|
|
50
|
+
"invert_image": invert_image,
|
|
51
|
+
"gamma": gamma,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
55
|
+
idx: torch.Tensor = params["apply_to_channel"]
|
|
56
|
+
if idx.numel() == 0:
|
|
57
|
+
return img
|
|
58
|
+
|
|
59
|
+
retain_stats: torch.Tensor = params["retain_stats"]
|
|
60
|
+
invert_image: torch.Tensor = params["invert_image"]
|
|
61
|
+
gamma: torch.Tensor = params["gamma"]
|
|
62
|
+
|
|
63
|
+
# constants
|
|
64
|
+
eps = 1e-7
|
|
65
|
+
|
|
66
|
+
# Loop over selected channels (good for small C)
|
|
67
|
+
for k in range(idx.numel()):
|
|
68
|
+
c = int(idx[k])
|
|
69
|
+
r = bool(retain_stats[k])
|
|
70
|
+
inv = bool(invert_image[k])
|
|
71
|
+
g = gamma[k]
|
|
72
|
+
|
|
73
|
+
x = img[c]
|
|
74
|
+
|
|
75
|
+
if inv:
|
|
76
|
+
x.mul_(-1)
|
|
77
|
+
|
|
78
|
+
if r:
|
|
79
|
+
mean = x.mean()
|
|
80
|
+
std = x.std()
|
|
81
|
+
|
|
82
|
+
minm = x.min()
|
|
83
|
+
maxm = x.max()
|
|
84
|
+
rnge = maxm - minm
|
|
85
|
+
denom = torch.clamp(rnge, min=eps)
|
|
86
|
+
|
|
87
|
+
# In-place gamma: x = (((x - min) / denom) ** g) * rnge + min
|
|
88
|
+
x.sub_(minm)
|
|
89
|
+
x.div_(denom)
|
|
90
|
+
x.pow_(g)
|
|
91
|
+
x.mul_(rnge)
|
|
92
|
+
x.add_(minm)
|
|
93
|
+
|
|
94
|
+
if r:
|
|
95
|
+
mn_here = x.mean()
|
|
96
|
+
std_here = x.std()
|
|
97
|
+
x.sub_(mn_here)
|
|
98
|
+
x.mul_(std / torch.clamp(std_here, min=eps))
|
|
99
|
+
x.add_(mean)
|
|
100
|
+
|
|
101
|
+
if inv:
|
|
102
|
+
x.mul_(-1)
|
|
103
|
+
|
|
104
|
+
return img
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
if __name__ == '__main__':
|
|
109
|
+
from time import time
|
|
110
|
+
import numpy as np
|
|
111
|
+
import os
|
|
112
|
+
|
|
113
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
114
|
+
torch.set_num_threads(1)
|
|
115
|
+
|
|
116
|
+
mbt = GammaTransform((0.7, 1.5), 0, False, 1, 1)
|
|
117
|
+
|
|
118
|
+
times_torch = []
|
|
119
|
+
for _ in range(100):
|
|
120
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
121
|
+
st = time()
|
|
122
|
+
out = mbt(**data_dict)
|
|
123
|
+
times_torch.append(time() - st)
|
|
124
|
+
print('torch', np.mean(times_torch))
|
|
125
|
+
|
|
126
|
+
from batchgenerators.transforms.color_transforms import GammaTransform as BGGamma
|
|
127
|
+
|
|
128
|
+
gnt_bg = BGGamma((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1)
|
|
129
|
+
times_bg = []
|
|
130
|
+
for _ in range(100):
|
|
131
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
132
|
+
st = time()
|
|
133
|
+
out = gnt_bg(**data_dict)
|
|
134
|
+
times_bg.append(time() - st)
|
|
135
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GaussianNoiseTransform(ImageOnlyTransform):
|
|
9
|
+
def __init__(self,
|
|
10
|
+
noise_variance: RandomScalar = (0, 0.1),
|
|
11
|
+
p_per_channel: float = 1.,
|
|
12
|
+
synchronize_channels: bool = False):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.noise_variance = noise_variance
|
|
15
|
+
self.p_per_channel = p_per_channel
|
|
16
|
+
self.synchronize_channels = synchronize_channels
|
|
17
|
+
|
|
18
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
19
|
+
img = data_dict["image"]
|
|
20
|
+
c = img.shape[0]
|
|
21
|
+
|
|
22
|
+
# bool mask on same device as image
|
|
23
|
+
apply = torch.rand(c, device=img.device) < self.p_per_channel
|
|
24
|
+
|
|
25
|
+
# store also count / indices to avoid recomputing later
|
|
26
|
+
idx = apply.nonzero(as_tuple=False).flatten()
|
|
27
|
+
n = idx.numel()
|
|
28
|
+
|
|
29
|
+
if n == 0:
|
|
30
|
+
sigmas = None
|
|
31
|
+
elif self.synchronize_channels:
|
|
32
|
+
sigmas = sample_scalar(self.noise_variance, img)
|
|
33
|
+
else:
|
|
34
|
+
# still uses sample_scalar, but avoids list->cat in _apply
|
|
35
|
+
# if sample_scalar is cheap, this is fine; otherwise see note below
|
|
36
|
+
sigmas = [sample_scalar(self.noise_variance, img) for _ in range(n)]
|
|
37
|
+
|
|
38
|
+
return {"apply_mask": apply, "apply_idx": idx, "num_apply": n, "sigmas": sigmas}
|
|
39
|
+
|
|
40
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
41
|
+
n = params["num_apply"]
|
|
42
|
+
if n == 0:
|
|
43
|
+
return img
|
|
44
|
+
|
|
45
|
+
idx = params["apply_idx"]
|
|
46
|
+
spatial = img.shape[1:]
|
|
47
|
+
device = img.device
|
|
48
|
+
dtype = img.dtype
|
|
49
|
+
|
|
50
|
+
sigmas = params["sigmas"]
|
|
51
|
+
|
|
52
|
+
if sigmas is None:
|
|
53
|
+
return img
|
|
54
|
+
|
|
55
|
+
# Create noise only for selected channels
|
|
56
|
+
if not self.synchronize_channels:
|
|
57
|
+
# vectorize per-channel sigma by creating a tensor of shape (n, 1, 1, ...)
|
|
58
|
+
# list->tensor is small (n floats), then broadcast
|
|
59
|
+
sigma_t = torch.as_tensor(sigmas, device=device, dtype=dtype)
|
|
60
|
+
view_shape = (n,) + (1,) * len(spatial)
|
|
61
|
+
sigma_t = sigma_t.view(view_shape)
|
|
62
|
+
|
|
63
|
+
noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_()
|
|
64
|
+
noise.mul_(sigma_t)
|
|
65
|
+
else:
|
|
66
|
+
sigma = sigmas
|
|
67
|
+
noise = torch.empty((n, *spatial), device=device, dtype=dtype).normal_(mean=0.0, std=float(sigma))
|
|
68
|
+
|
|
69
|
+
# Advanced indexing (img[idx]) returns a copy, so use indexed assignment
|
|
70
|
+
# to make sure modifications are written back to img.
|
|
71
|
+
img[idx] += noise
|
|
72
|
+
return img
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
if __name__ == "__main__":
|
|
76
|
+
from time import time
|
|
77
|
+
import numpy as np
|
|
78
|
+
|
|
79
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
80
|
+
torch.set_num_threads(1)
|
|
81
|
+
|
|
82
|
+
gnt = GaussianNoiseTransform((0, 0.1), 1, False)
|
|
83
|
+
|
|
84
|
+
times = []
|
|
85
|
+
for _ in range(1000):
|
|
86
|
+
data_dict = {'image': torch.ones((2, 32, 32, 32))}
|
|
87
|
+
st = time()
|
|
88
|
+
out = gnt(**data_dict)
|
|
89
|
+
times.append(time() - st)
|
|
90
|
+
print('torch', np.mean(times))
|
|
91
|
+
|
|
92
|
+
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform
|
|
93
|
+
|
|
94
|
+
gnt_bg = GaussianNoiseTransform((0, 0.1), 1, 1, True)
|
|
95
|
+
|
|
96
|
+
times = []
|
|
97
|
+
for _ in range(1000):
|
|
98
|
+
data_dict = {'data': np.ones((1, 2, 32, 32, 32))}
|
|
99
|
+
st = time()
|
|
100
|
+
out = gnt_bg(**data_dict)
|
|
101
|
+
times.append(time() - st)
|
|
102
|
+
|
|
103
|
+
print('bg', np.mean(times))
|
|
104
|
+
# torch is 2.5x faster
|
|
@@ -67,12 +67,11 @@ class CutOffOutliersTransform(ImageOnlyTransform):
|
|
|
67
67
|
if retain_std[c]:
|
|
68
68
|
orig_std = img_c.std()
|
|
69
69
|
|
|
70
|
-
#
|
|
71
|
-
|
|
72
|
-
lower_val =
|
|
73
|
-
upper_val = np.percentile(img_c_np, perc[1])
|
|
70
|
+
# Percentiles in torch to avoid numpy roundtrip
|
|
71
|
+
q = torch.tensor([perc[0] / 100.0, perc[1] / 100.0], device=img_c.device, dtype=torch.float32)
|
|
72
|
+
lower_val, upper_val = torch.quantile(img_c.float(), q)
|
|
74
73
|
|
|
75
|
-
img_c_clipped = img_c.clamp(min=
|
|
74
|
+
img_c_clipped = img_c.clamp(min=lower_val.item(), max=upper_val.item())
|
|
76
75
|
|
|
77
76
|
if retain_std[c]:
|
|
78
77
|
clipped_std = img_c_clipped.std()
|
|
@@ -99,4 +98,4 @@ if __name__ == '__main__':
|
|
|
99
98
|
params = transform.get_parameters(image=image)
|
|
100
99
|
image_clipped = transform._apply_to_image(image.clone(), **params)
|
|
101
100
|
|
|
102
|
-
view_batch(image, image_clipped, image_clipped-image)
|
|
101
|
+
view_batch(image, image_clipped, image_clipped-image)
|
|
@@ -3,14 +3,11 @@ from typing import Union, List, Tuple, Callable
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
|
-
from
|
|
6
|
+
from skimage.morphology import ball, disk
|
|
7
7
|
|
|
8
|
+
from batchgeneratorsv2.helpers.fft_conv import fft_conv
|
|
8
9
|
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
9
10
|
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
10
|
-
from skimage.morphology import ball, disk
|
|
11
|
-
from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
|
|
12
|
-
|
|
13
|
-
import torch.nn.functional as F
|
|
14
11
|
|
|
15
12
|
|
|
16
13
|
def binary_dilation_torch(input_tensor, structure_element):
|
|
@@ -71,7 +71,7 @@ class BlankRectangleTransform(ImageOnlyTransform):
|
|
|
71
71
|
}
|
|
72
72
|
|
|
73
73
|
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
74
|
-
out = img
|
|
74
|
+
out = img
|
|
75
75
|
for c, (apply, rects) in enumerate(zip(params['apply_channel'], params['rectangles'])):
|
|
76
76
|
if not apply:
|
|
77
77
|
continue
|
|
@@ -147,4 +147,4 @@ if __name__ == '__main__':
|
|
|
147
147
|
image_aug = transform._apply_to_image(image, **params)
|
|
148
148
|
|
|
149
149
|
from batchviewer import view_batch
|
|
150
|
-
view_batch(image, image_aug)
|
|
150
|
+
view_batch(image, image_aug)
|
|
@@ -6,9 +6,9 @@ import torch
|
|
|
6
6
|
from skimage.data import camera
|
|
7
7
|
from torch.nn.functional import pad, conv3d, conv1d, conv2d
|
|
8
8
|
|
|
9
|
+
from batchgeneratorsv2.helpers.fft_conv import fft_conv
|
|
9
10
|
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
10
11
|
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
11
|
-
from fft_conv_pytorch import fft_conv
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def blur_dimension(img: torch.Tensor, sigma: float, dim_to_blur: int, force_use_fft: bool = None, truncate: float = 6):
|
|
@@ -219,7 +219,7 @@ if __name__ == "__main__":
|
|
|
219
219
|
shape = (128, 164, 64)
|
|
220
220
|
num_warmup_for_benchmark = 1
|
|
221
221
|
num_repeats = 10
|
|
222
|
-
for sigma_range in (0.1, 1, 10):
|
|
222
|
+
for sigma_range in (0.1, 1, 10, 20):
|
|
223
223
|
print(shape, sigma_range)
|
|
224
224
|
gnt2 = GaussianBlurTransform(sigma_range, False, False, 1, benchmark=False)
|
|
225
225
|
times = []
|
|
@@ -257,4 +257,4 @@ if __name__ == "__main__":
|
|
|
257
257
|
print('batchgenerator', np.median(times))
|
|
258
258
|
print()
|
|
259
259
|
#
|
|
260
|
-
# print(gnt.benchmark_use_fft)
|
|
260
|
+
# print(gnt.benchmark_use_fft)
|