careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (62) hide show
  1. careamics/careamist.py +12 -11
  2. careamics/config/__init__.py +0 -1
  3. careamics/config/architectures/unet_model.py +1 -0
  4. careamics/config/callback_model.py +1 -0
  5. careamics/config/configuration_example.py +0 -2
  6. careamics/config/configuration_factory.py +112 -42
  7. careamics/config/configuration_model.py +14 -16
  8. careamics/config/data_model.py +59 -157
  9. careamics/config/inference_model.py +19 -20
  10. careamics/config/references/algorithm_descriptions.py +1 -0
  11. careamics/config/references/references.py +1 -0
  12. careamics/config/support/supported_extraction_strategies.py +1 -0
  13. careamics/config/training_model.py +1 -0
  14. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  15. careamics/config/transformations/nd_flip_model.py +6 -11
  16. careamics/config/transformations/normalize_model.py +1 -0
  17. careamics/config/transformations/transform_model.py +1 -0
  18. careamics/config/transformations/xy_random_rotate90_model.py +6 -8
  19. careamics/config/validators/validator_utils.py +1 -0
  20. careamics/conftest.py +1 -0
  21. careamics/dataset/dataset_utils/__init__.py +0 -1
  22. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  23. careamics/dataset/in_memory_dataset.py +14 -45
  24. careamics/dataset/iterable_dataset.py +13 -68
  25. careamics/dataset/patching/__init__.py +0 -7
  26. careamics/dataset/patching/patching.py +1 -0
  27. careamics/dataset/patching/sequential_patching.py +6 -6
  28. careamics/dataset/patching/tiled_patching.py +10 -6
  29. careamics/lightning_datamodule.py +20 -24
  30. careamics/lightning_module.py +1 -1
  31. careamics/lightning_prediction_datamodule.py +15 -10
  32. careamics/losses/__init__.py +0 -1
  33. careamics/losses/loss_factory.py +1 -0
  34. careamics/model_io/__init__.py +0 -1
  35. careamics/model_io/bioimage/_readme_factory.py +2 -1
  36. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  37. careamics/model_io/bioimage/model_description.py +1 -0
  38. careamics/model_io/bmz_io.py +2 -1
  39. careamics/models/layers.py +1 -0
  40. careamics/models/model_factory.py +1 -0
  41. careamics/models/unet.py +91 -17
  42. careamics/prediction/stitch_prediction.py +1 -0
  43. careamics/transforms/__init__.py +2 -23
  44. careamics/transforms/compose.py +98 -0
  45. careamics/transforms/n2v_manipulate.py +18 -23
  46. careamics/transforms/nd_flip.py +38 -64
  47. careamics/transforms/normalize.py +45 -34
  48. careamics/transforms/pixel_manipulation.py +2 -2
  49. careamics/transforms/transform.py +33 -0
  50. careamics/transforms/tta.py +2 -2
  51. careamics/transforms/xy_random_rotate90.py +41 -68
  52. careamics/utils/__init__.py +0 -1
  53. careamics/utils/context.py +1 -0
  54. careamics/utils/logging.py +1 -0
  55. careamics/utils/metrics.py +1 -0
  56. careamics/utils/torch_utils.py +1 -0
  57. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  58. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  59. careamics/dataset/patching/patch_transform.py +0 -44
  60. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  61. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  62. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py CHANGED
@@ -3,7 +3,8 @@ UNet model.
3
3
 
4
4
  A UNet encoder, decoder and complete model.
5
5
  """
6
- from typing import Any, List, Union
6
+
7
+ from typing import Any, List, Tuple, Union
7
8
 
8
9
  import torch
9
10
  import torch.nn as nn
@@ -33,6 +34,9 @@ class UnetEncoder(nn.Module):
33
34
  Dropout probability, by default 0.0.
34
35
  pool_kernel : int, optional
35
36
  Kernel size for the max pooling layers, by default 2.
37
+ groups: int, optional
38
+ Number of blocked connections from input channels to output
39
+ channels, by default 1.
36
40
  """
37
41
 
38
42
  def __init__(
@@ -45,6 +49,7 @@ class UnetEncoder(nn.Module):
45
49
  dropout: float = 0.0,
46
50
  pool_kernel: int = 2,
47
51
  n2v2: bool = False,
52
+ groups: int = 1,
48
53
  ) -> None:
