batchgeneratorsv2 0.1__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/PKG-INFO +1 -1
  2. batchgeneratorsv2-0.1.1/batchgeneratorsv2/benchmarks/__init__.py +0 -0
  3. batchgeneratorsv2-0.1.1/batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  4. batchgeneratorsv2-0.1.1/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  5. batchgeneratorsv2-0.1.1/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  6. batchgeneratorsv2-0.1.1/batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  7. batchgeneratorsv2-0.1.1/batchgeneratorsv2/dataloading/__init__.py +0 -0
  8. batchgeneratorsv2-0.1.1/batchgeneratorsv2/helpers/__init__.py +0 -0
  9. batchgeneratorsv2-0.1.1/batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/base/basic_transform.py +72 -0
  13. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/intensity/brightness.py +63 -0
  15. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/intensity/contrast.py +93 -0
  16. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/intensity/gamma.py +88 -0
  17. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/intensity/gaussian_noise.py +80 -0
  18. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  19. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +193 -0
  20. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  21. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  22. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  23. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  24. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  25. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/spatial/low_resolution.py +88 -0
  26. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  27. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/spatial/spatial.py +365 -0
  28. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  29. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/compose.py +14 -0
  30. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  31. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  32. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/nnunet_masking.py +22 -0
  33. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/pseudo2d.py +81 -0
  34. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/random.py +19 -0
  35. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/remove_label.py +24 -0
  36. batchgeneratorsv2-0.1.1/batchgeneratorsv2/transforms/utils/seg_to_regions.py +23 -0
  37. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/batchgeneratorsv2.egg-info/PKG-INFO +1 -1
  38. batchgeneratorsv2-0.1.1/batchgeneratorsv2.egg-info/SOURCES.txt +45 -0
  39. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/pyproject.toml +4 -3
  40. batchgeneratorsv2-0.1/batchgeneratorsv2.egg-info/SOURCES.txt +0 -10
  41. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/LICENSE +0 -0
  42. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/batchgeneratorsv2/__init__.py +0 -0
  43. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/batchgeneratorsv2.egg-info/dependency_links.txt +0 -0
  44. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/batchgeneratorsv2.egg-info/requires.txt +0 -0
  45. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/batchgeneratorsv2.egg-info/top_level.txt +0 -0
  46. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/readme.md +0 -0
  47. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/setup.cfg +0 -0
  48. {batchgeneratorsv2-0.1 → batchgeneratorsv2-0.1.1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: batchgeneratorsv2
3
- Version: 0.1
3
+ Version: 0.1.1
4
4
  Summary: Batchgenerators but better
5
5
  Author: Helmholtz Imaging Applied Computer Vision Lab
6
6
  Author-email: Fabian Isensee <f.isensee@dkfz-heidelberg.de>
@@ -0,0 +1,90 @@
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
6
+ ContrastAugmentationTransform, GammaTransform
7
+ from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
8
+ from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
9
+ from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
10
+ from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, NumpyToTensor
11
+ from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
12
+ DownsampleSegForDSTransform2
13
+ from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
14
+ from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
15
+ ConvertSegmentationToRegionsTransform
16
+ from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
17
+ Convert2DTo3DTransform
18
+
19
+ if __name__ == '__main__':
20
+ regions = ((1, 2, 3), (2, 3), (3, ))
21
+ do_dummy_2d_data_aug = False
22
+ patch_size = (128, 128, 128)
23
+ rotation_for_DA = (0, 2*np.pi)
24
+ deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
25
+
26
+ tr_transforms = []
27
+ if do_dummy_2d_data_aug:
28
+ ignore_axes = (0,)
29
+ tr_transforms.append(Convert3DTo2DTransform())
30
+ patch_size_spatial = patch_size[1:]
31
+ else:
32
+ patch_size_spatial = patch_size
33
+ ignore_axes = None
34
+
35
+ tr_transforms.append(SpatialTransform(
36
+ patch_size_spatial, patch_center_dist_from_border=None,
37
+ do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
38
+ do_rotation=True, angle_x=rotation_for_DA, angle_y=rotation_for_DA, angle_z=rotation_for_DA,
39
+ p_rot_per_axis=1, # todo experiment with this
40
+ do_scale=True, scale=(0.7, 1.4),
41
+ border_mode_data="constant", border_cval_data=0, order_data=3,
42
+ border_mode_seg="constant", border_cval_seg=-1, order_seg=1,
43
+ random_crop=False, # random cropping is part of our dataloaders
44
+ p_el_per_sample=0, p_scale_per_sample=1, p_rot_per_sample=1,
45
+ independent_scale_for_each_axis=False # todo experiment with this
46
+ ))
47
+
48
+ if do_dummy_2d_data_aug:
49
+ tr_transforms.append(Convert2DTo3DTransform())
50
+
51
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=1, p_per_channel=1))
52
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=1,
53
+ p_per_channel=1))
54
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=1))
55
+ tr_transforms.append(ContrastAugmentationTransform(p_per_sample=1, p_per_channel=1))
56
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
57
+ p_per_channel=1,
58
+ order_downsample=0, order_upsample=3, p_per_sample=1,
59
+ ignore_axes=ignore_axes))
60
+ tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=1))
61
+ tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1))
62
+
63
+ tr_transforms.append(MirrorTransform((0, 1, 2)))
64
+
65
+ tr_transforms.append(MaskTransform([0, 1, 2, 3],
66
+ mask_idx_in_seg=0, set_outside_to=0))
67
+
68
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
69
+
70
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'seg', 'seg'))
71
+
72
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='seg',
73
+ output_key='seg'))
74
+
75
+ tr_transforms.append(NumpyToTensor(['data', 'seg'], 'float'))
76
+
77
+ compute_times = [[] for i in range(len(tr_transforms))]
78
+
79
+ torch.set_num_threads(1)
80
+ for iter in range(50):
81
+ print(iter)
82
+ data_dict = {'data': np.random.uniform(size=(1, 4, 128, 128, 128)),
83
+ 'seg': np.round(4.5 * np.random.uniform(size=(1, 1, 128, 128, 128)) - 1, decimals=0).astype(np.int8)}
84
+ for i, t in enumerate(tr_transforms):
85
+ st = time()
86
+ data_dict = t(**data_dict)
87
+ compute_times[i].append(time() - st)
88
+
89
+ for t, ct in zip(tr_transforms, compute_times):
90
+ print(t.__class__.__name__, np.percentile(ct, 20))
@@ -0,0 +1,138 @@
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
7
+ from batchgeneratorsv2.transforms.intensity.contrast import BGContrast, ContrastTransform
8
+ from batchgeneratorsv2.transforms.intensity.gamma import GammaTransform
9
+ from batchgeneratorsv2.transforms.intensity.gaussian_noise import GaussianNoiseTransform
10
+ from batchgeneratorsv2.transforms.noise.gaussian_blur import GaussianBlurTransform
11
+ from batchgeneratorsv2.transforms.spatial.low_resolution import SimulateLowResolutionTransform
12
+ from batchgeneratorsv2.transforms.spatial.mirroring import MirrorTransform
13
+ from batchgeneratorsv2.transforms.spatial.spatial import SpatialTransform
14
+ from batchgeneratorsv2.transforms.utils.compose import ComposeTransforms
15
+ from batchgeneratorsv2.transforms.utils.deep_supervision_downsampling import DownsampleSegForDSTransform
16
+ from batchgeneratorsv2.transforms.utils.nnunet_masking import MaskImageTransform
17
+ from batchgeneratorsv2.transforms.utils.pseudo2d import Convert2DTo3DTransform, Convert3DTo2DTransform
18
+ from batchgeneratorsv2.transforms.utils.random import RandomTransform
19
+ from batchgeneratorsv2.transforms.utils.remove_label import RemoveLabelTansform
20
+ from batchgeneratorsv2.transforms.utils.seg_to_regions import ConvertSegmentationToRegionsTransform
21
+
22
+ if __name__ == '__main__':
23
+ regions = ((1, 2, 3), (2, 3), (3, ))
24
+ do_dummy_2d_data_aug = False
25
+ patch_size = (128, 128, 128)
26
+ rotation_for_DA = (0, 2*np.pi)
27
+ deep_supervision_scales = ((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
28
+
29
+ transforms = []
30
+ if do_dummy_2d_data_aug:
31
+ ignore_axes = (0,)
32
+ transforms.append(Convert3DTo2DTransform())
33
+ patch_size_spatial = patch_size[1:]
34
+ else:
35
+ patch_size_spatial = patch_size
36
+ ignore_axes = None
37
+ transforms.append(
38
+ SpatialTransform(
39
+ patch_size_spatial, patch_center_dist_from_border=0, random_crop=False, p_elastic_deform=0,
40
+ p_rotation=1,
41
+ rotation=rotation_for_DA, p_scaling=1, scaling=(0.7, 1.4), p_synchronize_scaling_across_axes=1
42
+ )
43
+ )
44
+ if do_dummy_2d_data_aug:
45
+ transforms.append(Convert2DTo3DTransform())
46
+
47
+ transforms.append(
48
+ GaussianNoiseTransform(
49
+ noise_variance=(0, 0.1),
50
+ p_per_channel=1,
51
+ synchronize_channels=True
52
+ )
53
+ )
54
+ transforms.append(
55
+ GaussianBlurTransform(
56
+ blur_sigma=(0.5, 1.),
57
+ synchronize_channels=False,
58
+ synchronize_axes=False,
59
+ p_per_channel=1, benchmark=True
60
+ ))
61
+ transforms.append(
62
+ MultiplicativeBrightnessTransform(
63
+ multiplier_range=BGContrast((0.75, 1.25)),
64
+ synchronize_channels=False,
65
+ p_per_channel=1
66
+ ))
67
+ transforms.append(
68
+ ContrastTransform(
69
+ contrast_range=BGContrast((0.75, 1.25)),
70
+ preserve_range=True,
71
+ synchronize_channels=False,
72
+ p_per_channel=1
73
+ ))
74
+ transforms.append(
75
+ SimulateLowResolutionTransform(
76
+ scale=(0.5, 1),
77
+ synchronize_channels=False,
78
+ synchronize_axes=True,
79
+ ignore_axes=ignore_axes,
80
+ allowed_channels=None,
81
+ p_per_channel=1
82
+ ))
83
+ transforms.append(
84
+ GammaTransform(
85
+ gamma=BGContrast((0.7, 1.5)),
86
+ p_invert_image=1,
87
+ synchronize_channels=False,
88
+ p_per_channel=1,
89
+ p_retain_stats=1
90
+ ))
91
+ transforms.append(
92
+ GammaTransform(
93
+ gamma=BGContrast((0.7, 1.5)),
94
+ p_invert_image=0,
95
+ synchronize_channels=False,
96
+ p_per_channel=1,
97
+ p_retain_stats=1
98
+ ))
99
+ transforms.append(
100
+ MirrorTransform(
101
+ allowed_axes=(0, 1, 2)
102
+ )
103
+ )
104
+
105
+ transforms.append(MaskImageTransform(
106
+ apply_to_channels=[0, 1, 2, 3],
107
+ channel_idx_in_seg=0,
108
+ set_outside_to=0,
109
+ ))
110
+
111
+ transforms.append(
112
+ RemoveLabelTansform(-1, 0)
113
+ )
114
+
115
+ transforms.append(
116
+ ConvertSegmentationToRegionsTransform(
117
+ regions=regions,
118
+ channel_in_seg=0
119
+ )
120
+ )
121
+
122
+ transforms.append(DownsampleSegForDSTransform(ds_scales=deep_supervision_scales))
123
+
124
+ compute_times = [[] for i in range(len(transforms))]
125
+
126
+ with torch.no_grad():
127
+ torch.set_num_threads(1)
128
+ for iter in range(50):
129
+ print(iter)
130
+ data_dict = {'image': torch.rand((4, 128, 128, 128)),
131
+ 'segmentation': torch.round(4.5 * torch.rand((1, 128, 128, 128)) - 1, decimals=0).to(torch.int8)}
132
+ for i, t in enumerate(transforms):
133
+ st = time()
134
+ data_dict = t(**data_dict)
135
+ compute_times[i].append(time() - st)
136
+
137
+ for t, ct in zip(transforms, compute_times):
138
+ print(t.__class__.__name__ if not isinstance(t, RandomTransform) else t.transform.__class__.__name__, np.percentile(ct, 20))
@@ -0,0 +1,55 @@
1
+ import torch
2
+ import numpy as np
3
+ from time import time
4
+ import pandas as pd
5
+
6
+ def unique_torch(tensor):
7
+ return torch.unique(tensor)
8
+
9
+ def unique_npy(tensor):
10
+ return np.unique(tensor.numpy())
11
+
12
+ def unique_pandas(tensor):
13
+ np.sort(pd.unique(tensor.numpy().ravel()))
14
+
15
+ def unique_bincount(tensor):
16
+ return torch.where(torch.bincount(tensor.ravel()) > 0)[0]
17
+
18
+
19
+ if __name__ == '__main__':
20
+ torch.set_num_threads(1)
21
+ shape = (64, 64, 64)
22
+ labels = 200
23
+
24
+ times = []
25
+ for _ in range(10):
26
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
27
+ st = time()
28
+ unique = unique_torch(seg)
29
+ times.append(time() - st)
30
+ print('unique_torch', np.median(times))
31
+
32
+ times = []
33
+ for _ in range(10):
34
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
35
+ st = time()
36
+ unique = unique_npy(seg)
37
+ times.append(time() - st)
38
+ print('unique_npy', np.median(times))
39
+
40
+ times = []
41
+ for _ in range(10):
42
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
43
+ st = time()
44
+ unique = unique_pandas(seg)
45
+ times.append(time() - st)
46
+ print('unique_pandas', np.median(times))
47
+
48
+ times = []
49
+ for _ in range(10):
50
+ seg = torch.round(torch.rand(shape) * 20, decimals=0).to(torch.uint8)
51
+ st = time()
52
+ unique = unique_bincount(seg)
53
+ times.append(time() - st)
54
+ print('unique_bincount', np.median(times))
55
+
@@ -0,0 +1,28 @@
1
+ from typing import Union, Tuple, Callable
2
+ import numpy as np
3
+
4
+
5
+ RandomScalar = Union[int, float, Tuple[float, float], Callable[..., Union[int, float]]]
6
+
7
+
8
+ def sample_scalar(scalar_type: RandomScalar, *args, **kwargs):
9
+ if isinstance(scalar_type, (int, float)):
10
+ return scalar_type
11
+ elif isinstance(scalar_type, (list, tuple)):
12
+ assert len(scalar_type) == 2, 'if list is provided, its length must be 2'
13
+ assert scalar_type[0] <= scalar_type[1], 'if list is provided, first entry must be smaller or equal than second entry, ' \
14
+ 'otherwise we cannot sample using np.random.uniform'
15
+ if scalar_type[0] == scalar_type[1]:
16
+ return scalar_type[0]
17
+ return np.random.uniform(*scalar_type)
18
+ elif callable(scalar_type):
19
+ return scalar_type(*args, **kwargs)
20
+ else:
21
+ raise RuntimeError('Unknown type: %s. Expected: int, float, list, tuple, callable', type(scalar_type))
22
+
23
+
24
+ if __name__ == '__main__':
25
+ sample_scalar(0.5)
26
+ sample_scalar((0, 1))
27
+ sample_scalar(lambda: np.random.uniform(-1, 2))
28
+ sample_scalar(lambda x, y: np.random.uniform(x, y), 0.5, 2)
@@ -0,0 +1,72 @@
1
+ import abc
2
+ import torch
3
+
4
+
5
+ class BasicTransform(abc.ABC):
6
+ """
7
+ Transforms are applied to each sample individually. The dataloader is responsible for collating, or we might consider a CollateTransform
8
+
9
+ We expect (C, X, Y) or (C, X, Y, Z) shaped inputs for image and seg (yes seg can have more color channels)
10
+
11
+ No idea what keypoint and bbox will look like, this is Michaels turf
12
+ """
13
+ def __init__(self):
14
+ pass
15
+
16
+ def __call__(self, **data_dict) -> dict:
17
+ params = self.get_parameters(**data_dict)
18
+ return self.apply(data_dict, **params)
19
+
20
+ def apply(self, data_dict, **params):
21
+ if data_dict.get('image') is not None:
22
+ data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
23
+
24
+ if data_dict.get('regression_target') is not None:
25
+ data_dict['regression_target'] = self._apply_to_segmentation(data_dict['regression_target'], **params)
26
+
27
+ if data_dict.get('segmentation') is not None:
28
+ data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
29
+
30
+ if data_dict.get('keypoints') is not None:
31
+ data_dict['keypoints'] = self._apply_to_keypoints(data_dict['keypoints'], **params)
32
+
33
+ if data_dict.get('bbox') is not None:
34
+ data_dict['bbox'] = self._apply_to_bbox(data_dict['bbox'], **params)
35
+
36
+ return data_dict
37
+
38
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
39
+ pass
40
+
41
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
42
+ pass
43
+
44
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
45
+ pass
46
+
47
+ def _apply_to_keypoints(self, keypoints, **params):
48
+ pass
49
+
50
+ def _apply_to_bbox(self, bbox, **params):
51
+ pass
52
+
53
+ def get_parameters(self, **data_dict) -> dict:
54
+ return {}
55
+
56
+
57
+ class ImageOnlyTransform(BasicTransform):
58
+ def apply(self, data_dict: dict, **params) -> dict:
59
+ if data_dict.get('image') is not None:
60
+ data_dict['image'] = self._apply_to_image(data_dict['image'], **params)
61
+ return data_dict
62
+
63
+
64
+ class SegOnlyTransform(BasicTransform):
65
+ def apply(self, data_dict: dict, **params) -> dict:
66
+ if data_dict.get('segmentation') is not None:
67
+ data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params)
68
+ return data_dict
69
+
70
+
71
+ if __name__ == '__main__':
72
+ pass
@@ -0,0 +1,63 @@
1
+ import torch
2
+
3
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
4
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
5
+
6
+
7
+ class MultiplicativeBrightnessTransform(ImageOnlyTransform):
8
+ def __init__(self, multiplier_range: RandomScalar, synchronize_channels: bool, p_per_channel: float = 1):
9
+ super().__init__()
10
+ self.multiplier_range = multiplier_range
11
+ self.synchronize_channels = synchronize_channels
12
+ self.p_per_channel = p_per_channel
13
+
14
+ def get_parameters(self, **data_dict) -> dict:
15
+ shape = data_dict['image'].shape
16
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
17
+ if self.synchronize_channels:
18
+ multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
19
+ else:
20
+ multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
21
+ return {
22
+ 'apply_to_channel': apply_to_channel,
23
+ 'multipliers': multipliers
24
+ }
25
+
26
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
27
+ if len(params['apply_to_channel']) == 0:
28
+ return img
29
+ # even though this is array notation it's a lot slower. Shame shame
30
+ # img[params['apply_to_channel']] *= params['multipliers'].view(-1, *[1]*(img.ndim - 1))
31
+ for c, m in zip(params['apply_to_channel'], params['multipliers']):
32
+ img[c] *= m
33
+ return img
34
+
35
+
36
+ if __name__ == '__main__':
37
+ from time import time
38
+ import numpy as np
39
+ import os
40
+
41
+ os.environ['OMP_NUM_THREADS'] = '1'
42
+ torch.set_num_threads(1)
43
+
44
+ mbt = MultiplicativeBrightnessTransform((0.5, 2.), False, 1)
45
+
46
+ times_torch = []
47
+ for _ in range(1000):
48
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
49
+ st = time()
50
+ out = mbt(**data_dict)
51
+ times_torch.append(time() - st)
52
+ print('torch', np.mean(times_torch))
53
+
54
+ from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform
55
+
56
+ gnt_bg = BrightnessMultiplicativeTransform((0.5, 2), True, p_per_sample=1)
57
+ times_bg = []
58
+ for _ in range(1000):
59
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
60
+ st = time()
61
+ out = gnt_bg(**data_dict)
62
+ times_bg.append(time() - st)
63
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,93 @@
1
+ import torch
2
+
3
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
4
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
5
+ import numpy as np
6
+
7
+
8
+ class BGContrast():
9
+ def __init__(self, contrast_range):
10
+ self.contrast_range = contrast_range
11
+
12
+ def sample_contrast(self, *args, **kwargs):
13
+ if callable(self.contrast_range):
14
+ factor = self.contrast_range()
15
+ else:
16
+ if np.random.random() < 0.5 and self.contrast_range[0] < 1:
17
+ factor = np.random.uniform(self.contrast_range[0], 1)
18
+ else:
19
+ factor = np.random.uniform(max(self.contrast_range[0], 1), self.contrast_range[1])
20
+ return factor
21
+
22
+ def __call__(self, *args, **kwargs):
23
+ return self.sample_contrast(*args, **kwargs)
24
+
25
+ class ContrastTransform(ImageOnlyTransform):
26
+ def __init__(self, contrast_range: RandomScalar, preserve_range: bool, synchronize_channels: bool, p_per_channel: float = 1):
27
+ super().__init__()
28
+ self.contrast_range = contrast_range
29
+ self.preserve_range = preserve_range
30
+ self.synchronize_channels = synchronize_channels
31
+ self.p_per_channel = p_per_channel
32
+
33
+ def get_parameters(self, **data_dict) -> dict:
34
+ shape = data_dict['image'].shape
35
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
36
+ if self.synchronize_channels:
37
+ multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))
38
+ else:
39
+ multipliers = torch.Tensor([sample_scalar(self.contrast_range, image=data_dict['image'], channel=c) for c in apply_to_channel])
40
+ return {
41
+ 'apply_to_channel': apply_to_channel,
42
+ 'multipliers': multipliers
43
+ }
44
+
45
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
46
+ if len(params['apply_to_channel']) == 0:
47
+ return img
48
+ # array notation is not faster, let's leave it like this
49
+ for i in range(len(params['apply_to_channel'])):
50
+ c = params['apply_to_channel'][i]
51
+ mean = img[c].mean()
52
+ if self.preserve_range:
53
+ minm = img[c].min()
54
+ maxm = img[c].max()
55
+
56
+ # this is faster than having it in one line because this circumvents reallocating memory
57
+ img[c] -= mean
58
+ img[c] *= params['multipliers'][i]
59
+ img[c] += mean
60
+
61
+ if self.preserve_range:
62
+ img[c].clamp_(minm, maxm)
63
+
64
+ return img
65
+
66
+
67
+ if __name__ == '__main__':
68
+ from time import time
69
+ import os
70
+
71
+ os.environ['OMP_NUM_THREADS'] = '1'
72
+ torch.set_num_threads(1)
73
+
74
+ mbt = ContrastTransform(BGContrast((0.75, 1.25)).sample_contrast, True, False, p_per_channel=1)
75
+
76
+ times_torch = []
77
+ for _ in range(100):
78
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
79
+ st = time()
80
+ out = mbt(**data_dict)
81
+ times_torch.append(time() - st)
82
+ print('torch', np.mean(times_torch))
83
+
84
+ from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform
85
+
86
+ gnt_bg = ContrastAugmentationTransform((0.75, 1.25), preserve_range=True, per_channel=True, p_per_channel=1)
87
+ times_bg = []
88
+ for _ in range(100):
89
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
90
+ st = time()
91
+ out = gnt_bg(**data_dict)
92
+ times_bg.append(time() - st)
93
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,88 @@
1
+ from typing import Callable, Union
2
+
3
+ import torch
4
+
5
+ from batchgeneratorsv2.helpers.scalar_type import RandomScalar, sample_scalar
6
+ from batchgeneratorsv2.transforms.base.basic_transform import ImageOnlyTransform
7
+
8
+
9
+ class GammaTransform(ImageOnlyTransform):
10
+ def __init__(self, gamma: RandomScalar, p_invert_image: float, synchronize_channels: bool, p_per_channel: float,
11
+ p_retain_stats: float):
12
+ super().__init__()
13
+ self.gamma = gamma
14
+ self.p_invert_image = p_invert_image
15
+ self.synchronize_channels = synchronize_channels
16
+ self.p_per_channel = p_per_channel
17
+ self.p_retain_stats = p_retain_stats
18
+
19
+ def get_parameters(self, **data_dict) -> dict:
20
+ shape = data_dict['image'].shape
21
+ apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]
22
+ retain_stats = torch.rand(len(apply_to_channel)) < self.p_retain_stats
23
+ invert_image = torch.rand(len(apply_to_channel)) < self.p_invert_image
24
+
25
+ if self.synchronize_channels:
26
+ gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=None)] * len(apply_to_channel))
27
+ else:
28
+ gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=c) for c in apply_to_channel])
29
+ return {
30
+ 'apply_to_channel': apply_to_channel,
31
+ 'retain_stats': retain_stats,
32
+ 'invert_image': invert_image,
33
+ 'gamma': gamma
34
+ }
35
+
36
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
37
+ for c, r, i, g in zip(params['apply_to_channel'], params['retain_stats'], params['invert_image'], params['gamma']):
38
+ if i:
39
+ img[c] *= -1
40
+ if r:
41
+ # std_mean is for whatever reason slower than doing the computations separately!?
42
+ # std, mean = torch.std_mean(img[c])
43
+ mean = torch.mean(img[c])
44
+ std = torch.std(img[c])
45
+ minm = torch.min(img[c])
46
+ rnge = torch.max(img[c]) - minm
47
+ img[c] = torch.pow(((img[c] - minm) / torch.clamp(rnge, min=1e-7)), g) * rnge + minm
48
+ if r:
49
+ # std_here, mn_here = torch.std_mean(img[c])
50
+ mn_here = torch.mean(img[c])
51
+ std_here = torch.std(img[c])
52
+ img[c] -= mn_here
53
+ img[c] *= (std / torch.clamp(std_here, min=1e-7))
54
+ img[c] += mean
55
+
56
+ if i:
57
+ img[c] *= -1
58
+ return img
59
+
60
+
61
+ if __name__ == '__main__':
62
+ from time import time
63
+ import numpy as np
64
+ import os
65
+
66
+ os.environ['OMP_NUM_THREADS'] = '1'
67
+ torch.set_num_threads(1)
68
+
69
+ mbt = GammaTransform((0.7, 1.5), 0, False, 1, 1)
70
+
71
+ times_torch = []
72
+ for _ in range(100):
73
+ data_dict = {'image': torch.ones((2, 128, 192, 64))}
74
+ st = time()
75
+ out = mbt(**data_dict)
76
+ times_torch.append(time() - st)
77
+ print('torch', np.mean(times_torch))
78
+
79
+ from batchgenerators.transforms.color_transforms import GammaTransform as BGGamma
80
+
81
+ gnt_bg = BGGamma((0.7, 1.5), False, True, retain_stats=True, p_per_sample=1)
82
+ times_bg = []
83
+ for _ in range(100):
84
+ data_dict = {'data': np.ones((1, 2, 128, 192, 64))}
85
+ st = time()
86
+ out = gnt_bg(**data_dict)
87
+ times_bg.append(time() - st)
88
+ print('bg', np.mean(times_bg))