batchgeneratorsv2 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. batchgeneratorsv2/benchmarks/__init__.py +0 -0
  2. batchgeneratorsv2/benchmarks/bg_comparison/__init__.py +0 -0
  3. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py +90 -0
  4. batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_here.py +138 -0
  5. batchgeneratorsv2/benchmarks/unique_values.py +55 -0
  6. batchgeneratorsv2/dataloading/__init__.py +0 -0
  7. batchgeneratorsv2/helpers/__init__.py +0 -0
  8. batchgeneratorsv2/helpers/fft_conv.py +149 -0
  9. batchgeneratorsv2/helpers/scalar_type.py +28 -0
  10. batchgeneratorsv2/transforms/__init__.py +0 -0
  11. batchgeneratorsv2/transforms/base/__init__.py +0 -0
  12. batchgeneratorsv2/transforms/base/basic_transform.py +77 -0
  13. batchgeneratorsv2/transforms/intensity/__init__.py +0 -0
  14. batchgeneratorsv2/transforms/intensity/brightness.py +123 -0
  15. batchgeneratorsv2/transforms/intensity/contrast.py +123 -0
  16. batchgeneratorsv2/transforms/intensity/gamma.py +135 -0
  17. batchgeneratorsv2/transforms/intensity/gaussian_noise.py +104 -0
  18. batchgeneratorsv2/transforms/intensity/inversion.py +51 -0
  19. batchgeneratorsv2/transforms/intensity/random_clip.py +101 -0
  20. batchgeneratorsv2/transforms/local/__init__.py +0 -0
  21. batchgeneratorsv2/transforms/local/brightness_gradient.py +177 -0
  22. batchgeneratorsv2/transforms/local/local_contrast.py +90 -0
  23. batchgeneratorsv2/transforms/local/local_gamma.py +104 -0
  24. batchgeneratorsv2/transforms/local/local_smoothing.py +98 -0
  25. batchgeneratorsv2/transforms/local/local_transform.py +86 -0
  26. batchgeneratorsv2/transforms/nnunet/__init__.py +0 -0
  27. batchgeneratorsv2/transforms/nnunet/random_binary_operator.py +190 -0
  28. batchgeneratorsv2/transforms/nnunet/remove_connected_components.py +86 -0
  29. batchgeneratorsv2/transforms/nnunet/seg_to_onehot.py +32 -0
  30. batchgeneratorsv2/transforms/noise/__init__.py +0 -0
  31. batchgeneratorsv2/transforms/noise/blank_rectangle.py +150 -0
  32. batchgeneratorsv2/transforms/noise/gaussian_blur.py +260 -0
  33. batchgeneratorsv2/transforms/noise/median_filter.py +52 -0
  34. batchgeneratorsv2/transforms/noise/rician.py +61 -0
  35. batchgeneratorsv2/transforms/noise/sharpen.py +128 -0
  36. batchgeneratorsv2/transforms/spatial/__init__.py +0 -0
  37. batchgeneratorsv2/transforms/spatial/channel_misalignment.py +224 -0
  38. batchgeneratorsv2/transforms/spatial/low_resolution.py +92 -0
  39. batchgeneratorsv2/transforms/spatial/mirroring.py +71 -0
  40. batchgeneratorsv2/transforms/spatial/rot90.py +78 -0
  41. batchgeneratorsv2/transforms/spatial/spatial.py +601 -0
  42. batchgeneratorsv2/transforms/spatial/transpose.py +67 -0
  43. batchgeneratorsv2/transforms/utils/__init__.py +0 -0
  44. batchgeneratorsv2/transforms/utils/compose.py +89 -0
  45. batchgeneratorsv2/transforms/utils/cropping.py +73 -0
  46. batchgeneratorsv2/transforms/utils/deep_supervision_downsampling.py +59 -0
  47. batchgeneratorsv2/transforms/utils/move_channels.py +52 -0
  48. batchgeneratorsv2/transforms/utils/nnunet_masking.py +24 -0
  49. batchgeneratorsv2/transforms/utils/pseudo2d.py +79 -0
  50. batchgeneratorsv2/transforms/utils/random.py +46 -0
  51. batchgeneratorsv2/transforms/utils/remove_label.py +27 -0
  52. batchgeneratorsv2/transforms/utils/seg_to_regions.py +24 -0
  53. batchgeneratorsv2-0.3.2.dist-info/METADATA +252 -0
  54. batchgeneratorsv2-0.3.2.dist-info/RECORD +57 -0
  55. batchgeneratorsv2-0.3.2.dist-info/WHEEL +5 -0
  56. batchgeneratorsv2-0.3.2.dist-info/licenses/LICENSE +201 -0
  57. batchgeneratorsv2-0.3.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from time import perf_counter