49
54
  """
50
55
  Constructor.
@@ -65,6 +70,9 @@ class UnetEncoder(nn.Module):
65
70
  Dropout probability, by default 0.0.
66
71
  pool_kernel : int, optional
67
72
  Kernel size for the max pooling layers, by default 2.
73
+ groups: int, optional
74
+ Number of blocked connections from input channels to output
75
+ channels, by default 1.
68
76
  """
69
77
  super().__init__()
70
78
 
@@ -77,7 +85,7 @@ class UnetEncoder(nn.Module):
77
85
  encoder_blocks = []
78
86
 
79
87
  for n in range(depth):
80
- out_channels = num_channels_init * (2**n)
88
+ out_channels = num_channels_init * (2**n) * groups
81
89
  in_channels = in_channels if n == 0 else out_channels // 2
82
90
  encoder_blocks.append(
83
91
  Conv_Block(
@@ -86,6 +94,7 @@ class UnetEncoder(nn.Module):
86
94
  out_channels=out_channels,
87
95
  dropout_perc=dropout,
88
96
  use_batch_norm=use_batch_norm,
97
+ groups=groups,
89
98
  )
90
99
  )
91
100
  encoder_blocks.append(self.pooling)
@@ -131,6 +140,9 @@ class UnetDecoder(nn.Module):
131
140
  Whether to use batch normalization, by default True.
132
141
  dropout : float, optional
133
142
  Dropout probability, by default 0.0.
143
+ groups: int, optional
144
+ Number of blocked connections from input channels to output
145
+ channels, by default 1.
134
146
  """
135
147
 
136
148
  def __init__(
@@ -141,6 +153,7 @@ class UnetDecoder(nn.Module):
141
153
  use_batch_norm: bool = True,
142
154
  dropout: float = 0.0,
143
155
  n2v2: bool = False,
156
+ groups: int = 1,
144
157
  ) -> None:
145
158
  """
146
159
  Constructor.
@@ -157,15 +170,19 @@ class UnetDecoder(nn.Module):
157
170
  Whether to use batch normalization, by default True.
158
171
  dropout : float, optional
159
172
  Dropout probability, by default 0.0.
173
+ groups: int, optional
174
+ Number of blocked connections from input channels to output
175
+ channels, by default 1.
160
176
  """
161
177
  super().__init__()
162
178
 
163
179
  upsampling = nn.Upsample(
164
180
  scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
165
181
  )
166
- in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
182
+ in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
167
183
 
168
184
  self.n2v2 = n2v2
185
+ self.groups = groups
169
186
 
170
187
  self.bottleneck = Conv_Block(
171
188
  conv_dim,
@@ -174,34 +191,32 @@ class UnetDecoder(nn.Module):
174
191
  intermediate_channel_multiplier=2,
175
192
  use_batch_norm=use_batch_norm,
176
193
  dropout_perc=dropout,
194
+ groups=self.groups,
177
195
  )
178
196
 
179
- decoder_blocks = []
197
+ decoder_blocks: List[nn.Module] = []
180
198
  for n in range(depth):
181
199
  decoder_blocks.append(upsampling)
182
- in_channels = (
183
- num_channels_init ** (depth - n)
184
- if (self.n2v2 and n == depth - 1)
185
- else num_channels_init * 2 ** (depth - n)
186
- )
200
+ in_channels = (num_channels_init * 2 ** (depth - n)) * groups
187
201
  out_channels = in_channels // 2
188
202
  decoder_blocks.append(
189
203
  Conv_Block(
190
204
  conv_dim,
191
- in_channels=in_channels + in_channels // 2
192
- if n > 0
193
- else in_channels,
205
+ in_channels=(
206
+ in_channels + in_channels // 2 if n > 0 else in_channels
207
+ ),
194
208
  out_channels=out_channels,
195
209
  intermediate_channel_multiplier=2,
196
210
  dropout_perc=dropout,
197
211
  activation="ReLU",
198
212
  use_batch_norm=use_batch_norm,
213
+ groups=groups,
199
214
  )
200
215
  )
201
216
 
202
217
  self.decoder_blocks = nn.ModuleList(decoder_blocks)
203
218
 
204
- def forward(self, *features: List[torch.Tensor]) -> torch.Tensor:
219
+ def forward(self, *features: torch.Tensor) -> torch.Tensor:
205
220
  """
206
221
  Forward pass.
207
222
 
@@ -217,20 +232,70 @@ class UnetDecoder(nn.Module):
217
232
  Output of the decoder.
