batchgeneratorsv2 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batchgeneratorsv2/benchmarks/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
- batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
- batchgeneratorsv2/benchmarks/unique_values.py +55 -0
- batchgeneratorsv2/dataloading/__init__.py +0 -0
- batchgeneratorsv2/helpers/__init__.py +0 -0
- batchgeneratorsv2/helpers/fft_conv.py +149 -0
- batchgeneratorsv2/helpers/scalar_type.py +28 -0
- batchgeneratorsv2/transforms/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/__init__.py +0 -0
- batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
- batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
- batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
- batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
- batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
- batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
- batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
- batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
- batchgeneratorsv2/transforms/local/__init__.py +0 -0
- batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
- batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
- batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
- batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
- batchgeneratorsv2/transforms/local/local_transform.py +86 -0
- batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
- batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
- batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
- batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
- batchgeneratorsv2/transforms/noise/__init__.py +0 -0
- batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
- batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
- batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
- batchgeneratorsv2/transforms/noise/rician.py +61 -0
- batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
- batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
- batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
- batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
- batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
- batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
- batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
- batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
- batchgeneratorsv2/transforms/utils/__init__.py +0 -0
- batchgeneratorsv2/transforms/utils/compose.py +89 -0
- batchgeneratorsv2/transforms/utils/cropping.py +73 -0
- batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
- batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
- batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
- batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
- batchgeneratorsv2/transforms/utils/random.py +46 -0
- batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
- batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
- batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
- batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
- batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
- batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
- batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from time import perf_counter
|
|
5
|
+
from typing import List, Optional, Dict, Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ComposeTransforms(BasicTransform):
|
|
13
|
+
def __init__(self, transforms: List[BasicTransform]):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.transforms = transforms
|
|
16
|
+
|
|
17
|
+
def apply(self, data_dict, **params):
|
|
18
|
+
for t in self.transforms:
|
|
19
|
+
data_dict = t(**data_dict)
|
|
20
|
+
return data_dict
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class _TimingStats:
|
|
25
|
+
total_s: float = 0.0
|
|
26
|
+
n: int = 0
|
|
27
|
+
|
|
28
|
+
def add(self, dt: float) -> None:
|
|
29
|
+
self.total_s += dt
|
|
30
|
+
self.n += 1
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def mean_s(self) -> float:
|
|
34
|
+
return self.total_s / self.n if self.n > 0 else 0.0
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TimedComposeTransforms(BasicTransform):
|
|
38
|
+
"""
|
|
39
|
+
ComposeTransforms variant that measures per-transform wall-clock time and prints
|
|
40
|
+
average times after every `print_every` calls to `apply`.
|
|
41
|
+
|
|
42
|
+
Notes:
|
|
43
|
+
- Measures wall clock time via perf_counter.
|
|
44
|
+
- For CPU-only pipelines, this is representative. For GPU, you'd need synchronization.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, transforms: List[BasicTransform], print_every: int = 100, name: Optional[str] = None, p_write: float = 1.0):
|
|
48
|
+
super().__init__()
|
|
49
|
+
if print_every <= 0:
|
|
50
|
+
raise ValueError("print_every must be >= 1")
|
|
51
|
+
self.transforms = transforms
|
|
52
|
+
self.print_every = int(print_every)
|
|
53
|
+
self.name = name or self.__class__.__name__
|
|
54
|
+
|
|
55
|
+
self._iter = 0
|
|
56
|
+
self.p_write = p_write
|
|
57
|
+
self._stats: Dict[int, _TimingStats] = {i: _TimingStats() for i in range(len(transforms))}
|
|
58
|
+
|
|
59
|
+
def reset_timings(self) -> None:
|
|
60
|
+
"""Reset accumulated timing statistics and iteration counter."""
|
|
61
|
+
self._iter = 0
|
|
62
|
+
for s in self._stats.values():
|
|
63
|
+
s.total_s = 0.0
|
|
64
|
+
s.n = 0
|
|
65
|
+
|
|
66
|
+
def _transform_display_name(self, t: BasicTransform) -> str:
|
|
67
|
+
# Prefer explicit "name" attribute if present, otherwise class name
|
|
68
|
+
return getattr(t, "name", None) or t.__class__.__name__
|
|
69
|
+
|
|
70
|
+
def _print_report(self) -> None:
|
|
71
|
+
lines = [f"[{self.name}] Average transform times over last {self._iter} iterations:"]
|
|
72
|
+
# Print in pipeline order
|
|
73
|
+
for i, t in enumerate(self.transforms):
|
|
74
|
+
st = self._stats[i]
|
|
75
|
+
lines.append(f" {i:02d} {self._transform_display_name(t)}: {st.mean_s * 1e3:.3f} ms")
|
|
76
|
+
print("\n".join(lines), flush=True)
|
|
77
|
+
|
|
78
|
+
def apply(self, data_dict: Dict[str, Any], **params) -> Dict[str, Any]:
|
|
79
|
+
for i, t in enumerate(self.transforms):
|
|
80
|
+
t0 = perf_counter()
|
|
81
|
+
data_dict = t(**data_dict)
|
|
82
|
+
dt = perf_counter() - t0
|
|
83
|
+
self._stats[i].add(dt)
|
|
84
|
+
|
|
85
|
+
self._iter += 1
|
|
86
|
+
if self._iter % self.print_every == 0 and np.random.uniform() < self.p_write:
|
|
87
|
+
self._print_report()
|
|
88
|
+
|
|
89
|
+
return data_dict
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def crop_tensor(input_tensor, center, crop_size, pad_mode='constant', pad_kwargs=None):
|
|
5
|
+
"""
|
|
6
|
+
Crops and pads an input tensor based on the specified center and crop size. Padding can be customized.
|
|
7
|
+
|
|
8
|
+
Parameters:
|
|
9
|
+
- input_tensor (torch.Tensor): The input tensor with shape (c, x, y) or (c, x, y, z).
|
|
10
|
+
- center (tuple): The center coordinates of the crop (x, y) or (x, y, z).
|
|
11
|
+
- crop_size (tuple): The size of the crop (width, height) or (width, height, depth).
|
|
12
|
+
- pad_mode (str): The mode to use for padding (see torch.nn.functional.pad documentation).
|
|
13
|
+
- pad_kwargs (dict, optional): Additional keyword arguments for padding.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
- torch.Tensor: The cropped and possibly padded tensor.
|
|
17
|
+
"""
|
|
18
|
+
if pad_kwargs is None:
|
|
19
|
+
pad_kwargs = {'value': 0}
|
|
20
|
+
|
|
21
|
+
# Calculate dimensions
|
|
22
|
+
dim = len(center) # Spatial dimensions
|
|
23
|
+
assert len(crop_size) == dim, "Crop size and center must have the same number of dimensions"
|
|
24
|
+
assert input_tensor.ndim - 1 == dim, "Crop size and input_tensor must have the same number of spatial dimensions"
|
|
25
|
+
|
|
26
|
+
spatial_shape = input_tensor.shape[-dim:]
|
|
27
|
+
start = [max(0, cen - cs // 2) for cen, cs in zip(center, crop_size)]
|
|
28
|
+
end = [min(sh, st + cs) for sh, st, cs in zip(spatial_shape, start, crop_size)]
|
|
29
|
+
|
|
30
|
+
# Calculate padding
|
|
31
|
+
padding_needed = [(cs - (e - s)) for cs, s, e in zip(crop_size, start, end)]
|
|
32
|
+
pad_before = [max(0, - (cen - cs // 2)) for cen, cs in zip(center, crop_size)]
|
|
33
|
+
pad_after = [pn - pb for pn, pb in zip(padding_needed, pad_before)]
|
|
34
|
+
|
|
35
|
+
# Adjust start and end for the case where the crop is entirely outside the input tensor
|
|
36
|
+
start = [min(max(0, s), sh) for s, sh in zip(start, spatial_shape)]
|
|
37
|
+
end = [max(min(e, sh), 0) for e, sh in zip(end, spatial_shape)]
|
|
38
|
+
|
|
39
|
+
# Perform crop
|
|
40
|
+
slices = [slice(None)] + [slice(s, e) for s, e in zip(start, end)]
|
|
41
|
+
cropped = input_tensor[tuple(slices)]
|
|
42
|
+
|
|
43
|
+
# Pad
|
|
44
|
+
pad_width = sum([[b, a] for b, a in zip(pad_before[::-1], pad_after[::-1])], [])
|
|
45
|
+
if any(pad_width):
|
|
46
|
+
cropped = torch.nn.functional.pad(cropped, pad_width, mode=pad_mode, **pad_kwargs)
|
|
47
|
+
|
|
48
|
+
return cropped
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def center_crop(input_tensor, crop_size, pad_mode='constant', pad_kwargs=None):
|
|
52
|
+
"""
|
|
53
|
+
Performs a center crop on the input tensor. If the crop extends beyond the borders of the tensor,
|
|
54
|
+
it will be padded according to the specified pad_mode and pad_kwargs.
|
|
55
|
+
|
|
56
|
+
Parameters:
|
|
57
|
+
- input_tensor (torch.Tensor): The input tensor with shape (c, x, y) or (c, x, y, z).
|
|
58
|
+
- crop_size (tuple): The size of the crop (width, height) or (width, height, depth).
|
|
59
|
+
- pad_mode (str): The mode to use for padding (see torch.nn.functional.pad documentation).
|
|
60
|
+
- pad_kwargs (dict, optional): Additional keyword arguments for padding.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
- torch.Tensor: The center-cropped and possibly padded tensor.
|
|
64
|
+
"""
|
|
65
|
+
dim = len(input_tensor.shape) - 1 # Number of spatial dimensions (2 or 3)
|
|
66
|
+
spatial_shape = input_tensor.shape[-dim:] # Spatial dimensions of the input tensor
|
|
67
|
+
|
|
68
|
+
# Calculate the center of the input tensor
|
|
69
|
+
center = tuple(s // 2 for s in spatial_shape)
|
|
70
|
+
|
|
71
|
+
# Use the previously defined function for cropping and padding
|
|
72
|
+
return crop_tensor(input_tensor, center, crop_size, pad_mode, pad_kwargs)
|
|
73
|
+
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Tuple, List, Union
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
|
|
5
|
+
from torch.nn.functional import interpolate
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DownsampleSegForDSTransform(SegOnlyTransform):
|
|
9
|
+
def __init__(self, ds_scales: Union[List, Tuple]):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.ds_scales = ds_scales
|
|
12
|
+
|
|
13
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> List[torch.Tensor]:
|
|
14
|
+
results = []
|
|
15
|
+
for s in self.ds_scales:
|
|
16
|
+
if not isinstance(s, (tuple, list)):
|
|
17
|
+
s = [s] * (segmentation.ndim - 1)
|
|
18
|
+
else:
|
|
19
|
+
assert len(s) == segmentation.ndim - 1
|
|
20
|
+
|
|
21
|
+
if all([i == 1 for i in s]):
|
|
22
|
+
results.append(segmentation)
|
|
23
|
+
else:
|
|
24
|
+
new_shape = [round(i * j) for i, j in zip(segmentation.shape[1:], s)]
|
|
25
|
+
dtype = segmentation.dtype
|
|
26
|
+
# interpolate is not defined for short etc
|
|
27
|
+
results.append(interpolate(segmentation[None].float(), new_shape, mode='nearest-exact')[0].to(dtype))
|
|
28
|
+
return results
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == '__main__':
|
|
32
|
+
from time import time
|
|
33
|
+
import numpy as np
|
|
34
|
+
import os
|
|
35
|
+
|
|
36
|
+
os.environ['OMP_NUM_THREADS'] = '1'
|
|
37
|
+
torch.set_num_threads(1)
|
|
38
|
+
|
|
39
|
+
mbt = DownsampleSegForDSTransform((1, 0.5, 0.25))
|
|
40
|
+
|
|
41
|
+
times_torch = []
|
|
42
|
+
for _ in range(1):
|
|
43
|
+
data_dict = {'segmentation': torch.round(5 * torch.rand((2, 128, 192, 64)), decimals=0).to(torch.uint8)}
|
|
44
|
+
st = time()
|
|
45
|
+
out = mbt(**data_dict)
|
|
46
|
+
times_torch.append(time() - st)
|
|
47
|
+
print('torch', np.mean(times_torch))
|
|
48
|
+
|
|
49
|
+
from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
|
|
50
|
+
DownsampleSegForDSTransform2
|
|
51
|
+
|
|
52
|
+
gnt_bg = DownsampleSegForDSTransform2((1, 0.5, 0.25), order=0)
|
|
53
|
+
times_bg = []
|
|
54
|
+
for _ in range(1):
|
|
55
|
+
data_dict = {'seg': np.round(5 * np.random.uniform(size=(1, 2, 128, 192, 64)), decimals=0).astype(np.uint8)}
|
|
56
|
+
st = time()
|
|
57
|
+
out = gnt_bg(**data_dict)
|
|
58
|
+
times_bg.append(time() - st)
|
|
59
|
+
print('bg', np.mean(times_bg))
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import Tuple, Union
|
|
2
|
+
import torch
|
|
3
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MoveChannelsTransform(BasicTransform):
|
|
7
|
+
def __init__(self, channel_ids: Union[int, Tuple[int, ...]], source_key: str, target_key: str):
|
|
8
|
+
super().__init__()
|
|
9
|
+
if isinstance(channel_ids, int):
|
|
10
|
+
channel_ids = (channel_ids,)
|
|
11
|
+
self.channel_ids: Tuple[int, ...] = tuple(channel_ids)
|
|
12
|
+
self.source_key = source_key
|
|
13
|
+
self.target_key = target_key
|
|
14
|
+
|
|
15
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
16
|
+
return {}
|
|
17
|
+
|
|
18
|
+
def apply(self, data_dict, **params):
|
|
19
|
+
src = data_dict[self.source_key]
|
|
20
|
+
assert src.ndim in (3, 4), f"Expected (C,X,Y) or (C,X,Y,Z), got {src.shape}"
|
|
21
|
+
|
|
22
|
+
if self.target_key in data_dict:
|
|
23
|
+
tgt = data_dict[self.target_key]
|
|
24
|
+
assert src.ndim == tgt.ndim, "source and target key must have the same number of dimensions"
|
|
25
|
+
assert src.shape[1:] == tgt.shape[1:], (
|
|
26
|
+
f"spatial dimensions must match. Got source: {src.shape} and target: {tgt.shape}"
|
|
27
|
+
)
|
|
28
|
+
else:
|
|
29
|
+
tgt = None
|
|
30
|
+
|
|
31
|
+
C = src.shape[0]
|
|
32
|
+
idx = torch.as_tensor(self.channel_ids, device=src.device, dtype=torch.long)
|
|
33
|
+
|
|
34
|
+
keep = torch.ones(C, device=src.device, dtype=torch.bool)
|
|
35
|
+
keep[idx] = False
|
|
36
|
+
|
|
37
|
+
move = src[~keep] # channels to move
|
|
38
|
+
src_new = src[keep] # remaining channels
|
|
39
|
+
|
|
40
|
+
# attach moved channels to target
|
|
41
|
+
if tgt is None:
|
|
42
|
+
data_dict[self.target_key] = move
|
|
43
|
+
else:
|
|
44
|
+
data_dict[self.target_key] = torch.cat((tgt, move), dim=0)
|
|
45
|
+
|
|
46
|
+
# update or remove source
|
|
47
|
+
if src_new.shape[0] == 0:
|
|
48
|
+
del data_dict[self.source_key]
|
|
49
|
+
else:
|
|
50
|
+
data_dict[self.source_key] = src_new
|
|
51
|
+
|
|
52
|
+
return data_dict
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MaskImageTransform(BasicTransform):
|
|
7
|
+
def __init__(self,
|
|
8
|
+
apply_to_channels: List[int],
|
|
9
|
+
channel_idx_in_seg: int = 0,
|
|
10
|
+
set_outside_to: float = 0,
|
|
11
|
+
):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.apply_to_channels = apply_to_channels
|
|
14
|
+
self.channel_idx_in_seg = channel_idx_in_seg
|
|
15
|
+
self.set_outside_to = set_outside_to
|
|
16
|
+
|
|
17
|
+
def apply(self, data_dict, **params):
|
|
18
|
+
if len(self.apply_to_channels) == 0:
|
|
19
|
+
return data_dict
|
|
20
|
+
mask = data_dict['segmentation'][self.channel_idx_in_seg] < 0
|
|
21
|
+
for a in self.apply_to_channels:
|
|
22
|
+
data_dict['image'][a, mask] = self.set_outside_to
|
|
23
|
+
return data_dict
|
|
24
|
+
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Convert3DTo2DTransform(BasicTransform):
|
|
7
|
+
def apply(self, data_dict, **params):
|
|
8
|
+
if 'image' in data_dict.keys():
|
|
9
|
+
data_dict['nchannels_img'] = data_dict['image'].shape[0]
|
|
10
|
+
if 'segmentation' in data_dict.keys():
|
|
11
|
+
data_dict['nchannels_seg'] = data_dict['segmentation'].shape[0]
|
|
12
|
+
if 'regression_target' in data_dict.keys():
|
|
13
|
+
data_dict['nchannels_regr_trg'] = data_dict['regression_target'].shape[0]
|
|
14
|
+
return super().apply(data_dict, **params)
|
|
15
|
+
|
|
16
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
17
|
+
shp = img.shape
|
|
18
|
+
return img.reshape((shp[0] * shp[1], *shp[2:]))
|
|
19
|
+
|
|
20
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
21
|
+
return self._apply_to_image(regression_target, **params)
|
|
22
|
+
|
|
23
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
24
|
+
return self._apply_to_image(segmentation, **params)
|
|
25
|
+
|
|
26
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
|
|
29
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Convert2DTo3DTransform(BasicTransform):
|
|
34
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
35
|
+
return {i: data_dict[i] for i in
|
|
36
|
+
['nchannels_img', 'nchannels_seg', 'nchannels_regr_trg']
|
|
37
|
+
if i in data_dict.keys()}
|
|
38
|
+
|
|
39
|
+
def apply(self, data_dict, **params):
|
|
40
|
+
data_dict = super().apply(data_dict, **params)
|
|
41
|
+
if 'nchannels_img' in data_dict.keys():
|
|
42
|
+
del data_dict['nchannels_img']
|
|
43
|
+
if 'nchannels_seg' in data_dict.keys():
|
|
44
|
+
del data_dict['nchannels_seg']
|
|
45
|
+
if 'nchannels_regr_trg' in data_dict.keys():
|
|
46
|
+
del data_dict['nchannels_regr_trg']
|
|
47
|
+
return data_dict
|
|
48
|
+
|
|
49
|
+
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
|
|
50
|
+
return img.reshape((params['nchannels_img'], img.shape[0] // params['nchannels_img'], *img.shape[1:]))
|
|
51
|
+
|
|
52
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
53
|
+
return segmentation.reshape(
|
|
54
|
+
(params['nchannels_seg'], segmentation.shape[0] // params['nchannels_seg'], *segmentation.shape[1:]))
|
|
55
|
+
|
|
56
|
+
def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
|
|
57
|
+
return regression_target.reshape(
|
|
58
|
+
(params['nchannels_regr_trg'], regression_target.shape[0] // params['nchannels_regr_trg'], *regression_target.shape[1:]))
|
|
59
|
+
|
|
60
|
+
def _apply_to_bbox(self, bbox, **params):
|
|
61
|
+
raise NotImplementedError
|
|
62
|
+
|
|
63
|
+
def _apply_to_keypoints(self, keypoints, **params):
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if __name__ == '__main__':
|
|
68
|
+
d = torch.rand((2, 32, 64, 128))
|
|
69
|
+
s = torch.ones((1, 32, 64, 128))
|
|
70
|
+
|
|
71
|
+
fwd = Convert3DTo2DTransform()
|
|
72
|
+
bwd = Convert2DTo3DTransform()
|
|
73
|
+
|
|
74
|
+
inp = {'image': d, 'segmentation': s}
|
|
75
|
+
|
|
76
|
+
tmp = fwd(**inp)
|
|
77
|
+
print(tmp['image'].shape, tmp['segmentation'].shape)
|
|
78
|
+
out = bwd(**tmp)
|
|
79
|
+
print(out['image'].shape, out['segmentation'].shape)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RandomTransform(BasicTransform):
|
|
10
|
+
def __init__(self, transform: BasicTransform, apply_probability: float = 1):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.transform = transform
|
|
13
|
+
self.apply_probability = apply_probability
|
|
14
|
+
|
|
15
|
+
def get_parameters(self, **data_dict) -> dict:
|
|
16
|
+
return {"apply_transform": torch.rand(1).item() < self.apply_probability}
|
|
17
|
+
|
|
18
|
+
def apply(self, data_dict: dict, **params) -> dict:
|
|
19
|
+
if params['apply_transform']:
|
|
20
|
+
return self.transform(**data_dict)
|
|
21
|
+
else:
|
|
22
|
+
return data_dict
|
|
23
|
+
|
|
24
|
+
def __repr__(self):
|
|
25
|
+
ret_str = f"{type(self).__name__}(p={self.apply_probability}, transform={self.transform})"
|
|
26
|
+
return ret_str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OneOfTransform(BasicTransform):
|
|
30
|
+
"""
|
|
31
|
+
Randomly selects and applies one transform from the provided list.
|
|
32
|
+
|
|
33
|
+
Each transform must be a callable (usually a BasicTransform subclass).
|
|
34
|
+
This does not override the internal probabilities of the transforms themselves.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
list_of_transforms (List[BasicTransform]): A list of transform instances to choose from.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, list_of_transforms: List[BasicTransform]):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.list_of_transforms = list_of_transforms
|
|
43
|
+
|
|
44
|
+
def __call__(self, **data_dict) -> dict:
|
|
45
|
+
chosen_transform = np.random.choice(self.list_of_transforms)
|
|
46
|
+
return chosen_transform(**data_dict)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Union, Tuple, List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RemoveLabelTransform(SegOnlyTransform):
|
|
9
|
+
def __init__(self, label_value: int, set_to: int, segmentation_channels: Union[int, Tuple[int, ...], List[int]] = None):
|
|
10
|
+
if not isinstance(segmentation_channels, (list, tuple)) and segmentation_channels is not None:
|
|
11
|
+
segmentation_channels = [segmentation_channels]
|
|
12
|
+
self.segmentation_channels = segmentation_channels
|
|
13
|
+
self.label_value = label_value
|
|
14
|
+
self.set_to = set_to
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
18
|
+
if self.segmentation_channels is None:
|
|
19
|
+
channels = list(range(segmentation.shape[0]))
|
|
20
|
+
else:
|
|
21
|
+
channels = self.segmentation_channels
|
|
22
|
+
for s in channels:
|
|
23
|
+
segmentation[s][segmentation[s] == self.label_value] = self.set_to
|
|
24
|
+
return segmentation
|
|
25
|
+
|
|
26
|
+
# uff.
|
|
27
|
+
RemoveLabelTansform = RemoveLabelTransform
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Union, List, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
|
|
8
|
+
def __init__(self, regions: Union[List[Union[int, List[int]]], Tuple[Union[int, Tuple[int, ...]]]], channel_in_seg: int = 0):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.regions = [torch.Tensor(i) if not isinstance(i, int) else torch.Tensor([i]) for i in regions]
|
|
11
|
+
self.channel_in_seg = channel_in_seg
|
|
12
|
+
|
|
13
|
+
def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
|
|
14
|
+
num_regions = len(self.regions)
|
|
15
|
+
region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
|
|
16
|
+
for region_id, region_labels in enumerate(self.regions):
|
|
17
|
+
if len(region_labels) == 1:
|
|
18
|
+
region_output[region_id] = segmentation[self.channel_in_seg] == region_labels
|
|
19
|
+
else:
|
|
20
|
+
region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels)
|
|
21
|
+
# we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to
|
|
22
|
+
# device followed by cast on device should be faster than having fp32 here and transferring that
|
|
23
|
+
return region_output
|
|
24
|
+
|