batchgeneratorsv2 0.1__tar.gz → 0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/PKG-INFO +1 -1
- batchgeneratorsv2-0.2/batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/intensity/brightness.py +63 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/intensity/contrast.py +96 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/intensity/gamma.py +88 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +80 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +193 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/spatial/low_resolution.py +88 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/spatial/spatial.py +509 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/compose.py +14 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/nnunet_masking.py +22 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/pseudo2d.py +81 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/random.py +23 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/remove_label.py +24 -0
- batchgeneratorsv2-0.2/batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/PKG-INFO +1 -1
- batchgeneratorsv2-0.2/batchgeneratorsv2.egg-info/SOURCES.txt +45 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/pyproject.toml +4 -3
- batchgeneratorsv2-0.1/batchgeneratorsv2.egg-info/SOURCES.txt +0 -10
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/LICENSE +0 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2/__init__.py +0 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/requires.txt +0 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/readme.md +0 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/setup.cfg +0 -0
- {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.2}/setup.py +0 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
|
|
6
|
+
ContrastAugmentationTransform, GammaTransform
|
|
7
|
+
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
|
|
8
|
+
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
|
|
9
|
+
from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
|
|
10
|
+
from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, NumpyToTensor
|
|
11
|
+
from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
|
|
12
|
+
DownsampleSegForDSTransform2
|
|
13
|
+
from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
|
|
14
|
+
from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
|
|
15
|
+
ConvertSegmentationToRegionsTransform
|
|
16
|
+
from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
|
|
17
|
+
Convert2DTo3DTransform
|
|
18
|
+
|
|
19
|
+
if __name__ == '__main__':
|
|
20
|
+
regions = ((1, 2, 3), (2, 3), (3, ))
|
|
21
|
+
do_dummy_2d_data_aug = False
|
|
22
|
+
patch_size = (128, 128, 128)
|
|
23
|
+
rotation_for_DA = (0, 2*np.pi)
|
|
24
|
+
deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
|
|
25
|
+
|
|
26
|
+
tr_transforms = []
|
|
27
|
+
if do_dummy_2d_data_aug:
|
|
28
|
+
ignore_axes = (0,)
|
|
29
|
+
tr_transforms.append(Convert3DTo2DTransform())
|
|
30
|
+
patch_size_spatial = patch_size[1:]
|
|
31
|
+
else:
|
|
32
|
+
patch_size_spatial = patch_size
|
|
33
|
+
ignore_axes = None
|
|
34
|
+
|
|
35
|
+
tr_transforms.append(SpatialTransform(
|
|
36
|
+
patch_size_spatial, patch_center_dist_from_border=None,
|
|
37
|
+
do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
|
|
38
|
+
do_rotation=True, angle_x=rotation_for_DA, angle_y=rotation_for_DA, angle_z=rotation_for_DA,
|
|
39
|
+
p_rot_per_axis=1, # todo experiment with this
|
|
40
|
+
do_scale=True, scale=(0.7, 1.4),
|
|
41
|
+
border_mode_data="constant", border_cval_data=0, order_data=3,
|
|
42
|
+
border_mode_seg="constant", border_cval_seg=-1, order_seg=1,
|
|
43
|
+
random_crop=False, # random cropping is part of our dataloaders
|
|
44
|
+
p_el_per_sample=0, p_scale_per_sample=1, p_rot_per_sample=1,
|
|
45
|
+
independent_scale_for_each_axis=False # todo experiment with this
|
|
46
|
+
))
|
|
47
|
+
|
|
48
|
+
if do_dummy_2d_data_aug:
|
|
49
|
+
tr_transforms.append(Convert2DTo3DTransform())
|
|
50
|
+
|
|
51
|
+
tr_transforms.append(GaussianNoiseTransform(p_per_sample=1, p_per_channel=1))
|
|
52
|
+
tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=1,
|
|
53
|
+
p_per_channel=1))
|
|
54
|
+
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=1))
|
|
55
|
+
tr_transforms.append(ContrastAugmentationTransform(p_per_sample=1, p_per_channel=1))
|
|
56
|
+
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
|
|
57
|
+
p_per_channel=1,
|
|
58
|
+
order_downsample=0, order_upsample=3, p_per_sample=1,
|
|
59
|
+
ignore_axes=ignore_axes))
|
|
60
|
+
tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=1))
|
|
61
|
+
tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1))
|
|
62
|
+
|
|
63
|
+
tr_transforms.append(MirrorTransform((0, 1, 2)))
|
|
64
|
+
|
|
65
|
+
tr_transforms.append(MaskTransform([0, 1, 2, 3],
|
|
66
|
+
mask_idx_in_seg=0, set_outside_to=0))
|
|
67
|
+
|
|
68
|
+
tr_transforms.append(RemoveLabelTransform(-1, 0))
|
|
69
|
+
|
|
70
|
+
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'seg', 'seg'))
|
|
71
|
+
|
|
72
|
+
tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='seg',
|
|
73
|
+
output_key='seg'))
|
|
74
|
+
|
|
75
|
+
tr_transforms.append(NumpyToTensor(['data', 'seg'], 'float'))
|
|
76
|
+
|
|
77
|
+
compute_times = [[] for i in range(len(tr_transforms))]
|
|
78
|
+
|
|
79
|
+
torch.set_num_threads(1)
|
|
80
|
+
for iter in range(50):
|
|
81
|
+
print(iter)
|
|
82
|
+
data_dict = {'data': np.random.uniform(size=(1, 4, 128, 128, 128)),
|
|
83
|
+
'seg': np.round(4.5 * np.random.uniform(size=(1, 1, 128, 128, 128)) - 1, decimals=0).astype(np.int8)}
|
|
84
|
+
for i, t in enumerate(tr_transforms):
|
|
85
|
+
st = time()
|
|
86
|
+
data_dict = t(**data_dict)
|
|
87
|
+
compute_times[i].append(time() - st)
|
|
88
|
+
|
|
89
|
+
for t, ct in zip(tr_transforms, compute_times):
|
|
90
|
+
print(t.__class__.__name__, np.mean(ct))
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from time import time
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
|
|
7
|
+
from batchgeneratorsv2.transforms.intensity.contrast import BGContrast, ContrastTransform
|
|
8
|
+
from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
|
|
9
|
+
from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
|
|
10
|
+
from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
|
|
11
|
+
from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
|
|
12
|
+
from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
|
|
13
|
+
from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
|
|
14
|
+
from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
|
|
15
|
+
from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
|
|
16
|
+
from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
|
|
17
|
+
from batchgeneratorsv2.transforms.utils.pseudo2d import Convert2DTo3DTransform, Convert3DTo2DTransform
|
|
18
|
+
from batchgeneratorsv2.transforms.utils.random import RandomTransform
|
|
19
|
+
from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
|
|
20
|
+
from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
|
|
21
|
+
|
|
22
|
+
if __name__ == '__main__':
|
|
23
|
+
regions = ((1, 2, 3), (2, 3), (3, ))
|
|
24
|
+
do_dummy_2d_data_aug = False
|
|
25
|
+
patch_size = (128, 128, 128)
|
|
26
|
+
rotation_for_DA = (0, 2*np.pi)
|
|
27
|
+
deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
|
|
28
|
+
|
|
29
|
+
transforms = []
|
|
30
|
+
if do_dummy_2d_data_aug:
|
|
31
|
+
ignore_axes = (0,)
|
|
32
|
+
transforms.append(Convert3DTo2DTransform())
|
|
33
|
+
patch_size_spatial = patch_size[1:]
|
|
34
|
+
else:
|
|
35
|
+
patch_size_spatial = patch_size
|
|
36
|
+
ignore_axes = None
|
|
37
|
+
transforms.append(
|
|
38
|
+
SpatialTransform(
|
|
39
|
+
patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
|
|
40
|
+
p_rotation=1,
|
|
41
|
+
rotation=rotation_for_DA, p_scaling=1, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
if do_dummy_2d_data_aug:
|
|
45
|
+
transforms.append(Convert2DTo3DTransform())
|
|
46
|
+
|
|
47
|
+
transforms.append(
|
|
48
|
+
GaussianNoiseTransform(
|
|
49
|
+
noise_variance=(0, 0.1),
|
|
50
|
+
p_per_channel=1,
|
|
51
|
+
synchronize_channels=True
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
transforms.append(
|
|
55
|
+
GaussianBlurTransform(
|
|
56
|
+
blur_sigma=(0.5, 1.),
|
|
57
|
+
synchronize_channels=False,
|
|
58
|
+
synchronize_axes=False,
|
|
59
|
+
p_per_channel=1, benchmark=True
|
|
60
|
+
))
|
|
61
|
+
transforms.append(
|
|
62
|
+
MultiplicativeBrightnessTransform(
|
|
63
|
+
multiplier_range=BGContrast((0.75, 1.25)),
|
|
64
|
+
synchronize_channels=False,
|
|
65
|
+
p_per_channel=1
|
|
66
|
+
))
|
|
67
|
+
transforms.append(
|
|
68
|
+
ContrastTransform(
|
|
69
|
+
contrast_range=BGContrast((0.75, 1.25)),
|
|
70
|
+
preserve_range=True,
|
|
71
|
+
synchronize_channels=False,
|
|
72
|
+
p_per_channel=1
|
|
73
|
+
))
|
|
74
|
+
transforms.append(
|
|
75
|
+
SimulateLowResolutionTransform(
|
|
76
|
+
scale=(0.5, 1),
|
|
77
|
+
synchronize_channels=False,
|
|
78
|
+
synchronize_axes=True,
|
|
79
|
+
ignore_axes=ignore_axes,
|
|
80
|
+
allowed_channels=None,
|
|
81
|
+
p_per_channel=1
|
|
82
|
+
))
|
|
83
|
+
transforms.append(
|
|
84
|
+
GammaTransform(
|
|
85
|
+
gamma=BGContrast((0.7, 1.5)),
|
|
86
|
+
p_invert_image=1,
|
|
87
|
+
synchronize_channels=False,
|
|
88
|
+
p_per_channel=1,
|
|
89
|
+
p_retain_stats=1
|
|
90
|
+
))
|
|
91
|
+
transforms.append(
|
|
92
|
+
GammaTransform(
|
|
93
|
+
gamma=BGContrast((0.7, 1.5)),
|
|
94
|
+
p_invert_image=0,
|
|
95
|
+
synchronize_channels=False,
|
|
96
|
+
p_per_channel=1,
|
|
97
|
+
p_retain_stats=1
|
|
98
|
+
))
|
|
99
|
+
transforms.append(
|
|
100
|
+
MirrorTransform(
|
|
101
|
+
allowed_axes=(0, 1, 2)
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
transforms.append(MaskImageTransform(
|
|
106
|
+
apply_to_channels=[0, 1, 2, 3],
|
|
107
|
+
channel_idx_in_seg=0,
|
|
108
|
+
set_outside_to=0,
|
|
109
|
+
))
|
|
110
|
+
|
|
111
|
+
transforms.append(
|
|
112
|
+
RemoveLabelTansform(-1, 0)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
transforms.append(
|
|
116
|
+
ConvertSegmentationToRegionsTransform(
|
|
117
|
+
regions=regions,
|
|
118
|
+
channel_in_seg=0
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
|
|
123
|
+
|
|
124
|
+
compute_times = [[] for i in range(len(transforms))]
|
|
125
|
+
|
|
126
|
+
with torch.no_grad():
|
|
127
|
+
torch.set_num_threads(1)
|
|
128
|
+
for iter in range(50):
|
|
129
|
+
print(iter)
|
|
130
|
+
data_dict = {'image': torch.rand((4, 128, 128, 128)),
|
|
131
|
+
'segmentation': torch.round(4.5 * torch.rand((1, 128, 128, 128)) - 1, decimals=0).to(torch.int8)}
|
|
132
|
+
for i, t in enumerate(transforms):
|
|
133
|
+
st = time()
|
|
134
|
+
data_dict = t(**data_dict)
|
|
135
|
+
compute_times[i].append(time() - st)
|
|
136
|
+
|
|
137
|
+
for t, ct in zip(transforms, compute_times):
|
|
138
|
+
print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.mean(ct))
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from time import time
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
def unique_torch(tensor):
|
|
7
|
+
return torch.unique(tensor)
|
|
8
|
+
|
|
9
|
+
def unique_npy(tensor):
|
|
10
|
+
return np.unique(tensor.numpy())
|
|
11
|
+
|
|
12
|
+
def unique_pandas(tensor):
|
|
13
|
+
np.sort(pd.unique(tensor.numpy().ravel()))
|
|
14
|
+
|
|
15
|
+
def unique_bincount(tensor):
|
|
16
|
+
return torch.where(torch.bincount(tensor.ravel()) > 0)[0]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == '__main__':
|
|
20
|
+
torch.set_num_threads(1)
|
|
21
|
+
shape = (64, 64, 64)
|
|
22
|
+
labels = 200
|
|
23
|
+
|
|
24
|
+
times = []
|
|
25
|
+
for _ in range(10):
|
|
26
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
27
|
+
st = time()
|
|
28
|
+
unique = unique_torch(seg)
|
|
29
|
+
times.append(time() - st)
|
|
30
|
+
print('unique_torch', np.median(times))
|
|
31
|
+
|
|
32
|
+
times = []
|
|
33
|
+
for _ in range(10):
|
|
34
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
35
|
+
st = time()
|
|
36
|
+
unique = unique_npy(seg)
|
|
37
|
+
times.append(time() - st)
|
|
38
|
+
print('unique_npy', np.median(times))
|
|
39
|
+
|
|
40
|
+
times = []
|
|
41
|
+
for _ in range(10):
|
|
42
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
43
|
+
st = time()
|
|
44
|
+
unique = unique_pandas(seg)
|
|
45
|
+
times.append(time() - st)
|
|
46
|
+
print('unique_pandas', np.median(times))
|
|
47
|
+
|
|
48
|
+
times = []
|
|
49
|
+
for _ in range(10):
|
|
50
|
+
seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
|
|
51
|
+
st = time()
|
|
52
|
+
unique = unique_bincount(seg)
|
|
53
|
+
times.append(time() - st)
|
|
54
|
+
print('unique_bincount', np.median(times))
|
|
55
|
+
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Union, Tuple, Callable
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
RandomScalar = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def sample_scalar(scalar_type: RandomScalar, *args, **kwargs):
|
|
9
|
+
if isinstance(scalar_type, (int, float)):
|
|
10
|
+
return scalar_type
|
|
11
|
+
elif isinstance(scalar_type, (list, tuple)):
|
|
12
|
+
assert len(scalar_type) == 2, 'if list is provided, its length must be 2'
|
|
13
|
+
assert scalar_type[0] <= scalar_type[1], 'if list is provided, first entry must be smaller or equal than second entry, ' \
|
|
14
|
+
'otherwise we cannot sample using np.random.uniform'
|
|
15
|
+
if scalar_type[0] == scalar_type[1]:
|
|
16
|
+
return scalar_type[0]
|
|
17
|
+
return np.random.uniform(*scalar_type)
|
|
18
|
+
elif callable(scalar_type):
|
|
19
|
+
return scalar_type(*args, **kwargs)
|
|
20
|
+
else:
|
|
21
|
+
raise RuntimeError('Unknown type: %s. Expected: int, float, list, tuple, callable', type(scalar_type))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if __name__ == '__main__':
|
|
25
|
+
sample_scalar(0.5)
|
|
26
|
+
sample_scalar((0, 1))
|
|
27
|
+
sample_scalar(lambda: np.random.uniform(-1, 2))
|
|
28
|
+
sample_scalar(lambda x, y: np.random.uniform(x, y), 0.5, 2)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BasicTransform(abc.ABC):
|
|
6
|
+
"""
|
|
7
|
+
Transforms are applied to each sample individually. The dataloader is responsible for collating, or we might consider a CollateTransform
|
|
8
|
+
|
|
9
|
+
We expect (C, X, Y) or (C, X, Y, Z) shaped inputs for image and seg (yes seg can have more color channels)
|
|
10
|
+
|
|
11
|
+
No idea what keypoint and bbox will look like, this is Michaels turf
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def __call__(self, **data_dict) -> dict:
|
|
17
|
+
params = self.get_parameters(**data_dict)
|
|
18
|
+
return self.apply(data_dict, **params)
|
|
19
|
+
|
|
20
|
+
def apply(self, data_dict, **params):
|
|
21
|
+
if data_dict.get('image') is not None:
|
|
22
|
+
data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
|
|
23
|
+
|
|
24
|
+
if data_dict.get('regression_target') is not None:
|
|
25
|
+
data_dict['regression_target'] = self._apply_to_segmentation(data_dict['regression_target'], **params)
|
|
26
|
+
|
|
27
|
+
if data_dict.get('segmentation') is not None:
|
|
28
|
+
data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
|
|
29
|
+
|
|
30
|
+
if data_dict.get('keypoints') is not None:
|
|
31
|
+
data_dict['keypoints'] = self._apply_to_keypoints(data_dict['keypoints'], **params)
|
|
32
|
+
|
|
33
|
+
if data_dict.get('bbox') is not None:
|
|
34
|
+
data_dict['bbox'] = self._apply_to_bbox(data_dict['bbox'], **params)
|
|
35
|
+
|
|
36
|
+
return data_dict
|
|
37
|
+
|
|
38
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
54
|
+
return {}
|
|
55
|
+
|
|
56
|
+
def __repr__(self):
|
|
57
|
+
ret_str = str(type(self).__name__) + "( " + ", ".join(
|
|
58
|
+
[key + " = " + repr(val) for key, val in self.__dict__.items()]) + " )"
|
|
59
|
+
return ret_str
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ImageOnlyTransform(BasicTransform):
|
|
63
|
+
def apply(self, data_dict: dict, **params) -> dict:
|
|
64
|
+
if data_dict.get('image') is not None:
|
|
65
|
+
data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
|
|
66
|
+
return data_dict
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SegOnlyTransform(BasicTransform):
|
|
70
|
+
def apply(self, data_dict: dict, **params) -> dict:
|
|
71
|
+
if data_dict.get('segmentation') is not None:
|
|
72
|
+
data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
|
|
73
|
+
return data_dict
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if __name__ == '__main__':
|
|
77
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MultiplicativeBrightnessTransform(ImageOnlyTransform):
|
|
8
|
+
def __init__(self, multiplier_range: RandomScalar, synchronize_channels: bool, p_per_channel: float = 1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.multiplier_range = multiplier_range
|
|
11
|
+
self.synchronize_channels = synchronize_channels
|
|
12
|
+
self.p_per_channel = p_per_channel
|
|
13
|
+
|
|
14
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
15
|
+
shape = data_dict['image'].shape
|
|
16
|
+
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
|
|
17
|
+
if self.synchronize_channels:
|
|
18
|
+
multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
|
|
19
|
+
else:
|
|
20
|
+
multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
|
|
21
|
+
return {
|
|
22
|
+
'apply_to_channel': apply_to_channel,
|
|
23
|
+
'multipliers': multipliers
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
27
|
+
if len(params['apply_to_channel']) == 0:
|
|
28
|
+
return img
|
|
29
|
+
# even though this is array notation it's a lot slower. Shame shame
|
|
30
|
+
# img[params['apply_to_channel']] *= params['multipliers'].view(-1, *[1]*(img.ndim - 1))
|
|
31
|
+
for c, m in zip(params['apply_to_channel'], params['multipliers']):
|
|
32
|
+
img[c] *= m
|
|
33
|
+
return img
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if __name__ == '__main__':
|
|
37
|
+
from time import time
|
|
38
|
+
import numpy as np
|
|
39
|
+
import os
|
|
40
|
+
|
|
41
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
42
|
+
torch.set_num_threads(1)
|
|
43
|
+
|
|
44
|
+
mbt = MultiplicativeBrightnessTransform((0.5, 2.), False, 1)
|
|
45
|
+
|
|
46
|
+
times_torch = []
|
|
47
|
+
for _ in range(1000):
|
|
48
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
49
|
+
st = time()
|
|
50
|
+
out = mbt(**data_dict)
|
|
51
|
+
times_torch.append(time() - st)
|
|
52
|
+
print('torch', np.mean(times_torch))
|
|
53
|
+
|
|
54
|
+
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform
|
|
55
|
+
|
|
56
|
+
gnt_bg = BrightnessMultiplicativeTransform((0.5, 2), True, p_per_sample=1)
|
|
57
|
+
times_bg = []
|
|
58
|
+
for _ in range(1000):
|
|
59
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
60
|
+
st = time()
|
|
61
|
+
out = gnt_bg(**data_dict)
|
|
62
|
+
times_bg.append(time() - st)
|
|
63
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BGContrast():
|
|
9
|
+
def __init__(self, contrast_range):
|
|
10
|
+
self.contrast_range = contrast_range
|
|
11
|
+
|
|
12
|
+
def sample_contrast(self, *args, **kwargs):
|
|
13
|
+
if callable(self.contrast_range):
|
|
14
|
+
factor = self.contrast_range()
|
|
15
|
+
else:
|
|
16
|
+
if np.random.random() < 0.5 and self.contrast_range[0] < 1:
|
|
17
|
+
factor = np.random.uniform(self.contrast_range[0], 1)
|
|
18
|
+
else:
|
|
19
|
+
factor = np.random.uniform(max(self.contrast_range[0], 1), self.contrast_range[1])
|
|
20
|
+
return factor
|
|
21
|
+
|
|
22
|
+
def __call__(self, *args, **kwargs):
|
|
23
|
+
return self.sample_contrast(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
def __repr__(self):
|
|
26
|
+
return self.__class__.__name__ + f"(contrast_range={self.contrast_range})"
|
|
27
|
+
|
|
28
|
+
class ContrastTransform(ImageOnlyTransform):
|
|
29
|
+
def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.contrast_range = contrast_range
|
|
32
|
+
self.preserve_range = preserve_range
|
|
33
|
+
self.synchronize_channels = synchronize_channels
|
|
34
|
+
self.p_per_channel = p_per_channel
|
|
35
|
+
|
|
36
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
37
|
+
shape = data_dict['image'].shape
|
|
38
|
+
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
|
|
39
|
+
if self.synchronize_channels:
|
|
40
|
+
multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
|
|
41
|
+
else:
|
|
42
|
+
multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
|
|
43
|
+
return {
|
|
44
|
+
'apply_to_channel': apply_to_channel,
|
|
45
|
+
'multipliers': multipliers
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
49
|
+
if len(params['apply_to_channel']) == 0:
|
|
50
|
+
return img
|
|
51
|
+
# array notation is not faster, let's leave it like this
|
|
52
|
+
for i in range(len(params['apply_to_channel'])):
|
|
53
|
+
c = params['apply_to_channel'][i]
|
|
54
|
+
mean = img[c].mean()
|
|
55
|
+
if self.preserve_range:
|
|
56
|
+
minm = img[c].min()
|
|
57
|
+
maxm = img[c].max()
|
|
58
|
+
|
|
59
|
+
# this is faster than having it in one line because this circumvents reallocating memory
|
|
60
|
+
img[c] -= mean
|
|
61
|
+
img[c] *= params['multipliers'][i]
|
|
62
|
+
img[c] += mean
|
|
63
|
+
|
|
64
|
+
if self.preserve_range:
|
|
65
|
+
img[c].clamp_(minm, maxm)
|
|
66
|
+
|
|
67
|
+
return img
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
if __name__ == '__main__':
|
|
71
|
+
from time import time
|
|
72
|
+
import os
|
|
73
|
+
|
|
74
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
75
|
+
torch.set_num_threads(1)
|
|
76
|
+
|
|
77
|
+
mbt = ContrastTransform(BGContrast((0.75, 1.25)).sample_contrast, True, False, p_per_channel=1)
|
|
78
|
+
|
|
79
|
+
times_torch = []
|
|
80
|
+
for _ in range(100):
|
|
81
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
82
|
+
st = time()
|
|
83
|
+
out = mbt(**data_dict)
|
|
84
|
+
times_torch.append(time() - st)
|
|
85
|
+
print('torch', np.mean(times_torch))
|
|
86
|
+
|
|
87
|
+
from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform
|
|
88
|
+
|
|
89
|
+
gnt_bg = ContrastAugmentationTransform((0.75, 1.25), preserve_range=True, per_channel=True, p_per_channel=1)
|
|
90
|
+
times_bg = []
|
|
91
|
+
for _ in range(100):
|
|
92
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
93
|
+
st = time()
|
|
94
|
+
out = gnt_bg(**data_dict)
|
|
95
|
+
times_bg.append(time() - st)
|
|
96
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from typing import Callable, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
|
|
6
|
+
from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GammaTransform(ImageOnlyTransform):
|
|
10
|
+
def __init__(self, gamma: RandomScalar, p_invert_image: float, synchronize_channels: bool, p_per_channel: float,
|
|
11
|
+
p_retain_stats: float):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.gamma = gamma
|
|
14
|
+
self.p_invert_image = p_invert_image
|
|
15
|
+
self.synchronize_channels = synchronize_channels
|
|
16
|
+
self.p_per_channel = p_per_channel
|
|
17
|
+
self.p_retain_stats = p_retain_stats
|
|
18
|
+
|
|
19
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
20
|
+
shape = data_dict['image'].shape
|
|
21
|
+
apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
|
|
22
|
+
retain_stats = torch.rand(len(apply_to_channel)) < self.p_retain_stats
|
|
23
|
+
invert_image = torch.rand(len(apply_to_channel)) < self.p_invert_image
|
|
24
|
+
|
|
25
|
+
if self.synchronize_channels:
|
|
26
|
+
gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=None)] * len(apply_to_channel))
|
|
27
|
+
else:
|
|
28
|
+
gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=c) for c in apply_to_channel])
|
|
29
|
+
return {
|
|
30
|
+
'apply_to_channel': apply_to_channel,
|
|
31
|
+
'retain_stats': retain_stats,
|
|
32
|
+
'invert_image': invert_image,
|
|
33
|
+
'gamma': gamma
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
37
|
+
for c, r, i, g in zip(params['apply_to_channel'], params['retain_stats'], params['invert_image'], params['gamma']):
|
|
38
|
+
if i:
|
|
39
|
+
img[c] *= -1
|
|
40
|
+
if r:
|
|
41
|
+
# std_mean is for whatever reason slower than doing the computations separately!?
|
|
42
|
+
# std, mean = torch.std_mean(img[c])
|
|
43
|
+
mean = torch.mean(img[c])
|
|
44
|
+
std = torch.std(img[c])
|
|
45
|
+
minm = torch.min(img[c])
|
|
46
|
+
rnge = torch.max(img[c]) - minm
|
|
47
|
+
img[c] = torch.pow(((img[c] - minm) / torch.clamp(rnge, min=1e-7)), g) * rnge + minm
|
|
48
|
+
if r:
|
|
49
|
+
# std_here, mn_here = torch.std_mean(img[c])
|
|
50
|
+
mn_here = torch.mean(img[c])
|
|
51
|
+
std_here = torch.std(img[c])
|
|
52
|
+
img[c] -= mn_here
|
|
53
|
+
img[c] *= (std / torch.clamp(std_here, min=1e-7))
|
|
54
|
+
img[c] += mean
|
|
55
|
+
|
|
56
|
+
if i:
|
|
57
|
+
img[c] *= -1
|
|
58
|
+
return img
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if __name__ == '__main__':
|
|
62
|
+
from time import time
|
|
63
|
+
import numpy as np
|
|
64
|
+
import os
|
|
65
|
+
|
|
66
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
67
|
+
torch.set_num_threads(1)
|
|
68
|
+
|
|
69
|
+
mbt = GammaTransform((0.7, 1.5), 0, False, 1, 1)
|
|
70
|
+
|
|
71
|
+
times_torch = []
|
|
72
|
+
for _ in range(100):
|
|
73
|
+
data_dict = {'image': torch.ones((2, 128, 192, 64))}
|
|
74
|
+
st = time()
|
|
75
|
+
out = mbt(**data_dict)
|
|
76
|
+
times_torch.append(time() - st)
|
|
77
|
+
print('torch', np.mean(times_torch))
|
|
78
|
+
|
|
79
|
+
from batchgenerators.transforms.color_transforms import GammaTransform as BGGamma
|
|
80
|
+
|
|
81
|
+
gnt_bg = BGGamma((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1)
|
|
82
|
+
times_bg = []
|
|
83
|
+
for _ in range(100):
|
|
84
|
+
data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
|
|
85
|
+
st = time()
|
|
86
|
+
out = gnt_bg(**data_dict)
|
|
87
|
+
times_bg.append(time() - st)
|
|
88
|
+
print('bg', np.mean(times_bg))
|