218
233
  """
219
234
  x: torch.Tensor = features[0]
220
- skip_connections: torch.Tensor = features[1:][::-1]
235
+ skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
221
236
 
222
237
  x = self.bottleneck(x)
223
238
 
224
239
  for i, module in enumerate(self.decoder_blocks):
225
240
  x = module(x)
226
241
  if isinstance(module, nn.Upsample):
242
+ # divide index by 2 because of upsampling layers
243
+ skip_connection: torch.Tensor = skip_connections[i // 2]
227
244
  if self.n2v2:
228
245
  if x.shape != skip_connections[-1].shape:
229
- x = torch.cat([x, skip_connections[i // 2]], axis=1)
246
+ x = self._interleave(x, skip_connection, self.groups)
230
247
  else:
231
- x = torch.cat([x, skip_connections[i // 2]], axis=1)
248
+ x = self._interleave(x, skip_connection, self.groups)
232
249
  return x
233
250
 
251
+ @staticmethod
252
+ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
253
+ """
254
+ Splits the tensors `A` and `B` into equally sized groups along the
255
+ channel axis (axis=1); then concatenates the groups in alternating
256
+ order along the channel axis, starting with the first group from tensor
257
+ A.
258
+
259
+ Parameters
260
+ ----------
261
+ A: torch.Tensor
262
+ B: torch.Tensor
263
+ groups: int
264
+ The number of groups.
265
+
266
+ Returns
267
+ -------
268
+ torch.Tensor
269
+
270
+ Raises
271
+ ------
272
+ ValueError:
273
+ If either of `A` or `B`'s channel axis is not divisible by `groups`.
274
+ """
275
+ if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
276
+ raise ValueError(f"Number of channels not divisible by {groups} groups.")
277
+
278
+ m = A.shape[1] // groups
279
+ n = B.shape[1] // groups
280
+
281
+ A_groups: List[torch.Tensor] = [
282
+ A[:, i * m : (i + 1) * m] for i in range(groups)
283
+ ]
284
+ B_groups: List[torch.Tensor] = [
285
+ B[:, i * n : (i + 1) * n] for i in range(groups)
286
+ ]
287
+
288
+ interleaved = torch.cat(
289
+ [
290
+ tensor_list[i]
291
+ for i in range(groups)
292
+ for tensor_list in [A_groups, B_groups]
293
+ ],
294
+ dim=1,
295
+ )
296
+
297
+ return interleaved
298
+
234
299
 
235
300
  class UNet(nn.Module):
236
301
  """
@@ -273,6 +338,7 @@ class UNet(nn.Module):
273
338
  pool_kernel: int = 2,
274
339
  final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
275
340
  n2v2: bool = False,
341
+ independent_channels: bool = True,
276
342
  **kwargs: Any,
277
343
  ) -> None:
278
344
  """
@@ -298,9 +364,14 @@ class UNet(nn.Module):
298
364
  Kernel size of the pooling layers, by default 2.
299
365
  last_activation : Optional[Callable], optional
300
366
  Activation function to use for the last layer, by default None.
367
+ independent_channels : bool
368
+ Whether to train parallel independent networks for each channel, by
369
+ default True.
301
370
  """
302
371
  super().__init__()
303
372
 
373
+ groups = in_channels if independent_channels else 1
374
+
304
375
  self.encoder = UnetEncoder(
305
376
  conv_dims,
306
377
  in_channels=in_channels,
@@ -310,6 +381,7 @@ class UNet(nn.Module):
310
381
  dropout=dropout,
311
382
  pool_kernel=pool_kernel,
312
383
  n2v2=n2v2,
384
+ groups=groups,
313
385
  )
314
386
 
315
387
  self.decoder = UnetDecoder(
@@ -319,11 +391,13 @@ class UNet(nn.Module):
319
391
  use_batch_norm=use_batch_norm,
320
392
  dropout=dropout,
321
393
  n2v2=n2v2,
394
+ groups=groups,
322
395
  )
323
396
  self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
324
- in_channels=num_channels_init,
397
+ in_channels=num_channels_init * groups,
325
398
  out_channels=num_classes,
326
399
  kernel_size=1,
400
+ groups=groups,
327
401
  )
328
402
  self.final_activation = get_activation(final_activation)
329
403
 
@@ -3,6 +3,7 @@ Prediction convenience functions.
3
3
 
4
4
  These functions are used during prediction.
