careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py CHANGED
@@ -3,12 +3,14 @@ UNet model.
3
3
 
4
4
  A UNet encoder, decoder and complete model.
5
5
  """
6
- from typing import Callable, List, Optional
6
+ from typing import Any, List, Union
7
7
 
8
8
  import torch
9
9
  import torch.nn as nn
10
10
 
11
- from .layers import Conv_Block
11
+ from ..config.support import SupportedActivation
12
+ from .activation import get_activation
13
+ from .layers import Conv_Block, MaxBlurPool
12
14
 
13
15
 
14
16
  class UnetEncoder(nn.Module):
@@ -42,6 +44,7 @@ class UnetEncoder(nn.Module):
42
44
  use_batch_norm: bool = True,
43
45
  dropout: float = 0.0,
44
46
  pool_kernel: int = 2,
47
+ n2v2: bool = False,
45
48
  ) -> None:
46
49
  """
47
50
  Constructor.
@@ -65,7 +68,11 @@ class UnetEncoder(nn.Module):
65
68
  """
66
69
  super().__init__()
67
70
 
68
- self.pooling = getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
71
+ self.pooling = (
72
+ getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
73
+ if not n2v2
74
+ else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
75
+ )
69
76
 
70
77
  encoder_blocks = []
71
78
 
@@ -82,7 +89,6 @@ class UnetEncoder(nn.Module):
82
89
  )
83
90
  )
84
91
  encoder_blocks.append(self.pooling)
85
-
86
92
  self.encoder_blocks = nn.ModuleList(encoder_blocks)
87
93
 
88
94
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
@@ -134,6 +140,7 @@ class UnetDecoder(nn.Module):
134
140
  num_channels_init: int = 64,
135
141
  use_batch_norm: bool = True,
136
142
  dropout: float = 0.0,
143
+ n2v2: bool = False,
137
144
  ) -> None:
138
145
  """
139
146
  Constructor.
@@ -157,6 +164,9 @@ class UnetDecoder(nn.Module):
157
164
  scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
158
165
  )
159
166
  in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
