batchgeneratorsv2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2/helpers/fft_conv.py +149 -0
- batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
- batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
- batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
- batchgeneratorsv2/transforms/local/__init__.py +0 -0
- batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
- batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
- batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
- batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
- batchgeneratorsv2/transforms/local/local_transform.py +86 -0
- batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
- batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
- batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
- batchgeneratorsv2/transforms/noise/rician.py +61 -0
- batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
- batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
- batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
- batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
- batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
- batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
- batchgeneratorsv2/transforms/utils/random.py +46 -0
- batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
- batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
- batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
- batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
- batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
- batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
|
|
6
|
+
ContrastAugmentationTransform, GammaTransform
|
|
7
|
+
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
|
|
8
|
+
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
|
|
9
|
+
from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
|
|
10
|
+
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, NumpyToTensor
|
|
11
|
+
from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
|
|
12
|
+
DownsampleSegForDSTransform2
|
|
13
|
+
from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
|
|
14
|
+
from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
|
|
15
|
+
ConvertSegmentationToRegionsTransform
|
|
16
|
+
from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
|
|
17
|
+
Convert2DTo3DTransform
|
|
18
|
+
|
|
19
|
+
if __name__ == '__main__':
|
|
20
|
+
regions = ((1, 2, 3), (2, 3), (3, ))
|
|
21
|
+
do_dummy_2d_data_aug = False
|
|
22
|
+
patch_size = (128, 128, 128)
|
|
23
|
+
rotation_for_DA = (0, 2*np.pi)
|
|
24
|
+
deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
|
|
25
|
+
|
|
26
|
+
tr_transforms = []
|
|
27
|
+
if do_dummy_2d_data_aug:
|
|
28
|
+
ignore_axes = (0,)
|
|
29
|
+
tr_transforms.append(Convert3DTo2DTransform())
|
|
30
|
+
patch_size_spatial = patch_size[1:]
|
|
31
|
+
else:
|
|
32
|
+
patch_size_spatial = patch_size
|
|
33
|
+
ignore_axes = None
|
|
34
|
+
|
|
35
|
+
tr_transforms.append(SpatialTransform(
|
|
36
|
+
patch_size_spatial, patch_center_dist_from_border=None,
|
|
37
|
+
do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
|
|
38
|
+
do_rotation=True, angle_x=rotation_for_DA, angle_y=rotation_for_DA, angle_z=rotation_for_DA,
|
|
39
|
+
p_rot_per_axis=1, # todo experiment with this
|
|
40
|
+
do_scale=True, scale=(0.7, 1.4),
|
|
41
|
+
border_mode_data="constant", border_cval_data=0, order_data=3,
|
|
42
|
+
border_mode_seg="constant", border_cval_seg=-1, order_seg=1,
|
|
43
|
+
random_crop=False, # random cropping is part of our dataloaders
|
|
44
|
+
p_el_per_sample=0, p_scale_per_sample=1, p_rot_per_sample=1,
|
|
45
|
+
independent_scale_for_each_axis=False # todo experiment with this
|
|
46
|
+
))
|
|
47
|
+
|
|
48
|
+
if do_dummy_2d_data_aug:
|
|
49
|
+
tr_transforms.append(Convert2DTo3DTransform())
|
|
50
|
+
|
|
51
|
+
tr_transforms.append(GaussianNoiseTransform(p_per_sample=1, p_per_channel=1))
|
|
52
|
+
tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=1,
|
|
53
|
+
p_per_channel=1))
|
|
54
|
+
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=1))
|
|
55
|
+
tr_transforms.append(ContrastAugmentationTransform(p_per_sample=1, p_per_channel=1))
|
|
56
|
+
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
|
|
57
|
+
p_per_channel=1,
|
|
58
|
+
order_downsample=0, order_upsample=3, p_per_sample=1,
|
|
59
|
+
ignore_axes=ignore_axes))
|
|
60
|
+
tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=1))
|
|
61
|
+
tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1))
|
|
62
|
+
|
|
63
|
+
tr_transforms.append(MirrorTransform((0, 1, 2)))
|
|
64
|
+
|
|
65
|
+
tr_transforms.append(MaskTransform([0, 1, 2, 3],
|
|
66
|
+
mask_idx_in_seg=0, set_outside_to=0))
|
|
67
|
+
|
|
68
|
+
tr_transforms.append(RemoveLabelTransform(-1, 0))
|
|
69
|
+
|
|
70
|
+
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'seg', 'seg'))
|
|
71
|
+
|
|
72
|
+
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='seg',
|
|
73
|
+
output_key='seg'))
|
|
74
|
+
|
|
75
|
+
tr_transforms.append(NumpyToTensor(['data', 'seg'], 'float'))
|
|
76
|
+
|
|
77
|
+
compute_times = [[] for i in range(len(tr_transforms))]
|
|
78
|
+
|
|
79
|
+
torch.set_num_threads(1)
|
|
80
|
+
for iter in range(50):
|
|
81
|
+
print(iter)
|
|
82
|
+
data_dict = {'data': np.random.uniform(size=(1, 4, 128, 128, 128)),
|
|
83
|
+
'seg': np.round(4.5 * np.random.uniform(size=(1, 1, 128, 128, 128)) - 1, decimals=0).astype(np.int8)}
|
|
84
|
+
for i, t in enumerate(tr_transforms):
|
|
85
|
+
st = time()
|
|
86
|
+
data_dict = t(**data_dict)
|
|
87
|
+
compute_times[i].append(time() - st)
|
|
88
|
+
|
|
89
|
+
for t, ct in zip(tr_transforms, compute_times):
|
|
90
|
+
print(t.__class__.__name__, np.mean(ct))
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
|
|
7
|
+
from batchgeneratorsv2.transforms.intensity.contrast import BGContrast, ContrastTransform
|
|
8
|
+
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
|
|
9
|
+
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
|
|
10
|
+
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
|
|
11
|
+
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
|
|
12
|
+
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
|
|
13
|
+
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
|
|
14
|
+
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
|
|
15
|
+
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
|
|
16
|
+
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
|
|
17
|
+
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert2DTo3DTransform, Convert3DTo2DTransform
|
|
18
|
+
from batchgeneratorsv2.transforms.utils.random import RandomTransform
|
|
19
|
+
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
|
|
20
|
+
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
|
|
21
|
+
|
|
22
|
+
if __name__ == '__main__':
|
|
23
|
+
regions = ((1, 2, 3), (2, 3), (3, ))
|
|
24
|
+
do_dummy_2d_data_aug = False
|
|
25
|
+
patch_size = (128, 128, 128)
|
|
26
|
+
rotation_for_DA = (0, 2*np.pi)
|
|
27
|
+
deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
|
|
28
|
+
|
|
29
|
+
transforms = []
|
|
30
|
+
if do_dummy_2d_data_aug:
|
|
31
|
+
ignore_axes = (0,)
|
|
32
|
+
transforms.append(Convert3DTo2DTransform())
|
|
33
|
+
patch_size_spatial = patch_size[1:]
|
|
34
|
+
else:
|
|
35
|
+
patch_size_spatial = patch_size
|
|
36
|
+
ignore_axes = None
|
|
37
|
+
transforms.append(
|
|
38
|
+
SpatialTransform(
|
|
39
|
+
patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
|
|
40
|
+
p_rotation=1,
|
|
41
|
+
rotation=rotation_for_DA, p_scaling=1, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
if do_dummy_2d_data_aug:
|
|
45
|
+
transforms.append(Convert2DTo3DTransform())
|
|
46
|
+
|
|
47
|
+
transforms.append(
|
|
48
|
+
GaussianNoiseTransform(
|
|
49
|
+
noise_variance=(0, 0.1),
|
|
50
|
+
p_per_channel=1,
|
|
51
|
+
synchronize_channels=True
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
transforms.append(
|
|
55
|
+
GaussianBlurTransform(
|
|
56
|
+
blur_sigma=(0.5, 1.),
|
|
57
|
+
synchronize_channels=False,
|
|
58
|
+
synchronize_axes=False,
|
|
59
|
+
p_per_channel=1, benchmark=True
|
|
60
|
+
))
|
|
61
|
+
transforms.append(
|
|
62
|
+
MultiplicativeBrightnessTransform(
|
|
63
|
+
multiplier_range=BGContrast((0.75, 1.25)),
|
|
64
|
+
synchronize_channels=False,
|
|
65
|
+
p_per_channel=1
|
|
66
|
+
))
|
|
67
|
+
transforms.append(
|
|
68
|
+
ContrastTransform(
|
|
69
|
+
contrast_range=BGContrast((0.75, 1.25)),
|
|
70
|
+
preserve_range=True,
|
|
71
|
+
synchronize_channels=False,
|
|
72
|
+
p_per_channel=1
|
|
73
|
+
))
|
|
74
|
+
transforms.append(
|
|
75
|
+
SimulateLowResolutionTransform(
|
|
76
|
+
scale=(0.5, 1),
|
|
77
|
+
synchronize_channels=False,
|
|
78
|
+
synchronize_axes=True,
|
|
79
|
+
ignore_axes=ignore_axes,
|
|
80
|
+
allowed_channels=None,
|
|
81
|
+
p_per_channel=1
|
|
82
|
+
))
|
|
83
|
+
transforms.append(
|
|
84
|
+
GammaTransform(
|
|
85
|
+
gamma=BGContrast((0.7, 1.5)),
|
|
86
|
+
p_invert_image=1,
|
|
87
|
+
synchronize_channels=False,
|
|
88
|
+
p_per_channel=1,
|
|
89
|
+
p_retain_stats=1
|
|
90
|
+
))
|
|
91
|
+
transforms.append(
|
|
92
|
+
GammaTransform(
|
|
93
|
+
gamma=BGContrast((0.7, 1.5)),
|
|
94
|
+
p_invert_image=0,
|
|
95
|
+
synchronize_channels=False,
|
|
96
|
+
p_per_channel=1,
|
|
97
|
+
p_retain_stats=1
|
|
98
|
+
))
|
|
99
|
+
transforms.append(
|
|
100
|
+
MirrorTransform(
|
|
101
|
+
allowed_axes=(0, 1, 2)
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
transforms.append(MaskImageTransform(
|
|
106
|
+
apply_to_channels=[0, 1, 2, 3],
|
|
107
|
+
channel_idx_in_seg=0,
|
|
108
|
+
set_outside_to=0,
|
|
109
|
+
))
|
|
110
|
+
|
|
111
|
+
transforms.append(
|
|
112
|
+
RemoveLabelTansform(-1, 0)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
transforms.append(
|
|
116
|
+
ConvertSegmentationToRegionsTransform(
|
|
117
|
+
regions=regions,
|
|
118
|
+
channel_in_seg=0
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
|
|
123
|
+
|
|
124
|
+
compute_times = [[] for i in range(len(transforms))]
|
|
125
|
+
|
|
126
|
+
with torch.no_grad():
|
|
127
|
+
torch.set_num_threads(1)
|
|
128
|
+
for iter in range(50):
|
|
129
|
+
print(iter)
|
|
130
|
+
data_dict = {'image': torch.rand((4, 128, 128, 128)),
|
|
131
|
+
'segmentation': torch.round(4.5 * torch.rand((1, 128, 128, 128)) - 1, decimals=0).to(torch.int8)}
|
|
132
|
+
for i, t in enumerate(transforms):
|
|
133
|
+
st = time()
|
|
134
|
+
data_dict = t(**data_dict)
|
|
135
|
+
compute_times[i].append(time() - st)
|
|
136
|
+
|
|
137
|
+
for t, ct in zip(transforms, compute_times):
|
|
138
|
+
print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.mean(ct))
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from time import time
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
def unique_torch(tensor):
|
|
7
|
+
return torch.unique(tensor)
|
|
8
|
+
|
|
9
|
+
def unique_npy(tensor):
|
|
10
|
+
return np.unique(tensor.numpy())
|
|
11
|
+
|
|
12
|
+
def unique_pandas(tensor):
|
|
13
|
+
np.sort(pd.unique(tensor.numpy().ravel()))
|
|
14
|
+
|
|
15
|
+
def unique_bincount(tensor):
|
|
16
|
+
return torch.where(torch.bincount(tensor.ravel()) > 0)[0]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == '__main__':
|
|
20
|
+
torch.set_num_threads(1)
|
|
21
|
+
shape = (64, 64, 64)
|
|
22
|
+
labels = 200
|
|
23
|
+
|
|
24
|
+
times = []
|
|
25
|
+
for _ in range(10):
|
|
26
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
27
|
+
st = time()
|
|
28
|
+
unique = unique_torch(seg)
|
|
29
|
+
times.append(time() - st)
|
|
30
|
+
print('unique_torch', np.median(times))
|
|
31
|
+
|
|
32
|
+
times = []
|
|
33
|
+
for _ in range(10):
|
|
34
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
35
|
+
st = time()
|
|
36
|
+
unique = unique_npy(seg)
|
|
37
|
+
times.append(time() - st)
|
|
38
|
+
print('unique_npy', np.median(times))
|
|
39
|
+
|
|
40
|
+
times = []
|
|
41
|
+
for _ in range(10):
|
|
42
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
43
|
+
st = time()
|
|
44
|
+
unique = unique_pandas(seg)
|
|
45
|
+
times.append(time() - st)
|
|
46
|
+
print('unique_pandas', np.median(times))
|
|
47
|
+
|
|
48
|
+
times = []
|
|
49
|
+
for _ in range(10):
|
|
50
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
51
|
+
st = time()
|
|
52
|
+
unique = unique_bincount(seg)
|
|
53
|
+
times.append(time() - st)
|
|
54
|
+
print('unique_bincount', np.median(times))
|
|
55
|
+
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from typing import Iterable, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as f
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
from torch.fft import irfftn, rfftn
|
|
7
|
+
from math import ceil, floor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Taken from here: https://github.com/vcasellesb/fft-conv-pytorch/tree/non-tup-slice-fix. THANK YOU!
|
|
11
|
+
# Original codebase is here: https://github.com/fkodom/fft-conv-pytorch -> unfortunately was not updated for pytorch 2.9.
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
|
|
15
|
+
"""Multiplies two complex-valued tensors."""
|
|
16
|
+
# Scalar matrix multiplication of two tensors, over only the first channel
|
|
17
|
+
# dimensions. Dimensions 3 and higher will have the same shape after multiplication.
|
|
18
|
+
# We also allow for "grouped" multiplications, where multiple sections of channels
|
|
19
|
+
# are multiplied independently of one another (required for group convolutions).
|
|
20
|
+
a = a.view(a.size(0), groups, -1, *a.shape[2:])
|
|
21
|
+
b = b.view(groups, -1, *b.shape[1:])
|
|
22
|
+
|
|
23
|
+
a = torch.movedim(a, 2, a.dim() - 1).unsqueeze(-2)
|
|
24
|
+
b = torch.movedim(b, (1, 2), (b.dim() - 1, b.dim() - 2))
|
|
25
|
+
|
|
26
|
+
# complex value matrix multiplication
|
|
27
|
+
real = a.real @ b.real - a.imag @ b.imag
|
|
28
|
+
imag = a.imag @ b.real + a.real @ b.imag
|
|
29
|
+
real = torch.movedim(real, real.dim() - 1, 2).squeeze(-1)
|
|
30
|
+
imag = torch.movedim(imag, imag.dim() - 1, 2).squeeze(-1)
|
|
31
|
+
c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
|
|
32
|
+
c.real, c.imag = real, imag
|
|
33
|
+
|
|
34
|
+
return c.view(c.size(0), -1, *c.shape[3:])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:
|
|
38
|
+
"""Casts to a tuple with length 'n'. Useful for automatically computing the
|
|
39
|
+
padding and stride for convolutions, where users may only provide an integer.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
val: (Union[int, Iterable[int]]) Value to cast into a tuple.
|
|
43
|
+
n: (int) Desired length of the tuple
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
(Tuple[int, ...]) Tuple of length 'n'
|
|
47
|
+
"""
|
|
48
|
+
if isinstance(val, Iterable):
|
|
49
|
+
out = tuple(val)
|
|
50
|
+
if len(out) == n:
|
|
51
|
+
return out
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"Cannot cast tuple of length {len(out)} to length {n}.")
|
|
54
|
+
else:
|
|
55
|
+
return n * (val,)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def fft_conv(
|
|
59
|
+
signal: Tensor,
|
|
60
|
+
kernel: Tensor,
|
|
61
|
+
bias: Tensor = None,
|
|
62
|
+
padding: Union[int, Iterable[int], str] = 0,
|
|
63
|
+
padding_mode: str = "constant",
|
|
64
|
+
stride: Union[int, Iterable[int]] = 1,
|
|
65
|
+
dilation: Union[int, Iterable[int]] = 1,
|
|
66
|
+
groups: int = 1,
|
|
67
|
+
) -> Tensor:
|
|
68
|
+
"""Performs N-d convolution of Tensors using a fast fourier transform, which
|
|
69
|
+
is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
|
|
70
|
+
the convolution (in order ot mimic the PyTorch direct convolution).
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
signal: (Tensor) Input tensor to be convolved with the kernel.
|
|
74
|
+
kernel: (Tensor) Convolution kernel.
|
|
75
|
+
bias: (Tensor) Bias tensor to add to the output.
|
|
76
|
+
padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
|
|
77
|
+
input on the last dimension. If str, "same" supported to pad input for size preservation.
|
|
78
|
+
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
|
|
79
|
+
reflection not available for 3d.
|
|
80
|
+
stride: (Union[int, Iterable[int]) Stride size for computing output values.
|
|
81
|
+
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
|
|
82
|
+
groups: (int) Number of groups for the convolution.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
(Tensor) Convolved tensor
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
# Cast padding, stride & dilation to tuples.
|
|
89
|
+
n = signal.ndim - 2
|
|
90
|
+
stride_ = to_ntuple(stride, n=n)
|
|
91
|
+
dilation_ = to_ntuple(dilation, n=n)
|
|
92
|
+
if isinstance(padding, str):
|
|
93
|
+
if padding == "same":
|
|
94
|
+
if stride != 1 or dilation != 1:
|
|
95
|
+
raise ValueError("stride must be 1 for padding='same'.")
|
|
96
|
+
padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
|
|
97
|
+
# else:
|
|
98
|
+
# raise ValueError(f"Padding mode {padding} not supported.")
|
|
99
|
+
else:
|
|
100
|
+
padding_ = to_ntuple(padding, n=n)
|
|
101
|
+
|
|
102
|
+
# internal dilation offsets
|
|
103
|
+
offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
|
|
104
|
+
offset[(slice(None), slice(None), *((0,) * n))] = 1.0
|
|
105
|
+
|
|
106
|
+
# correct the kernel by cutting off unwanted dilation trailing zeros
|
|
107
|
+
cutoff = tuple(slice(None, -d + 1 if d != 1 else None) for d in dilation_)
|
|
108
|
+
|
|
109
|
+
# pad the kernel internally according to the dilation parameters
|
|
110
|
+
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
|
|
111
|
+
|
|
112
|
+
# Pad the input signal & kernel tensors (round to support even sized convolutions)
|
|
113
|
+
signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
|
|
114
|
+
signal = f.pad(signal, signal_padding, mode=padding_mode)
|
|
115
|
+
|
|
116
|
+
# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
|
|
117
|
+
# have *even* length. Just pad with one more zero if the final dimension is odd.
|
|
118
|
+
signal_size = signal.size() # original signal size without padding to even
|
|
119
|
+
if signal.size(-1) % 2 != 0:
|
|
120
|
+
signal = f.pad(signal, [0, 1])
|
|
121
|
+
|
|
122
|
+
kernel_padding = [
|
|
123
|
+
pad
|
|
124
|
+
for i in reversed(range(2, signal.ndim))
|
|
125
|
+
for pad in [0, signal.size(i) - kernel.size(i)]
|
|
126
|
+
]
|
|
127
|
+
padded_kernel = f.pad(kernel, kernel_padding)
|
|
128
|
+
|
|
129
|
+
# Perform fourier convolution -- FFT, matrix multiply, then IFFT
|
|
130
|
+
signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
|
|
131
|
+
kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))
|
|
132
|
+
|
|
133
|
+
kernel_fr.imag *= -1
|
|
134
|
+
output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
|
|
135
|
+
output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))
|
|
136
|
+
|
|
137
|
+
# Remove extra padded values
|
|
138
|
+
crop_slices = (slice(None), slice(None)) + tuple(
|
|
139
|
+
slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2])
|
|
140
|
+
for i in range(2, signal.ndim)
|
|
141
|
+
)
|
|
142
|
+
output = output[crop_slices].contiguous()
|
|
143
|
+
|
|
144
|
+
# Optionally, add a bias term before returning.
|
|
145
|
+
if bias is not None:
|
|
146
|
+
bias_shape = tuple([1, -1] + (signal.ndim - 2) * [1])
|
|
147
|
+
output += bias.view(bias_shape)
|
|
148
|
+
|
|
149
|
+
return output
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Union, Tuple, Callable
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
RandomScalar = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def sample_scalar(scalar_type: RandomScalar, *args, **kwargs):
|
|
9
|
+
if isinstance(scalar_type, (int, float)):
|
|
10
|
+
return scalar_type
|
|
11
|
+
elif isinstance(scalar_type, (list, tuple)):
|
|
12
|
+
assert len(scalar_type) == 2, 'if list is provided, its length must be 2'
|
|
13
|
+
assert scalar_type[0] <= scalar_type[1], 'if list is provided, first entry must be smaller or equal than second entry, ' \
|
|
14
|
+
'otherwise we cannot sample using np.random.uniform'
|
|
15
|
+
if scalar_type[0] == scalar_type[1]:
|
|
16
|
+
return scalar_type[0]
|
|
17
|
+
return np.random.uniform(*scalar_type)
|
|
18
|
+
elif callable(scalar_type):
|
|
19
|
+
return scalar_type(*args, **kwargs)
|
|
20
|
+
else:
|
|
21
|
+
raise RuntimeError('Unknown type: %s. Expected: int, float, list, tuple, callable', type(scalar_type))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if __name__ == '__main__':
|
|
25
|
+
sample_scalar(0.5)
|
|
26
|
+
sample_scalar((0, 1))
|
|
27
|
+
sample_scalar(lambda: np.random.uniform(-1, 2))
|
|
28
|
+
sample_scalar(lambda x, y: np.random.uniform(x, y), 0.5, 2)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BasicTransform(abc.ABC):
|
|
6
|
+
"""
|
|
7
|
+
Transforms are applied to each sample individually. The dataloader is responsible for collating, or we might consider a CollateTransform
|
|
8
|
+
|
|
9
|
+
We expect (C, X, Y) or (C, X, Y, Z) shaped inputs for image and seg (yes seg can have more color channels)
|
|
10
|
+
|
|
11
|
+
No idea what keypoint and bbox will look like, this is Michaels turf
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def __call__(self, **data_dict) -> dict:
|
|
17
|
+
params = self.get_parameters(**data_dict)
|
|
18
|
+
return self.apply(data_dict, **params)
|
|
19
|
+
|
|
20
|
+
def apply(self, data_dict, **params):
|
|
21
|
+
if data_dict.get('image') is not None:
|
|
22
|
+
data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
|
|
23
|
+
|
|
24
|
+
if data_dict.get('regression_target') is not None:
|
|
25
|
+
data_dict['regression_target'] = self._apply_to_regr_target(data_dict['regression_target'], **params)
|
|
26
|
+
|
|
27
|
+
if data_dict.get('segmentation') is not None:
|
|
28
|
+
data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
|
|
29
|
+
|
|
30
|
+
if data_dict.get('keypoints') is not None:
|
|
31
|
+
data_dict['keypoints'] = self._apply_to_keypoints(data_dict['keypoints'], **params)
|
|
32
|
+
|
|
33
|
+
if data_dict.get('bbox') is not None:
|
|
34
|
+
data_dict['bbox'] = self._apply_to_bbox(data_dict['bbox'], **params)
|
|
35
|
+
|
|
36
|
+
return data_dict
|
|
37
|
+
|
|
38
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
54
|
+
return {}
|
|
55
|
+
|
|
56
|
+
def __repr__(self):
|
|
57
|
+
ret_str = str(type(self).__name__) + "( " + ", ".join(
|
|
58
|
+
[key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
|
|
59
|
+
return ret_str
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ImageOnlyTransform(BasicTransform):
|
|
63
|
+
def apply(self, data_dict: dict, **params) -> dict:
|
|
64
|
+
if data_dict.get('image') is not None:
|
|
65
|
+
data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
|
|
66
|
+
return data_dict
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SegOnlyTransform(BasicTransform):
|
|
70
|
+
def apply(self, data_dict: dict, **params) -> dict:
|
|
71
|
+
if data_dict.get('segmentation') is not None:
|
|
72
|
+
data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
|
|
73
|
+
return data_dict
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if __name__ == '__main__':
|
|
77
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MultiplicativeBrightnessTransform(ImageOnlyTransform):
|
|
9
|
+
def __init__(self, multiplier_range: RandomScalar, synchronize_channels: bool, p_per_channel: float = 1):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.multiplier_range = multiplier_range
|
|
12
|
+
self.synchronize_channels = synchronize_channels
|
|
13
|
+
self.p_per_channel = p_per_channel
|
|
14
|
+
|
|
15
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
16
|
+
shape = data_dict['image'].shape
|
|
17
|
+
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
|
|
18
|
+
if self.synchronize_channels:
|
|
19
|
+
multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
|
|
20
|
+
else:
|
|
21
|
+
multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
|
|
22
|
+
return {
|
|
23
|
+
'apply_to_channel': apply_to_channel,
|
|
24
|
+
'multipliers': multipliers
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
28
|
+
if len(params['apply_to_channel']) == 0:
|
|
29
|
+
return img
|
|
30
|
+
# even though this is array notation it's a lot slower. Shame shame
|
|
31
|
+
# img[params['apply_to_channel']] *= params['multipliers'].view(-1, *[1]*(img.ndim - 1))
|
|
32
|
+
for c, m in zip(params['apply_to_channel'], params['multipliers']):
|
|
33
|
+
img[c] *= m
|
|
34
|
+
return img
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BrightnessAdditiveTransform(ImageOnlyTransform):
|
|
38
|
+
"""
|
|
39
|
+
Adds random additive brightness noise sampled from a Gaussian distribution (mu, sigma).
|
|
40
|
+
|
|
41
|
+
Supports either synchronized brightness shift across all channels or per-channel brightness shift.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
mu (float): Mean of the Gaussian used to sample brightness shifts.
|
|
45
|
+
sigma (float): Standard deviation of the Gaussian.
|
|
46
|
+
synchronize_channels (bool): If True, brightness shifts are shared across all channels.
|
|
47
|
+
p_per_channel (float): Probability to apply the brightness shift to each channel.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self,
|
|
51
|
+
mu: float,
|
|
52
|
+
sigma: float,
|
|
53
|
+
synchronize_channels: bool = True, # Changed to synchronize_channels
|
|
54
|
+
p_per_channel: float = 1.0):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.mu = mu
|
|
57
|
+
self.sigma = sigma
|
|
58
|
+
self.synchronize_channels = synchronize_channels # Now it's being used
|
|
59
|
+
self.p_per_channel = p_per_channel
|
|
60
|
+
|
|
61
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
62
|
+
img = data_dict["image"]
|
|
63
|
+
c = img.shape[0]
|
|
64
|
+
apply_to_channel = (torch.rand(c, device=img.device) < self.p_per_channel).nonzero(as_tuple=False).flatten()
|
|
65
|
+
|
|
66
|
+
if len(apply_to_channel) == 0:
|
|
67
|
+
return {"apply_to_channel": apply_to_channel, "shift": None}
|
|
68
|
+
|
|
69
|
+
# Apply either synchronized or per-channel brightness shift
|
|
70
|
+
if self.synchronize_channels:
|
|
71
|
+
shift_value = float(sample_scalar((self.mu, self.sigma), image=img, channel=None))
|
|
72
|
+
shift = torch.full((c,), shift_value, device=img.device)
|
|
73
|
+
else:
|
|
74
|
+
shift = torch.empty(c, device=img.device).normal_(float(self.mu), float(self.sigma))
|
|
75
|
+
|
|
76
|
+
return {"apply_to_channel": apply_to_channel, "shift": shift}
|
|
77
|
+
|
|
78
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
79
|
+
if params["shift"] is None:
|
|
80
|
+
return img
|
|
81
|
+
|
|
82
|
+
apply_idx = params["apply_to_channel"]
|
|
83
|
+
if apply_idx.numel() == 0:
|
|
84
|
+
return img
|
|
85
|
+
|
|
86
|
+
shift = params["shift"]
|
|
87
|
+
# Build full per-channel shift vector; non-selected channels get shift 0
|
|
88
|
+
shift_full = torch.zeros((img.shape[0],), device=img.device, dtype=img.dtype)
|
|
89
|
+
shift_full[apply_idx] = shift[apply_idx]
|
|
90
|
+
|
|
91
|
+
view_shape = (img.shape[0],) + (1,) * (img.ndim - 1)
|
|
92
|
+
img.add_(shift_full.view(view_shape))
|
|
93
|
+
return img
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if __name__ == '__main__':
|
|
97
|
+
from time import time
|
|
98
|
+
import os
|
|
99
|
+
|
|
100
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
101
|
+
torch.set_num_threads(1)
|
|
102
|
+
|
|
103
|
+
# mbt = BrightnessAdditiveTransform(0, 0.5,True, 1)
|
|
104
|
+
mbt = MultiplicativeBrightnessTransform((0.5, 2),False, 1)
|
|
105
|
+
|
|
106
|
+
times_torch = []
|
|
107
|
+
for _ in range(1000):
|
|
108
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
109
|
+
st = time()
|
|
110
|
+
out = mbt(**data_dict)
|
|
111
|
+
times_torch.append(time() - st)
|
|
112
|
+
print('torch', np.mean(times_torch))
|
|
113
|
+
|
|
114
|
+
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform
|
|
115
|
+
|
|
116
|
+
gnt_bg = BrightnessMultiplicativeTransform((0.5, 2), True, p_per_sample=1)
|
|
117
|
+
times_bg = []
|
|
118
|
+
for _ in range(1000):
|
|
119
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
120
|
+
st = time()
|
|
121
|
+
out = gnt_bg(**data_dict)
|
|
122
|
+
times_bg.append(time() - st)
|
|
123
|
+
print('bg', np.mean(times_bg))
|