5
5
  """
6
+
6
7
  from typing import List
7
8
 
8
9
  import numpy as np
@@ -8,34 +8,13 @@ __all__ = [
8
8
  "ImageRestorationTTA",
9
9
  "Denormalize",
10
10
  "Normalize",
11
+ "Compose",
11
12
  ]
12
13
 
13
14
 
15
+ from .compose import Compose, get_all_transforms
14
16
  from .n2v_manipulate import N2VManipulate
15
17
  from .nd_flip import NDFlip
16
18
  from .normalize import Denormalize, Normalize
17
19
  from .tta import ImageRestorationTTA
18
20
  from .xy_random_rotate90 import XYRandomRotate90
19
-
20
- ALL_TRANSFORMS = {
21
- "Normalize": Normalize,
22
- "N2VManipulate": N2VManipulate,
23
- "NDFlip": NDFlip,
24
- "XYRandomRotate90": XYRandomRotate90,
25
- }
26
-
27
-
28
- def get_all_transforms() -> dict:
29
- """Return all the transforms accepted by CAREamics.
30
-
31
- Note that while CAREamics accepts any `Compose` transforms from Albumentations (see
32
- https://albumentations.ai/), only a few transformations are explicitely supported
33
- (see `SupportedTransform`).
34
-
35
- Returns
36
- -------
37
- dict
38
- A dictionary with all the transforms accepted by CAREamics, where the keys are
39
- the transform names and the values are the transform classes.
40
- """
41
- return ALL_TRANSFORMS
@@ -0,0 +1,98 @@
1
+ """A class chaining transforms together."""
2
+
3
+ from typing import Callable, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.data_model import TRANSFORMS_UNION
8
+
9
+ from .n2v_manipulate import N2VManipulate
10
+ from .nd_flip import NDFlip
11
+ from .normalize import Normalize
12
+ from .transform import Transform
13
+ from .xy_random_rotate90 import XYRandomRotate90
14
+
15
+ ALL_TRANSFORMS = {
16
+ "Normalize": Normalize,
17
+ "N2VManipulate": N2VManipulate,
18
+ "NDFlip": NDFlip,
19
+ "XYRandomRotate90": XYRandomRotate90,
20
+ }
21
+
22
+
23
+ def get_all_transforms() -> dict:
24
+ """Return all the transforms accepted by CAREamics.
25
+
26
+ Returns
27
+ -------
28
+ dict
29
+ A dictionary with all the transforms accepted by CAREamics, where the keys are
30
+ the transform names and the values are the transform classes.
31
+ """
32
+ return ALL_TRANSFORMS
33
+
34
+
35
+ class Compose:
36
+ """A class chaining transforms together."""
37
+
38
+ def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
39
+ """Instantiate a Compose object.
40
+
41
+ Parameters
42
+ ----------
43
+ transform_list : List[TRANSFORMS_UNION]
44
+ A list of dictionaries where each dictionary contains the name of a
45
+ transform and its parameters.
46
+ """
47
+ # retrieve all available transforms
48
+ all_transforms = get_all_transforms()
49
+
50
+ # instantiate all transforms
51
+ transforms = [all_transforms[t.name](**t.model_dump()) for t in transform_list]
52
+
53
+ self._callable_transforms = self._chain_transforms(transforms)
54
+
55
+ def _chain_transforms(self, transforms: List[Transform]) -> Callable:
56
+ """Chain the transforms together.
57
+
58
+ Parameters
59
+ ----------
60
+ transforms : List[Transform]
61
+ A list of transforms to chain together.
62
+
63
+ Returns
64
+ -------
65
+ Callable
66
+ A callable that applies the transforms in order to the input data.
67
+ """
68
+
69
+ def _chain(
70
+ patch: np.ndarray, target: Optional[np.ndarray]
71
+ ) -> Tuple[np.ndarray, ...]:
72
+ params = (patch, target)
73
+
74
+ for t in transforms:
75
+ params = t(*params)
76
+
77
+ return params
78
+
79
+ return _chain
80
+
81
+ def __call__(
82
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
83
+ ) -> Tuple[np.ndarray, ...]:
84
+ """Apply the transforms to the input data.
85
+
86
+ Parameters
87
+ ----------
88
+ patch : np.ndarray
89
+ The input data.
90
+ target : Optional[np.ndarray], optional
91
+ Target data, by default None
92
+
93
+ Returns
94
+ -------
95
+ Tuple[np.ndarray, ...]
96
+ The output of the transformations.
97
+ """
98
+ return self._callable_transforms(patch, target)
@@ -1,19 +1,19 @@
1
1
  from typing import Any, Literal, Optional, Tuple
2
2
 
3
3
  import numpy as np
4
- from albumentations import ImageOnlyTransform
5
4
 
6
5
  from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
6
+ from careamics.transforms.transform import Transform
7
7
 
8
8
  from .pixel_manipulation import median_manipulate, uniform_manipulate
9
9
  from .struct_mask_parameters import StructMaskParameters
10
10
 
11
11
 
12
- class N2VManipulate(ImageOnlyTransform):
12
+ class N2VManipulate(Transform):
13
13
  """
14
14
  Default augmentation for the N2V model.
15
15
 
16
- This transform expects (Z)YXC dimensions.
16
+ This transform expects C(Z)YX dimensions.
17
17
 
18
18
  Parameters
19
19
  ----------
@@ -33,6 +33,7 @@ class N2VManipulate(ImageOnlyTransform):
33
33
  remove_center: bool = True,
34
34
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
35
35
  struct_mask_span: int = 5,
36
+ seed: Optional[int] = None, # TODO use in pixel manipulation
36
37
  ):