167
+
168
+ self.n2v2 = n2v2
169
+
160
170
  self.bottleneck = Conv_Block(
161
171
  conv_dim,
162
172
  in_channels=in_channels,
@@ -169,12 +179,18 @@ class UnetDecoder(nn.Module):
169
179
  decoder_blocks = []
170
180
  for n in range(depth):
171
181
  decoder_blocks.append(upsampling)
172
- in_channels = num_channels_init * 2 ** (depth - n)
173
- out_channels = num_channels_init
182
+ in_channels = (
183
+ num_channels_init ** (depth - n)
184
+ if (self.n2v2 and n == depth - 1)
185
+ else num_channels_init * 2 ** (depth - n)
186
+ )
187
+ out_channels = in_channels // 2
174
188
  decoder_blocks.append(
175
189
  Conv_Block(
176
190
  conv_dim,
177
- in_channels=in_channels,
191
+ in_channels=in_channels + in_channels // 2
192
+ if n > 0
193
+ else in_channels,
178
194
  out_channels=out_channels,
179
195
  intermediate_channel_multiplier=2,
180
196
  dropout_perc=dropout,
@@ -200,13 +216,19 @@ class UnetDecoder(nn.Module):
200
216
  torch.Tensor
201
217
  Output of the decoder.
202
218
  """
203
- x = features[0]
204
- skip_connections = features[1:][::-1]
219
+ x: torch.Tensor = features[0]
220
+ skip_connections: torch.Tensor = features[1:][::-1]
221
+
205
222
  x = self.bottleneck(x)
223
+
206
224
  for i, module in enumerate(self.decoder_blocks):
207
225
  x = module(x)
208
226
  if isinstance(module, nn.Upsample):
209
- x = torch.cat([x, skip_connections[i // 2]], axis=1)
227
+ if self.n2v2:
228
+ if x.shape != skip_connections[-1].shape:
229
+ x = torch.cat([x, skip_connections[i // 2]], axis=1)
230
+ else:
231
+ x = torch.cat([x, skip_connections[i // 2]], axis=1)
210
232
  return x
211
233
 
212
234
 
@@ -214,12 +236,12 @@ class UNet(nn.Module):
214
236
  """
215
237
  UNet model.
216
238
 
217
- Adapted for PyTorch from
239
+ Adapted for PyTorch from:
218
240
  https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
219
241
 
220
242
  Parameters
221
243
  ----------
222
- conv_dim : int
244
+ conv_dims : int
223
245
  Number of dimensions of the convolution layers (2 or 3).
224
246
  num_classes : int, optional
225
247
  Number of classes to predict, by default 1.
@@ -241,7 +263,7 @@ class UNet(nn.Module):
241
263
 
242
264
  def __init__(
243
265
  self,
244
- conv_dim: int,
266
+ conv_dims: int,
245
267
  num_classes: int = 1,
246
268
  in_channels: int = 1,
247
269
  depth: int = 3,
@@ -249,14 +271,16 @@ class UNet(nn.Module):
249
271
  use_batch_norm: bool = True,
250
272
  dropout: float = 0.0,
251
273
  pool_kernel: int = 2,
252
- last_activation: Optional[Callable] = None,
274
+ final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
275
+ n2v2: bool = False,
276
+ **kwargs: Any,
253
277
  ) -> None:
254
278
  """
255
279
  Constructor.
256
280
 
257
281
  Parameters
258
282
  ----------
259
- conv_dim : int
283
+ conv_dims : int
260
284
  Number of dimensions of the convolution layers (2 or 3).
261
285
  num_classes : int, optional
262
286
  Number of classes to predict, by default 1.
@@ -278,28 +302,30 @@ class UNet(nn.Module):
278
302
  super().__init__()
279
303
 
280
304
  self.encoder = UnetEncoder(
281
- conv_dim,
305
+ conv_dims,
282
306
  in_channels=in_channels,
283
307
  depth=depth,
284
308
  num_channels_init=num_channels_init,
285
309
  use_batch_norm=use_batch_norm,
286
310
  dropout=dropout,
287
311
  pool_kernel=pool_kernel,
312
+ n2v2=n2v2,
288
313
  )
289
314
 
290
315
  self.decoder = UnetDecoder(
291
- conv_dim,
316
+ conv_dims,
292
317
  depth=depth,
293
318
  num_channels_init=num_channels_init,
294
319
  use_batch_norm=use_batch_norm,
295
320
  dropout=dropout,
321
+ n2v2=n2v2,
296
322
  )
297
- self.final_conv = getattr(nn, f"Conv{conv_dim}d")(
323
+ self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
298
324
  in_channels=num_channels_init,
299
325
  out_channels=num_classes,
300
326
  kernel_size=1,
301
327
  )
302
- self.last_activation = last_activation if last_activation else nn.Identity()
328
+ self.final_activation = get_activation(final_activation)
303
329
 
304
330
  def forward(self, x: torch.Tensor) -> torch.Tensor:
305
331
  """
@@ -318,5 +344,5 @@ class UNet(nn.Module):
318
344
  encoder_features = self.encoder(x)
319
345
  x = self.decoder(*encoder_features)
320
346
  x = self.final_conv(x)
321
- x = self.last_activation(x)
347
+ x = self.final_activation(x)
322
348
  return x
@@ -2,8 +2,6 @@
2
2
 
3
3
  __all__ = [
4
4
  "stitch_prediction",
5
- "tta_backward",
6
- "tta_forward",
7
5
  ]
8
6
 
9
- from .prediction_utils import stitch_prediction, tta_backward, tta_forward
7
+ from .stitch_prediction import stitch_prediction
@@ -0,0 +1,73 @@
1
+ """
2
+ Prediction convenience functions.
3
+
4
+ These functions are used during prediction.
5
+ """
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def stitch_prediction(
13
+ tiles: List[torch.Tensor],
14
+ stitching_data: List[List[torch.Tensor]],
15
+ ) -> torch.Tensor:
16
+ """
17
+ Stitch tiles back together to form a full image.
18
+
19
+ Parameters
20
+ ----------
21
+ tiles : List[torch.Tensor]
22
+ Cropped tiles and their respective stitching coordinates.
23
+ stitching_coords : List
24
+ List of information and coordinates obtained from
25
+ `dataset.tiled_patching.extract_tiles`.
26
+
27
+ Returns
28
+ -------
29
+ np.ndarray
30
+ Full image.
31
+ """
32
+ # retrieve whole array size, there is two cases to consider:
33
+ # 1. the tiles are stored in a list
34
+ # 2. the tiles are stored in a list with batches along the first dim
35
+ if tiles[0].shape[0] > 1:
36
+ input_shape = np.array(
37
+ [el.numpy() for el in stitching_data[0][0][0]], dtype=int
38
+ ).squeeze()
39
+ else:
40
+ input_shape = np.array(
41
+ [el.numpy() for el in stitching_data[0][0]], dtype=int
42
+ ).squeeze()
43
+
44
+ # TODO should use torch.zeros instead of np.zeros
45
+ predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32))
46
+
47
+ for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip(
48
+ tiles, stitching_data
49
+ ):
50
+ for batch_idx in range(tile_batch.shape[0]):
51
+ # Compute coordinates for cropping predicted tile
52
+ slices = tuple(
53
+ [
54
+ slice(c[0][batch_idx], c[1][batch_idx])
55
+ for c in overlap_crop_coords_batch
56
+ ]
57
+ )
58
+
59
+ # Crop predited tile according to overlap coordinates
60
+ cropped_tile = tile_batch[batch_idx].squeeze()[slices]
61
+
62
+ # Insert cropped tile into predicted image using stitch coordinates
63
+ predicted_image[
64
+ (
65
+ ...,
66
+ *[
67
+ slice(c[0][batch_idx], c[1][batch_idx])
68
+ for c in stitch_coords_batch
69
+ ],
70
+ )
71
+ ] = cropped_tile.to(torch.float32)
72
+
73
+ return predicted_image
@@ -0,0 +1,41 @@
1
+ """Transforms that are used to augment the data."""
2
+
3
+ __all__ = [
4
+ "get_all_transforms",
5
+ "N2VManipulate",
6
+ "NDFlip",
7
+ "XYRandomRotate90",
8
+ "ImageRestorationTTA",
9
+ "Denormalize",
10
+ "Normalize",
11
+ ]
12
+
13
+
14
+ from .n2v_manipulate import N2VManipulate
15
+ from .nd_flip import NDFlip
16
+ from .normalize import Denormalize, Normalize
17
+ from .tta import ImageRestorationTTA
18
+ from .xy_random_rotate90 import XYRandomRotate90
19
+
20
+ ALL_TRANSFORMS = {
21
+ "Normalize": Normalize,
22
+ "N2VManipulate": N2VManipulate,
23
+ "NDFlip": NDFlip,
24
+ "XYRandomRotate90": XYRandomRotate90,
25
+ }
26
+
27
+
28
+ def get_all_transforms() -> dict:
29
+ """Return all the transforms accepted by CAREamics.
30
+
31
+ Note that while CAREamics accepts any `Compose` transforms from Albumentations (see
32
+ https://albumentations.ai/), only a few transformations are explicitely supported
33
+ (see `SupportedTransform`).
34
+
35
+ Returns
36
+ -------
37
+ dict
38
+ A dictionary with all the transforms accepted by CAREamics, where the keys are
39
+ the transform names and the values are the transform classes.
40
+ """
41
+ return ALL_TRANSFORMS
@@ -0,0 +1,113 @@
1
+ from typing import Any, Literal, Optional, Tuple
2
+
3
+ import numpy as np
4
+ from albumentations import ImageOnlyTransform
5
+
6
+ from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
7
+
8
+ from .pixel_manipulation import median_manipulate, uniform_manipulate
9
+ from .struct_mask_parameters import StructMaskParameters
10
+
11
+
12
+ class N2VManipulate(ImageOnlyTransform):
13
+ """
14
+ Default augmentation for the N2V model.
15
+
16
+ This transform expects (Z)YXC dimensions.
17
+
18
+ Parameters
19
+ ----------
20
+ mask_pixel_percentage : float
21
+ Approximate percentage of pixels to be masked.
22
+ roi_size : int
23
+ Size of the ROI the new pixel value is sampled from, by default 11.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ roi_size: int = 11,
29
+ masked_pixel_percentage: float = 0.2,
30
+ strategy: Literal[
31
+ "uniform", "median"
32
+ ] = SupportedPixelManipulation.UNIFORM.value,
33
+ remove_center: bool = True,
34
+ struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
35
+ struct_mask_span: int = 5,
36
+ ):
37
+ """Constructor.
38
+
39
+ Parameters
40
+ ----------
41
+ roi_size : int, optional
42
+ Size of the replacement area, by default 11
43
+ masked_pixel_percentage : float, optional
44
+ Percentage of pixels to mask, by default 0.2
45
+ strategy : Literal[ "uniform", "median" ], optional
46
+ Replaccement strategy, uniform or median, by default uniform
47
+ remove_center : bool, optional
48
+ Whether to remove central pixel from patch, by default True
49
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
50
+ StructN2V mask axis, by default "none"
51
+ struct_mask_span : int, optional
52
+ StructN2V mask span, by default 5
53
+ """
54
+ super().__init__(p=1)
55
+ self.masked_pixel_percentage = masked_pixel_percentage
56
+ self.roi_size = roi_size
57
+ self.strategy = strategy
58
+ self.remove_center = remove_center
59
+
60
+ if struct_mask_axis == SupportedStructAxis.NONE:
61
+ self.struct_mask: Optional[StructMaskParameters] = None
62
+ else:
63
+ self.struct_mask = StructMaskParameters(
64
+ axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1,
65
+ span=struct_mask_span,
66
+ )
67
+
68
+ def apply(
69
+ self, patch: np.ndarray, **kwargs: Any
70
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
71
+ """Apply the transform to the image.
72
+
73
+ Parameters
74
+ ----------
75
+ image : np.ndarray
76
+ Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
77
+ """
78
+ masked = np.zeros_like(patch)
79
+ mask = np.zeros_like(patch)
80
+ if self.strategy == SupportedPixelManipulation.UNIFORM:
81
+ # Iterate over the channels to apply manipulation separately
82
+ for c in range(patch.shape[-1]):
83
+ masked[..., c], mask[..., c] = uniform_manipulate(
84
+ patch=patch[..., c],
85
+ mask_pixel_percentage=self.masked_pixel_percentage,
86
+ subpatch_size=self.roi_size,
87
+ remove_center=self.remove_center,
88
+ struct_params=self.struct_mask,
89
+ )
90
+ elif self.strategy == SupportedPixelManipulation.MEDIAN:
91
+ # Iterate over the channels to apply manipulation separately
92
+ for c in range(patch.shape[-1]):
93
+ masked[..., c], mask[..., c] = median_manipulate(
94
+ patch=patch[..., c],
95
+ mask_pixel_percentage=self.masked_pixel_percentage,
96
+ subpatch_size=self.roi_size,
97
+ struct_params=self.struct_mask,
98
+ )
99
+ else:
100
+ raise ValueError(f"Unknown masking strategy ({self.strategy}).")
101
+
102
+ # TODO why return patch?
103
+ return masked, patch, mask
104
+
105
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
106
+ """Get the transform parameters.
107
+
108
+ Returns
109
+ -------
110
+ Tuple[str, ...]
111
+ Transform parameters.
112
+ """
113
+ return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask")
@@ -0,0 +1,93 @@
1
+ from typing import Any, Dict, Tuple
2
+
3
+ import numpy as np
4
+ from albumentations import DualTransform
5
+
6
+
7
+ class NDFlip(DualTransform):
8
+ """Flip ND arrays on a single axis.
9
+
10
+ This transform ignores singleton axes and randomly flips one of the other
11
+ axes, to the exception of the first and last axes (sample and channels).
12
+
13
+ This transform expects (Z)YXC dimensions.
14
+ """
15
+
16
+ def __init__(self, p: float = 0.5, is_3D: bool = False, flip_z: bool = True):
17
+ """Constructor.
18
+
19
+ Parameters
20
+ ----------
21
+ p : float, optional
22
+ Probability to apply the transform, by default 0.5
23
+ is_3D : bool, optional
24
+ Whether the data is 3D, by default False
25
+ flip_z : bool, optional
26
+ Whether to flip Z dimension, by default True
27
+ """
28
+ super().__init__(p=p)
29
+
30
+ self.is_3D = is_3D
31
+ self.flip_z = flip_z
32
+
33
+ # "flippable" axes
34
+ if is_3D:
35
+ self.axis_indices = [0, 1, 2] if flip_z else [1, 2]
36
+ else:
37
+ self.axis_indices = [0, 1]
38
+
39
+ def get_params(self, **kwargs: Any) -> Dict[str, int]:
40
+ """Get the transform parameters.
41
+
42
+ Returns
43
+ -------
44
+ Dict[str, int]
45
+ Transform parameters.
46
+ """
47
+ return {"flip_axis": np.random.choice(self.axis_indices)}
48
+
49
+ def apply(self, patch: np.ndarray, flip_axis: int, **kwargs: Any) -> np.ndarray:
50
+ """Apply the transform to the image.
51
+
52
+ Parameters
53
+ ----------
54
+ patch : np.ndarray
55
+ Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
56
+ flip_axis : int
57
+ Axis along which to flip the patch.
58
+ """
59
+ if len(patch.shape) == 3 and self.is_3D:
60
+ raise ValueError(
61
+ "Incompatible patch shape and dimensionality. ZYXC patch shape "
62
+ "expected, but got YXC shape."
63
+ )
64
+
65
+ return np.ascontiguousarray(np.flip(patch, axis=flip_axis))
66
+
67
+ def apply_to_mask(
68
+ self, mask: np.ndarray, flip_axis: int, **kwargs: Any
69
+ ) -> np.ndarray:
70
+ """Apply the transform to the mask.
71
+
72
+ Parameters
73
+ ----------
74
+ mask : np.ndarray
75
+ Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
76
+ """
77
+ if len(mask.shape) == 3 and self.is_3D:
78
+ raise ValueError(
79
+ "Incompatible mask shape and dimensionality. ZYXC patch shape "
80
+ "expected, but got YXC shape."
81
+ )
82
+
83
+ return np.ascontiguousarray(np.flip(mask, axis=flip_axis))
84
+
85
+ def get_transform_init_args_names(self, **kwargs: Any) -> Tuple[str, ...]:
86
+ """Get the transform arguments names.
87
+
88
+ Returns
89
+ -------
90
+ Tuple[str, ...]
91
+ Transform arguments names.
92
+ """
93
+ return ("is_3D", "flip_z")
@@ -0,0 +1,109 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ from albumentations import DualTransform
5
+
6
+
7
+ class Normalize(DualTransform):
8
+ """
9
+ Normalize an image or image patch.
10
+
11
+ Normalization is a zero mean and unit variance. This transform expects (Z)YXC
12
+ dimensions.
13
+
14
+ Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
15
+ division by zero and that it returns a float32 image.
16
+
17
+ Attributes
18
+ ----------
19
+ mean : float
20
+ Mean value.
21
+ std : float
22
+ Standard deviation value.
23
+ eps : float
24
+ Epsilon value to avoid division by zero.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ mean: float,
30
+ std: float,
31
+ ):
32
+ super().__init__(always_apply=True, p=1)
33
+
34
+ self.mean = mean
35
+ self.std = std
36
+ self.eps = 1e-6
37
+
38
+ def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray:
39
+ """
40
+ Apply the transform to the image.
41
+
42
+ Parameters
43
+ ----------
44
+ patch : np.ndarray
45
+ Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
46
+
47
+ Returns
48
+ -------
49
+ np.ndarray
50
+ Normalized image or image patch.
51
+ """
52
+ return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
53
+
54
+ def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray:
55
+ """
56
+ Apply the transform to the mask.
57
+
58
+ The mask is returned as is.
59
+
60
+ Parameters
61
+ ----------
62
+ mask : np.ndarray
63
+ Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
64
+ """
65
+ return mask
66
+
67
+
68
+ class Denormalize(DualTransform):
69
+ """
70
+ Denormalize an image or image patch.
71
+
72
+ Denormalization is performed expecting a zero mean and unit variance input. This
73
+ transform expects (Z)YXC dimensions.
74
+
75
+ Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
76
+ division by zero during the normalization step, which is taken into account during
77
+ denormalization.
78
+
79
+ Attributes
80
+ ----------
81
+ mean : float
82
+ Mean value.
83
+ std : float
84
+ Standard deviation value.
85
+ eps : float
86
+ Epsilon value to avoid division by zero.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ mean: float,
92
+ std: float,
93
+ ):
94
+ super().__init__(always_apply=True, p=1)
95
+
96
+ self.mean = mean
97
+ self.std = std
98
+ self.eps = 1e-6
99
+
100
+ def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray:
101
+ """
102
+ Apply the transform to the image.
103
+
104
+ Parameters
105
+ ----------
106
+ patch : np.ndarray
107
+ Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
108
+ """
109
+ return patch * (self.std + self.eps) + self.mean