5
+ from typing import List, Optional, Dict, Any
6
+
7
+ import numpy as np
8
+
9
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
10
+
11
+
12
+ class ComposeTransforms(BasicTransform):
13
+ def __init__(self, transforms: List[BasicTransform]):
14
+ super().__init__()
15
+ self.transforms = transforms
16
+
17
+ def apply(self, data_dict, **params):
18
+ for t in self.transforms:
19
+ data_dict = t(**data_dict)
20
+ return data_dict
21
+
22
+
23
+ @dataclass
24
+ class _TimingStats:
25
+ total_s: float = 0.0
26
+ n: int = 0
27
+
28
+ def add(self, dt: float) -> None:
29
+ self.total_s += dt
30
+ self.n += 1
31
+
32
+ @property
33
+ def mean_s(self) -> float:
34
+ return self.total_s / self.n if self.n > 0 else 0.0
35
+
36
+
37
+ class TimedComposeTransforms(BasicTransform):
38
+ """
39
+ ComposeTransforms variant that measures per-transform wall-clock time and prints
40
+ average times after every `print_every` calls to `apply`.
41
+
42
+ Notes:
43
+ - Measures wall clock time via perf_counter.
44
+ - For CPU-only pipelines, this is representative. For GPU, you'd need synchronization.
45
+ """
46
+
47
+ def __init__(self, transforms: List[BasicTransform], print_every: int = 100, name: Optional[str] = None, p_write: float = 1.0):
48
+ super().__init__()
49
+ if print_every <= 0:
50
+ raise ValueError("print_every must be >= 1")
51
+ self.transforms = transforms
52
+ self.print_every = int(print_every)
53
+ self.name = name or self.__class__.__name__
54
+
55
+ self._iter = 0
56
+ self.p_write = p_write
57
+ self._stats: Dict[int, _TimingStats] = {i: _TimingStats() for i in range(len(transforms))}
58
+
59
+ def reset_timings(self) -> None:
60
+ """Reset accumulated timing statistics and iteration counter."""
61
+ self._iter = 0
62
+ for s in self._stats.values():
63
+ s.total_s = 0.0
64
+ s.n = 0
65
+
66
+ def _transform_display_name(self, t: BasicTransform) -> str:
67
+ # Prefer explicit "name" attribute if present, otherwise class name
68
+ return getattr(t, "name", None) or t.__class__.__name__
69
+
70
+ def _print_report(self) -> None:
71
+ lines = [f"[{self.name}] Average transform times over last {self._iter} iterations:"]
72
+ # Print in pipeline order
73
+ for i, t in enumerate(self.transforms):
74
+ st = self._stats[i]
75
+ lines.append(f" {i:02d} {self._transform_display_name(t)}: {st.mean_s * 1e3:.3f} ms")
76
+ print("\n".join(lines), flush=True)
77
+
78
+ def apply(self, data_dict: Dict[str, Any], **params) -> Dict[str, Any]:
79
+ for i, t in enumerate(self.transforms):
80
+ t0 = perf_counter()
81
+ data_dict = t(**data_dict)
82
+ dt = perf_counter() - t0
83
+ self._stats[i].add(dt)
84
+
85
+ self._iter += 1
86
+ if self._iter % self.print_every == 0 and np.random.uniform() < self.p_write:
87
+ self._print_report()
88
+
89
+ return data_dict
@@ -0,0 +1,73 @@
1
+ import torch
2
+
3
+
4
+ def crop_tensor(input_tensor, center, crop_size, pad_mode='constant', pad_kwargs=None):
5
+ """
6
+ Crops and pads an input tensor based on the specified center and crop size. Padding can be customized.
7
+
8
+ Parameters:
9
+ - input_tensor (torch.Tensor): The input tensor with shape (c, x, y) or (c, x, y, z).
10
+ - center (tuple): The center coordinates of the crop (x, y) or (x, y, z).
11
+ - crop_size (tuple): The size of the crop (width, height) or (width, height, depth).
12
+ - pad_mode (str): The mode to use for padding (see torch.nn.functional.pad documentation).
13
+ - pad_kwargs (dict, optional): Additional keyword arguments for padding.
14
+
15
+ Returns:
16
+ - torch.Tensor: The cropped and possibly padded tensor.
17
+ """
18
+ if pad_kwargs is None:
19
+ pad_kwargs = {'value': 0}
20
+
21
+ # Calculate dimensions
22
+ dim = len(center) # Spatial dimensions
23
+ assert len(crop_size) == dim, "Crop size and center must have the same number of dimensions"
24
+ assert input_tensor.ndim - 1 == dim, "Crop size and input_tensor must have the same number of spatial dimensions"
25
+
26
+ spatial_shape = input_tensor.shape[-dim:]
27
+ start = [max(0, cen - cs // 2) for cen, cs in zip(center, crop_size)]
28
+ end = [min(sh, st + cs) for sh, st, cs in zip(spatial_shape, start, crop_size)]
29
+
30
+ # Calculate padding
31
+ padding_needed = [(cs - (e - s)) for cs, s, e in zip(crop_size, start, end)]
32
+ pad_before = [max(0, - (cen - cs // 2)) for cen, cs in zip(center, crop_size)]
33
+ pad_after = [pn - pb for pn, pb in zip(padding_needed, pad_before)]
34
+
35
+ # Adjust start and end for the case where the crop is entirely outside the input tensor
36
+ start = [min(max(0, s), sh) for s, sh in zip(start, spatial_shape)]
37
+ end = [max(min(e, sh), 0) for e, sh in zip(end, spatial_shape)]
38
+
39
+ # Perform crop
40
+ slices = [slice(None)] + [slice(s, e) for s, e in zip(start, end)]
41
+ cropped = input_tensor[tuple(slices)]
42
+
43
+ # Pad
44
+ pad_width = sum([[b, a] for b, a in zip(pad_before[::-1], pad_after[::-1])], [])
45
+ if any(pad_width):
46
+ cropped = torch.nn.functional.pad(cropped, pad_width, mode=pad_mode, **pad_kwargs)
47
+
48
+ return cropped
49
+
50
+
51
+ def center_crop(input_tensor, crop_size, pad_mode='constant', pad_kwargs=None):
52
+ """
53
+ Performs a center crop on the input tensor. If the crop extends beyond the borders of the tensor,
54
+ it will be padded according to the specified pad_mode and pad_kwargs.
55
+
56
+ Parameters:
57
+ - input_tensor (torch.Tensor): The input tensor with shape (c, x, y) or (c, x, y, z).
58
+ - crop_size (tuple): The size of the crop (width, height) or (width, height, depth).
59
+ - pad_mode (str): The mode to use for padding (see torch.nn.functional.pad documentation).
60
+ - pad_kwargs (dict, optional): Additional keyword arguments for padding.
61
+
62
+ Returns:
63
+ - torch.Tensor: The center-cropped and possibly padded tensor.
64
+ """
65
+ dim = len(input_tensor.shape) - 1 # Number of spatial dimensions (2 or 3)
66
+ spatial_shape = input_tensor.shape[-dim:] # Spatial dimensions of the input tensor
67
+
68
+ # Calculate the center of the input tensor
69
+ center = tuple(s // 2 for s in spatial_shape)
70
+
71
+ # Use the previously defined function for cropping and padding
72
+ return crop_tensor(input_tensor, center, crop_size, pad_mode, pad_kwargs)
73
+
@@ -0,0 +1,59 @@
1
+ from typing import Tuple, List, Union
2
+ import torch
3
+
4
+ from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
5
+ from torch.nn.functional import interpolate
6
+
7
+
8
+ class DownsampleSegForDSTransform(SegOnlyTransform):
9
+ def __init__(self, ds_scales: Union[List, Tuple]):
10
+ super().__init__()
11
+ self.ds_scales = ds_scales
12
+
13
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> List[torch.Tensor]:
14
+ results = []
15
+ for s in self.ds_scales:
16
+ if not isinstance(s, (tuple, list)):
17
+ s = [s] * (segmentation.ndim - 1)
18
+ else:
19
+ assert len(s) == segmentation.ndim - 1
20
+
21
+ if all([i == 1 for i in s]):
22
+ results.append(segmentation)
23
+ else:
24
+ new_shape = [round(i * j) for i, j in zip(segmentation.shape[1:], s)]
25
+ dtype = segmentation.dtype
26
+ # interpolate is not defined for short etc
27
+ results.append(interpolate(segmentation[None].float(), new_shape, mode='nearest-exact')[0].to(dtype))
28
+ return results
29
+
30
+
31
+ if __name__ == '__main__':
32
+ from time import time
33
+ import numpy as np
34
+ import os
35
+
36
+ os.environ['OMP_NUM_THREADS'] = '1'
37
+ torch.set_num_threads(1)
38
+
39
+ mbt = DownsampleSegForDSTransform((1, 0.5, 0.25))
40
+
41
+ times_torch = []
42
+ for _ in range(1):
43
+ data_dict = {'segmentation': torch.round(5 * torch.rand((2, 128, 192, 64)), decimals=0).to(torch.uint8)}
44
+ st = time()
45
+ out = mbt(**data_dict)
46
+ times_torch.append(time() - st)
47
+ print('torch', np.mean(times_torch))
48
+
49
+ from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
50
+ DownsampleSegForDSTransform2
51
+
52
+ gnt_bg = DownsampleSegForDSTransform2((1, 0.5, 0.25), order=0)
53
+ times_bg = []
54
+ for _ in range(1):
55
+ data_dict = {'seg': np.round(5 * np.random.uniform(size=(1, 2, 128, 192, 64)), decimals=0).astype(np.uint8)}
56
+ st = time()
57
+ out = gnt_bg(**data_dict)
58
+ times_bg.append(time() - st)
59
+ print('bg', np.mean(times_bg))
@@ -0,0 +1,52 @@
1
+ from typing import Tuple, Union
2
+ import torch
3
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
4
+
5
+
6
+ class MoveChannelsTransform(BasicTransform):
7
+ def __init__(self, channel_ids: Union[int, Tuple[int, ...]], source_key: str, target_key: str):
8
+ super().__init__()
9
+ if isinstance(channel_ids, int):
10
+ channel_ids = (channel_ids,)
11
+ self.channel_ids: Tuple[int, ...] = tuple(channel_ids)
12
+ self.source_key = source_key
13
+ self.target_key = target_key
14
+
15
+ def get_parameters(self, **data_dict) -> dict:
16
+ return {}
17
+
18
+ def apply(self, data_dict, **params):
19
+ src = data_dict[self.source_key]
20
+ assert src.ndim in (3, 4), f"Expected (C,X,Y) or (C,X,Y,Z), got {src.shape}"
21
+
22
+ if self.target_key in data_dict:
23
+ tgt = data_dict[self.target_key]
24
+ assert src.ndim == tgt.ndim, "source and target key must have the same number of dimensions"
25
+ assert src.shape[1:] == tgt.shape[1:], (
26
+ f"spatial dimensions must match. Got source: {src.shape} and target: {tgt.shape}"
27
+ )
28
+ else:
29
+ tgt = None
30
+
31
+ C = src.shape[0]
32
+ idx = torch.as_tensor(self.channel_ids, device=src.device, dtype=torch.long)
33
+
34
+ keep = torch.ones(C, device=src.device, dtype=torch.bool)
35
+ keep[idx] = False
36
+
37
+ move = src[~keep] # channels to move
38
+ src_new = src[keep] # remaining channels
39
+
40
+ # attach moved channels to target
41
+ if tgt is None:
42
+ data_dict[self.target_key] = move
43
+ else:
44
+ data_dict[self.target_key] = torch.cat((tgt, move), dim=0)
45
+
46
+ # update or remove source
47
+ if src_new.shape[0] == 0:
48
+ del data_dict[self.source_key]
49
+ else:
50
+ data_dict[self.source_key] = src_new
51
+
52
+ return data_dict
@@ -0,0 +1,24 @@
1
+ from typing import List
2
+
3
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
4
+
5
+
6
+ class MaskImageTransform(BasicTransform):
7
+ def __init__(self,
8
+ apply_to_channels: List[int],
9
+ channel_idx_in_seg: int = 0,
10
+ set_outside_to: float = 0,
11
+ ):
12
+ super().__init__()
13
+ self.apply_to_channels = apply_to_channels
14
+ self.channel_idx_in_seg = channel_idx_in_seg
15
+ self.set_outside_to = set_outside_to
16
+
17
+ def apply(self, data_dict, **params):
18
+ if len(self.apply_to_channels) == 0:
19
+ return data_dict
20
+ mask = data_dict['segmentation'][self.channel_idx_in_seg] < 0
21
+ for a in self.apply_to_channels:
22
+ data_dict['image'][a, mask] = self.set_outside_to
23
+ return data_dict
24
+
@@ -0,0 +1,79 @@
1
+ import torch
2
+
3
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
4
+
5
+
6
+ class Convert3DTo2DTransform(BasicTransform):
7
+ def apply(self, data_dict, **params):
8
+ if 'image' in data_dict.keys():
9
+ data_dict['nchannels_img'] = data_dict['image'].shape[0]
10
+ if 'segmentation' in data_dict.keys():
11
+ data_dict['nchannels_seg'] = data_dict['segmentation'].shape[0]
12
+ if 'regression_target' in data_dict.keys():
13
+ data_dict['nchannels_regr_trg'] = data_dict['regression_target'].shape[0]
14
+ return super().apply(data_dict, **params)
15
+
16
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
17
+ shp = img.shape
18
+ return img.reshape((shp[0] * shp[1], *shp[2:]))
19
+
20
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
21
+ return self._apply_to_image(regression_target, **params)
22
+
23
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
24
+ return self._apply_to_image(segmentation, **params)
25
+
26
+ def _apply_to_bbox(self, bbox, **params):
27
+ raise NotImplementedError
28
+
29
+ def _apply_to_keypoints(self, keypoints, **params):
30
+ raise NotImplementedError
31
+
32
+
33
+ class Convert2DTo3DTransform(BasicTransform):
34
+ def get_parameters(self, **data_dict) -> dict:
35
+ return {i: data_dict[i] for i in
36
+ ['nchannels_img', 'nchannels_seg', 'nchannels_regr_trg']
37
+ if i in data_dict.keys()}
38
+
39
+ def apply(self, data_dict, **params):
40
+ data_dict = super().apply(data_dict, **params)
41
+ if 'nchannels_img' in data_dict.keys():
42
+ del data_dict['nchannels_img']
43
+ if 'nchannels_seg' in data_dict.keys():
44
+ del data_dict['nchannels_seg']
45
+ if 'nchannels_regr_trg' in data_dict.keys():
46
+ del data_dict['nchannels_regr_trg']
47
+ return data_dict
48
+
49
+ def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:
50
+ return img.reshape((params['nchannels_img'], img.shape[0] // params['nchannels_img'], *img.shape[1:]))
51
+
52
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
53
+ return segmentation.reshape(
54
+ (params['nchannels_seg'], segmentation.shape[0] // params['nchannels_seg'], *segmentation.shape[1:]))
55
+
56
+ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor:
57
+ return regression_target.reshape(
58
+ (params['nchannels_regr_trg'], regression_target.shape[0] // params['nchannels_regr_trg'], *regression_target.shape[1:]))
59
+
60
+ def _apply_to_bbox(self, bbox, **params):
61
+ raise NotImplementedError
62
+
63
+ def _apply_to_keypoints(self, keypoints, **params):
64
+ raise NotImplementedError
65
+
66
+
67
+ if __name__ == '__main__':
68
+ d = torch.rand((2, 32, 64, 128))
69
+ s = torch.ones((1, 32, 64, 128))
70
+
71
+ fwd = Convert3DTo2DTransform()
72
+ bwd = Convert2DTo3DTransform()
73
+
74
+ inp = {'image': d, 'segmentation': s}
75
+
76
+ tmp = fwd(**inp)
77
+ print(tmp['image'].shape, tmp['segmentation'].shape)
78
+ out = bwd(**tmp)
79
+ print(out['image'].shape, out['segmentation'].shape)
@@ -0,0 +1,46 @@
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+ from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
6
+ import numpy as np
7
+
8
+
9
+ class RandomTransform(BasicTransform):
10
+ def __init__(self, transform: BasicTransform, apply_probability: float = 1):
11
+ super().__init__()
12
+ self.transform = transform
13
+ self.apply_probability = apply_probability
14
+
15
+ def get_parameters(self, **data_dict) -> dict:
16
+ return {"apply_transform": torch.rand(1).item() < self.apply_probability}
17
+
18
+ def apply(self, data_dict: dict, **params) -> dict:
19
+ if params['apply_transform']:
20
+ return self.transform(**data_dict)
21
+ else:
22
+ return data_dict
23
+
24
+ def __repr__(self):
25
+ ret_str = f"{type(self).__name__}(p={self.apply_probability}, transform={self.transform})"
26
+ return ret_str
27
+
28
+
29
+ class OneOfTransform(BasicTransform):
30
+ """
31
+ Randomly selects and applies one transform from the provided list.
32
+
33
+ Each transform must be a callable (usually a BasicTransform subclass).
34
+ This does not override the internal probabilities of the transforms themselves.
35
+
36
+ Args:
37
+ list_of_transforms (List[BasicTransform]): A list of transform instances to choose from.
38
+ """
39
+
40
+ def __init__(self, list_of_transforms: List[BasicTransform]):
41
+ super().__init__()
42
+ self.list_of_transforms = list_of_transforms
43
+
44
+ def __call__(self, **data_dict) -> dict:
45
+ chosen_transform = np.random.choice(self.list_of_transforms)
46
+ return chosen_transform(**data_dict)
@@ -0,0 +1,27 @@
1
+ from typing import Union, Tuple, List
2
+
3
+ import torch
4
+
5
+ from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
6
+
7
+
8
+ class RemoveLabelTransform(SegOnlyTransform):
9
+ def __init__(self, label_value: int, set_to: int, segmentation_channels: Union[int, Tuple[int, ...], List[int]] = None):
10
+ if not isinstance(segmentation_channels, (list, tuple)) and segmentation_channels is not None:
11
+ segmentation_channels = [segmentation_channels]
12
+ self.segmentation_channels = segmentation_channels
13
+ self.label_value = label_value
14
+ self.set_to = set_to
15
+ super().__init__()
16
+
17
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
18
+ if self.segmentation_channels is None:
19
+ channels = list(range(segmentation.shape[0]))
20
+ else:
21
+ channels = self.segmentation_channels
22
+ for s in channels:
23
+ segmentation[s][segmentation[s] == self.label_value] = self.set_to
24
+ return segmentation
25
+
26
+ # uff.
27
+ RemoveLabelTansform = RemoveLabelTransform
@@ -0,0 +1,24 @@
1
+ from typing import Union, List, Tuple
2
+ import torch
3
+
4
+ from batchgeneratorsv2.transforms.base.basic_transform import SegOnlyTransform
5
+
6
+
7
+ class ConvertSegmentationToRegionsTransform(SegOnlyTransform):
8
+ def __init__(self, regions: Union[List[Union[int, List[int]]], Tuple[Union[int, Tuple[int, ...]]]], channel_in_seg: int = 0):
9
+ super().__init__()
10
+ self.regions = [torch.Tensor(i) if not isinstance(i, int) else torch.Tensor([i]) for i in regions]
11
+ self.channel_in_seg = channel_in_seg
12
+
13
+ def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor:
14
+ num_regions = len(self.regions)
15
+ region_output = torch.zeros((num_regions, *segmentation.shape[1:]), dtype=torch.bool, device=segmentation.device)
16
+ for region_id, region_labels in enumerate(self.regions):
17
+ if len(region_labels) == 1:
18
+ region_output[region_id] = segmentation[self.channel_in_seg] == region_labels
19
+ else:
20
+ region_output[region_id] = torch.isin(segmentation[self.channel_in_seg], region_labels)
21
+ # we return bool here and leave it to the loss function to cast it to whatever it needs. Transferring bool to
22
+ # device followed by cast on device should be faster than having fp32 here and transferring that
23
+ return region_output
24
+