37
38
  """Constructor.
38
39
 
@@ -50,8 +51,9 @@ class N2VManipulate(ImageOnlyTransform):
50
51
  StructN2V mask axis, by default "none"
51
52
  struct_mask_span : int, optional
52
53
  StructN2V mask span, by default 5
54
+ seed : Optional[int], optional
55
+ Random seed, by default None
53
56
  """
54
- super().__init__(p=1)
55
57
  self.masked_pixel_percentage = masked_pixel_percentage
56
58
  self.roi_size = roi_size
57
59
  self.strategy = strategy
@@ -65,23 +67,26 @@ class N2VManipulate(ImageOnlyTransform):
65
67
  span=struct_mask_span,
66
68
  )
67
69
 
68
- def apply(
69
- self, patch: np.ndarray, **kwargs: Any
70
+ # numpy random generator
71
+ self.rng = np.random.default_rng(seed=seed)
72
+
73
+ def __call__(
74
+ self, patch: np.ndarray, *args: Any, **kwargs: Any
70
75
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
71
76
  """Apply the transform to the image.
72
77
 
73
78
  Parameters
74
79
  ----------
75
80
  image : np.ndarray
76
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
81
+ Image or image patch, 2D or 3D, shape C(Z)YX.
77
82
  """
78
83
  masked = np.zeros_like(patch)
79
84
  mask = np.zeros_like(patch)
80
85
  if self.strategy == SupportedPixelManipulation.UNIFORM:
81
86
  # Iterate over the channels to apply manipulation separately
82
- for c in range(patch.shape[-1]):
83
- masked[..., c], mask[..., c] = uniform_manipulate(
84
- patch=patch[..., c],
87
+ for c in range(patch.shape[0]):
88
+ masked[c, ...], mask[c, ...] = uniform_manipulate(
89
+ patch=patch[c, ...],
85
90
  mask_pixel_percentage=self.masked_pixel_percentage,
86
91
  subpatch_size=self.roi_size,
87
92
  remove_center=self.remove_center,
@@ -89,9 +94,9 @@ class N2VManipulate(ImageOnlyTransform):
89
94
  )
90
95
  elif self.strategy == SupportedPixelManipulation.MEDIAN:
91
96
  # Iterate over the channels to apply manipulation separately
92
- for c in range(patch.shape[-1]):
93
- masked[..., c], mask[..., c] = median_manipulate(
94
- patch=patch[..., c],
97
+ for c in range(patch.shape[0]):
98
+ masked[c, ...], mask[c, ...] = median_manipulate(
99
+ patch=patch[c, ...],
95
100
  mask_pixel_percentage=self.masked_pixel_percentage,
96
101
  subpatch_size=self.roi_size,
97
102
  struct_params=self.struct_mask,
@@ -101,13 +106,3 @@ class N2VManipulate(ImageOnlyTransform):
101
106
 
102
107
  # TODO why return patch?
103
108
  return masked, patch, mask
104
-
105
- def get_transform_init_args_names(self) -> Tuple[str, ...]:
106
- """Get the transform parameters.
107
-
108
- Returns
109
- -------
110
- Tuple[str, ...]
111
- Transform parameters.
112
- """
113
- return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask")
@@ -1,93 +1,67 @@
1
- from typing import Any, Dict, Tuple
1
+ from typing import Optional, Tuple
2
2
 
3
3
  import numpy as np
4
- from albumentations import DualTransform
5
4
 
5
+ from careamics.transforms.transform import Transform
6
6
 
7
- class NDFlip(DualTransform):
7
+
8
+ class NDFlip(Transform):
8
9
  """Flip ND arrays on a single axis.
9
10
 
10
11
  This transform ignores singleton axes and randomly flips one of the other
11
- axes, to the exception of the first and last axes (sample and channels).
12
+ last two axes.
12
13
 
13
- This transform expects (Z)YXC dimensions.
14
+ This transform expects C(Z)YX dimensions.
14
15
  """
15
16
 
16
- def __init__(self, p: float = 0.5, is_3D: bool = False, flip_z: bool = True):
17
+ def __init__(self, seed: Optional[int] = None):
17
18
  """Constructor.
18
19
 
19
20
  Parameters
20
21
  ----------
21
- p : float, optional
22
- Probability to apply the transform, by default 0.5
23
- is_3D : bool, optional
24
- Whether the data is 3D, by default False
25
- flip_z : bool, optional
26
- Whether to flip Z dimension, by default True
22
+ seed : Optional[int], optional
23
+ Random seed, by default None
27
24
  """
28
- super().__init__(p=p)
29
-
30
- self.is_3D = is_3D
31
- self.flip_z = flip_z
32
-
33
25
  # "flippable" axes
34
- if is_3D:
35
- self.axis_indices = [0, 1, 2] if flip_z else [1, 2]
36
- else:
37
- self.axis_indices = [0, 1]
26
+ self.axis_indices = [-2, -1]
38
27
 
39
- def get_params(self, **kwargs: Any) -> Dict[str, int]:
40
- """Get the transform parameters.
41
-
42
- Returns
43
- -------
44
- Dict[str, int]
45
- Transform parameters.
46
- """
47
- return {"flip_axis": np.random.choice(self.axis_indices)}
28
+ # numpy random generator
29
+ self.rng = np.random.default_rng(seed=seed)
48
30
 
49
- def apply(self, patch: np.ndarray, flip_axis: int, **kwargs: Any) -> np.ndarray:
50
- """Apply the transform to the image.
31
+ def __call__(
32
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
33
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
34
+ """Apply the transform to the source patch and the target (optional).
51
35
 
52
36
  Parameters
53
37
  ----------
54
38
  patch : np.ndarray
55
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
56
- flip_axis : int
57
- Axis along which to flip the patch.
39
+ Patch, 2D or 3D, shape C(Z)YX.
40
+ target : Optional[np.ndarray], optional
41
+ Target for the patch, by default None
42
+
43
+ Returns
44
+ -------
45
+ Tuple[np.ndarray, Optional[np.ndarray]]
46
+ Transformed patch and target.
58
47
  """
59
- if len(patch.shape) == 3 and self.is_3D:
60
- raise ValueError(
61
- "Incompatible patch shape and dimensionality. ZYXC patch shape "
62
- "expected, but got YXC shape."
63
- )
48
+ # choose an axis to flip
49
+ axis = self.rng.choice(self.axis_indices)
50
+
51
+ patch_transformed = self._apply(patch, axis)
52
+ target_transformed = self._apply(target, axis) if target is not None else None
64
53
 
65
- return np.ascontiguousarray(np.flip(patch, axis=flip_axis))
54
+ return patch_transformed, target_transformed
66
55
 
67
- def apply_to_mask(
68
- self, mask: np.ndarray, flip_axis: int, **kwargs: Any
69
- ) -> np.ndarray:
70
- """Apply the transform to the mask.
56
+ def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
57
+ """Apply the transform to the image.
71
58
 
72
59
  Parameters
73
60
  ----------
74
- mask : np.ndarray
75
- Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
76
- """
77
- if len(mask.shape) == 3 and self.is_3D:
78
- raise ValueError(
79
- "Incompatible mask shape and dimensionality. ZYXC patch shape "
80
- "expected, but got YXC shape."
81
- )
82
-
83
- return np.ascontiguousarray(np.flip(mask, axis=flip_axis))
84
-
85
- def get_transform_init_args_names(self, **kwargs: Any) -> Tuple[str, ...]:
86
- """Get the transform arguments names.
87
-
88
- Returns
89
- -------
90
- Tuple[str, ...]
91
- Transform arguments names.
61
+ patch : np.ndarray
62
+ Image or image patch, 2D or 3D, shape C(Z)YX.
63
+ axis : int
64
+ Axis to flip.
92
65
  """
93
- return ("is_3D", "flip_z")
66
+ # TODO why ascontiguousarray?
67
+ return np.ascontiguousarray(np.flip(patch, axis=axis))