careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (81) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +80 -44
  4. careamics/config/algorithm_model.py +5 -3
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +8 -1
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -2
  12. careamics/config/configuration_factory.py +4 -16
  13. careamics/config/data_model.py +10 -14
  14. careamics/config/inference_model.py +0 -65
  15. careamics/config/optimizer_models.py +4 -4
  16. careamics/config/support/__init__.py +0 -2
  17. careamics/config/support/supported_activations.py +2 -0
  18. careamics/config/support/supported_algorithms.py +3 -1
  19. careamics/config/support/supported_architectures.py +2 -0
  20. careamics/config/support/supported_data.py +2 -0
  21. careamics/config/support/supported_loggers.py +2 -0
  22. careamics/config/support/supported_losses.py +2 -0
  23. careamics/config/support/supported_optimizers.py +2 -0
  24. careamics/config/support/supported_pixel_manipulations.py +3 -3
  25. careamics/config/support/supported_struct_axis.py +2 -0
  26. careamics/config/support/supported_transforms.py +4 -15
  27. careamics/config/tile_information.py +2 -0
  28. careamics/config/transformations/__init__.py +3 -2
  29. careamics/config/transformations/xy_flip_model.py +43 -0
  30. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  31. careamics/conftest.py +12 -0
  32. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  33. careamics/dataset/dataset_utils/file_utils.py +4 -3
  34. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  35. careamics/dataset/dataset_utils/read_utils.py +2 -0
  36. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  37. careamics/dataset/in_memory_dataset.py +71 -32
  38. careamics/dataset/iterable_dataset.py +155 -68
  39. careamics/dataset/patching/patching.py +56 -15
  40. careamics/dataset/patching/random_patching.py +8 -2
  41. careamics/dataset/patching/sequential_patching.py +14 -8
  42. careamics/dataset/patching/tiled_patching.py +3 -1
  43. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  44. careamics/dataset/zarr_dataset.py +2 -0
  45. careamics/lightning_datamodule.py +45 -19
  46. careamics/lightning_module.py +8 -2
  47. careamics/lightning_prediction_datamodule.py +3 -13
  48. careamics/lightning_prediction_loop.py +8 -6
  49. careamics/losses/__init__.py +2 -3
  50. careamics/losses/loss_factory.py +1 -1
  51. careamics/losses/losses.py +11 -7
  52. careamics/model_io/bmz_io.py +3 -3
  53. careamics/models/activation.py +2 -0
  54. careamics/models/layers.py +121 -25
  55. careamics/models/model_factory.py +1 -1
  56. careamics/models/unet.py +35 -14
  57. careamics/prediction/stitch_prediction.py +2 -6
  58. careamics/transforms/__init__.py +2 -2
  59. careamics/transforms/compose.py +33 -7
  60. careamics/transforms/n2v_manipulate.py +49 -13
  61. careamics/transforms/normalize.py +55 -3
  62. careamics/transforms/pixel_manipulation.py +5 -5
  63. careamics/transforms/struct_mask_parameters.py +3 -1
  64. careamics/transforms/transform.py +10 -19
  65. careamics/transforms/xy_flip.py +123 -0
  66. careamics/transforms/xy_random_rotate90.py +38 -5
  67. careamics/utils/base_enum.py +28 -0
  68. careamics/utils/path_utils.py +2 -0
  69. careamics/utils/ram.py +2 -0
  70. careamics/utils/receptive_field.py +93 -87
  71. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
  72. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  73. careamics/config/noise_models.py +0 -162
  74. careamics/config/support/supported_extraction_strategies.py +0 -25
  75. careamics/config/transformations/nd_flip_model.py +0 -27
  76. careamics/losses/noise_model_factory.py +0 -40
  77. careamics/losses/noise_models.py +0 -524
  78. careamics/transforms/nd_flip.py +0 -67
  79. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  80. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  81. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py CHANGED
@@ -34,7 +34,9 @@ class UnetEncoder(nn.Module):
34
34
  Dropout probability, by default 0.0.
35
35
  pool_kernel : int, optional
36
36
  Kernel size for the max pooling layers, by default 2.
37
- groups: int, optional
37
+ n2v2 : bool, optional
38
+ Whether to use N2V2 architecture, by default False.
39
+ groups : int, optional
38
40
  Number of blocked connections from input channels to output
39
41
  channels, by default 1.
40
42
  """
@@ -70,7 +72,9 @@ class UnetEncoder(nn.Module):
70
72
  Dropout probability, by default 0.0.
71
73
  pool_kernel : int, optional
72
74
  Kernel size for the max pooling layers, by default 2.
73
- groups: int, optional
75
+ n2v2 : bool, optional
76
+ Whether to use N2V2 architecture, by default False.
77
+ groups : int, optional
74
78
  Number of blocked connections from input channels to output
75
79
  channels, by default 1.
76
80
  """
@@ -140,7 +144,9 @@ class UnetDecoder(nn.Module):
140
144
  Whether to use batch normalization, by default True.
141
145
  dropout : float, optional
142
146
  Dropout probability, by default 0.0.
143
- groups: int, optional
147
+ n2v2 : bool, optional
148
+ Whether to use N2V2 architecture, by default False.
149
+ groups : int, optional
144
150
  Number of blocked connections from input channels to output
145
151
  channels, by default 1.
146
152
  """
@@ -170,7 +176,9 @@ class UnetDecoder(nn.Module):
170
176
  Whether to use batch normalization, by default True.
171
177
  dropout : float, optional
172
178
  Dropout probability, by default 0.0.
173
- groups: int, optional
179
+ n2v2 : bool, optional
180
+ Whether to use N2V2 architecture, by default False.
181
+ groups : int, optional
174
182
  Number of blocked connections from input channels to output
175
183
  channels, by default 1.
176
184
  """
@@ -250,22 +258,25 @@ class UnetDecoder(nn.Module):
250
258
 
251
259
  @staticmethod
252
260
  def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
253
- """
254
- Splits the tensors `A` and `B` into equally sized groups along the
255
- channel axis (axis=1); then concatenates the groups in alternating
256
- order along the channel axis, starting with the first group from tensor
257
- A.
261
+ """Interleave two tensors.
262
+
263
+ Splits the tensors `A` and `B` into equally sized groups along the channel
264
+ axis (axis=1); then concatenates the groups in alternating order along the
265
+ channel axis, starting with the first group from tensor A.
258
266
 
259
267
  Parameters
260
268
  ----------
261
- A: torch.Tensor
262
- B: torch.Tensor
263
- groups: int
269
+ A : torch.Tensor
270
+ First tensor.
271
+ B : torch.Tensor
272
+ Second tensor.
273
+ groups : int
264
274
  The number of groups.
265
275
 
266
276
  Returns
267
277
  -------
268
278
  torch.Tensor
279
+ Interleaved tensor.
269
280
 
270
281
  Raises
271
282
  ------
@@ -322,8 +333,14 @@ class UNet(nn.Module):
322
333
  Dropout probability, by default 0.0.
323
334
  pool_kernel : int, optional
324
335
  Kernel size of the pooling layers, by default 2.
325
- last_activation : Optional[Callable], optional
336
+ final_activation : Optional[Callable], optional
326
337
  Activation function to use for the last layer, by default None.
338
+ n2v2 : bool, optional
339
+ Whether to use N2V2 architecture, by default False.
340
+ independent_channels : bool
341
+ Whether to train the channels independently, by default True.
342
+ **kwargs : Any
343
+ Additional keyword arguments, unused.
327
344
  """
328
345
 
329
346
  def __init__(
@@ -362,11 +379,15 @@ class UNet(nn.Module):
362
379
  Dropout probability, by default 0.0.
363
380
  pool_kernel : int, optional
364
381
  Kernel size of the pooling layers, by default 2.
365
- last_activation : Optional[Callable], optional
382
+ final_activation : Optional[Callable], optional
366
383
  Activation function to use for the last layer, by default None.
384
+ n2v2 : bool, optional
385
+ Whether to use N2V2 architecture, by default False.
367
386
  independent_channels : bool
368
387
  Whether to train parallel independent networks for each channel, by
369
388
  default True.
389
+ **kwargs : Any
390
+ Additional keyword arguments, unused.
370
391
  """
371
392
  super().__init__()
372
393
 
@@ -1,8 +1,4 @@
1
- """
2
- Prediction convenience functions.
3
-
4
- These functions are used during prediction.
5
- """
1
+ """Prediction utility functions."""
6
2
 
7
3
  from typing import List
8
4
 
@@ -21,7 +17,7 @@ def stitch_prediction(
21
17
  ----------
22
18
  tiles : List[torch.Tensor]
23
19
  Cropped tiles and their respective stitching coordinates.
24
- stitching_coords : List
20
+ stitching_data : List
25
21
  List of information and coordinates obtained from
26
22
  `dataset.tiled_patching.extract_tiles`.
27
23
 
@@ -3,7 +3,7 @@
3
3
  __all__ = [
4
4
  "get_all_transforms",
5
5
  "N2VManipulate",
6
- "NDFlip",
6
+ "XYFlip",
7
7
  "XYRandomRotate90",
8
8
  "ImageRestorationTTA",
9
9
  "Denormalize",
@@ -14,7 +14,7 @@ __all__ = [
14
14
 
15
15
  from .compose import Compose, get_all_transforms
16
16
  from .n2v_manipulate import N2VManipulate
17
- from .nd_flip import NDFlip
18
17
  from .normalize import Denormalize, Normalize
19
18
  from .tta import ImageRestorationTTA
19
+ from .xy_flip import XYFlip
20
20
  from .xy_random_rotate90 import XYRandomRotate90
@@ -1,26 +1,26 @@
1
1
  """A class chaining transforms together."""
2
2
 
3
- from typing import Callable, List, Optional, Tuple
3
+ from typing import Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
 
7
7
  from careamics.config.data_model import TRANSFORMS_UNION
8
8
 
9
9
  from .n2v_manipulate import N2VManipulate
10
- from .nd_flip import NDFlip
11
10
  from .normalize import Normalize
12
11
  from .transform import Transform
12
+ from .xy_flip import XYFlip
13
13
  from .xy_random_rotate90 import XYRandomRotate90
14
14
 
15
15
  ALL_TRANSFORMS = {
16
16
  "Normalize": Normalize,
17
17
  "N2VManipulate": N2VManipulate,
18
- "NDFlip": NDFlip,
18
+ "XYFlip": XYFlip,
19
19
  "XYRandomRotate90": XYRandomRotate90,
20
20
  }
21
21
 
22
22
 
23
- def get_all_transforms() -> dict:
23
+ def get_all_transforms() -> Dict[str, type]:
24
24
  """Return all the transforms accepted by CAREamics.
25
25
 
26
26
  Returns
@@ -33,7 +33,19 @@ def get_all_transforms() -> dict:
33
33
 
34
34
 
35
35
  class Compose:
36
- """A class chaining transforms together."""
36
+ """A class chaining transforms together.
37
+
38
+ Parameters
39
+ ----------
40
+ transform_list : List[TRANSFORMS_UNION]
41
+ A list of dictionaries where each dictionary contains the name of a
42
+ transform and its parameters.
43
+
44
+ Attributes
45
+ ----------
46
+ _callable_transforms : Callable
47
+ A callable that applies the transforms to the input data.
48
+ """
37
49
 
38
50
  def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
39
51
  """Instantiate a Compose object.
@@ -68,7 +80,21 @@ class Compose:
68
80
 
69
81
  def _chain(
70
82
  patch: np.ndarray, target: Optional[np.ndarray]
71
- ) -> Tuple[np.ndarray, ...]:
83
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
84
+ """Chain transforms on the input data.
85
+
86
+ Parameters
87
+ ----------
88
+ patch : np.ndarray
89
+ Input data.
90
+ target : Optional[np.ndarray]
91
+ Target data, by default None.
92
+
93
+ Returns
94
+ -------
95
+ Tuple[np.ndarray, Optional[np.ndarray]]
96
+ The output of the transformations.
97
+ """
72
98
  params = (patch, target)
73
99
 
74
100
  for t in transforms:
@@ -88,7 +114,7 @@ class Compose:
88
114
  patch : np.ndarray
89
115
  The input data.
90
116
  target : Optional[np.ndarray], optional
91
- Target data, by default None
117
+ Target data, by default None.
92
118
 
93
119
  Returns
94
120
  -------
@@ -1,3 +1,5 @@
1
+ """N2V manipulation transform."""
2
+
1
3
  from typing import Any, Literal, Optional, Tuple
2
4
 
3
5
  import numpy as np
@@ -17,10 +19,35 @@ class N2VManipulate(Transform):
17
19
 
18
20
  Parameters
19
21
  ----------
20
- mask_pixel_percentage : float
21
- Approximate percentage of pixels to be masked.
22
+ roi_size : int, optional
23
+ Size of the replacement area, by default 11.
24
+ masked_pixel_percentage : float, optional
25
+ Percentage of pixels to mask, by default 0.2.
26
+ strategy : Literal[ "uniform", "median" ], optional
27
+ Replaccement strategy, uniform or median, by default uniform.
28
+ remove_center : bool, optional
29
+ Whether to remove central pixel from patch, by default True.
30
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
31
+ StructN2V mask axis, by default "none".
32
+ struct_mask_span : int, optional
33
+ StructN2V mask span, by default 5.
34
+ seed : Optional[int], optional
35
+ Random seed, by default None.
36
+
37
+ Attributes
38
+ ----------
39
+ masked_pixel_percentage : float
40
+ Percentage of pixels to mask.
22
41
  roi_size : int
23
- Size of the ROI the new pixel value is sampled from, by default 11.
42
+ Size of the replacement area.
43
+ strategy : Literal[ "uniform", "median" ]
44
+ Replaccement strategy, uniform or median.
45
+ remove_center : bool
46
+ Whether to remove central pixel from patch.
47
+ struct_mask : Optional[StructMaskParameters]
48
+ StructN2V mask parameters.
49
+ rng : Generator
50
+ Random number generator.
24
51
  """
25
52
 
26
53
  def __init__(
@@ -40,24 +67,24 @@ class N2VManipulate(Transform):
40
67
  Parameters
41
68
  ----------
42
69
  roi_size : int, optional
43
- Size of the replacement area, by default 11
70
+ Size of the replacement area, by default 11.
44
71
  masked_pixel_percentage : float, optional
45
- Percentage of pixels to mask, by default 0.2
72
+ Percentage of pixels to mask, by default 0.2.
46
73
  strategy : Literal[ "uniform", "median" ], optional
47
- Replaccement strategy, uniform or median, by default uniform
74
+ Replaccement strategy, uniform or median, by default uniform.
48
75
  remove_center : bool, optional
49
- Whether to remove central pixel from patch, by default True
76
+ Whether to remove central pixel from patch, by default True.
50
77
  struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
51
- StructN2V mask axis, by default "none"
78
+ StructN2V mask axis, by default "none".
52
79
  struct_mask_span : int, optional
53
- StructN2V mask span, by default 5
80
+ StructN2V mask span, by default 5.
54
81
  seed : Optional[int], optional
55
- Random seed, by default None
82
+ Random seed, by default None.
56
83
  """
57
84
  self.masked_pixel_percentage = masked_pixel_percentage
58
85
  self.roi_size = roi_size
59
86
  self.strategy = strategy
60
- self.remove_center = remove_center
87
+ self.remove_center = remove_center # TODO is this ever used?
61
88
 
62
89
  if struct_mask_axis == SupportedStructAxis.NONE:
63
90
  self.struct_mask: Optional[StructMaskParameters] = None
@@ -77,8 +104,17 @@ class N2VManipulate(Transform):
77
104
 
78
105
  Parameters
79
106
  ----------
80
- image : np.ndarray
81
- Image or image patch, 2D or 3D, shape C(Z)YX.
107
+ patch : np.ndarray
108
+ Image patch, 2D or 3D, shape C(Z)YX.
109
+ *args : Any
110
+ Additional arguments, unused.
111
+ **kwargs : Any
112
+ Additional keyword arguments, unused.
113
+
114
+ Returns
115
+ -------
116
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
117
+ Masked patch, original patch, and mask.
82
118
  """
83
119
  masked = np.zeros_like(patch)
84
120
  mask = np.zeros_like(patch)
@@ -1,3 +1,5 @@
1
+ """Normalization and denormalization transforms for image patches."""
2
+
1
3
  from typing import Optional, Tuple
2
4
 
3
5
  import numpy as np
@@ -15,6 +17,13 @@ class Normalize(Transform):
15
17
  Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
16
18
  division by zero and that it returns a float32 image.
17
19
 
20
+ Parameters
21
+ ----------
22
+ mean : float
23
+ Mean value.
24
+ std : float
25
+ Standard deviation value.
26
+
18
27
  Attributes
19
28
  ----------
20
29
  mean : float
@@ -28,6 +37,15 @@ class Normalize(Transform):
28
37
  mean: float,
29
38
  std: float,
30
39
  ):
40
+ """Constructor.
41
+
42
+ Parameters
43
+ ----------
44
+ mean : float
45
+ Mean value.
46
+ std : float
47
+ Standard deviation value.
48
+ """
31
49
  self.mean = mean
32
50
  self.std = std
33
51
  self.eps = 1e-6
@@ -42,7 +60,7 @@ class Normalize(Transform):
42
60
  patch : np.ndarray
43
61
  Patch, 2D or 3D, shape C(Z)YX.
44
62
  target : Optional[np.ndarray], optional
45
- Target for the patch, by default None
63
+ Target for the patch, by default None.
46
64
 
47
65
  Returns
48
66
  -------
@@ -55,6 +73,19 @@ class Normalize(Transform):
55
73
  return norm_patch, norm_target
56
74
 
57
75
  def _apply(self, patch: np.ndarray) -> np.ndarray:
76
+ """
77
+ Apply the transform to the image.
78
+
79
+ Parameters
80
+ ----------
81
+ patch : np.ndarray
82
+ Image patch, 2D or 3D, shape C(Z)YX.
83
+
84
+ Returns
85
+ -------
86
+ np.ndarray
87
+ Normalizedimage patch.
88
+ """
58
89
  return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
59
90
 
60
91
 
@@ -69,6 +100,13 @@ class Denormalize:
69
100
  division by zero during the normalization step, which is taken into account during
70
101
  denormalization.
71
102
 
103
+ Parameters
104
+ ----------
105
+ mean : float
106
+ Mean value.
107
+ std : float
108
+ Standard deviation value.
109
+
72
110
  Attributes
73
111
  ----------
74
112
  mean : float
@@ -82,6 +120,15 @@ class Denormalize:
82
120
  mean: float,
83
121
  std: float,
84
122
  ):
123
+ """Constructor.
124
+
125
+ Parameters
126
+ ----------
127
+ mean : float
128
+ Mean.
129
+ std : float
130
+ Standard deviation.
131
+ """
85
132
  self.mean = mean
86
133
  self.std = std
87
134
  self.eps = 1e-6
@@ -96,7 +143,7 @@ class Denormalize:
96
143
  patch : np.ndarray
97
144
  Patch, 2D or 3D, shape C(Z)YX.
98
145
  target : Optional[np.ndarray], optional
99
- Target for the patch, by default None
146
+ Target for the patch, by default None.
100
147
 
101
148
  Returns
102
149
  -------
@@ -115,6 +162,11 @@ class Denormalize:
115
162
  Parameters
116
163
  ----------
117
164
  patch : np.ndarray
118
- Image or image patch, 2D or 3D, shape C(Z)YX.
165
+ Image patch, 2D or 3D, shape C(Z)YX.
166
+
167
+ Returns
168
+ -------
169
+ np.ndarray
170
+ Denormalized image patch.
119
171
  """
120
172
  return patch * (self.std + self.eps) + self.mean
@@ -5,7 +5,7 @@ Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
5
  masked pixels.
6
6
  """
7
7
 
8
- from typing import Optional, Tuple, Union
8
+ from typing import Optional, Tuple
9
9
 
10
10
  import numpy as np
11
11
 
@@ -15,7 +15,7 @@ from .struct_mask_parameters import StructMaskParameters
15
15
  def _apply_struct_mask(
16
16
  patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
17
17
  ) -> np.ndarray:
18
- """Applies structN2V masks to patch.
18
+ """Apply structN2V masks to patch.
19
19
 
20
20
  Each point in `coords` corresponds to the center of a mask, masks are paremeterized
21
21
  by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
@@ -98,7 +98,7 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
98
98
 
99
99
 
100
100
  def _get_stratified_coords(
101
- mask_pixel_perc: float, shape: Union[Tuple[int, int], Tuple[int, int, int]]
101
+ mask_pixel_perc: float, shape: Tuple[int, ...]
102
102
  ) -> np.ndarray:
103
103
  """
104
104
  Generate coordinates of the pixels to mask.
@@ -248,7 +248,7 @@ def uniform_manipulate(
248
248
  Size of the subpatch the new pixel value is sampled from, by default 11.
249
249
  remove_center : bool
250
250
  Whether to remove the center pixel from the subpatch, by default False.
251
- struct_params: Optional[StructMaskParameters]
251
+ struct_params : Optional[StructMaskParameters]
252
252
  Parameters for the structN2V mask (axis and span).
253
253
 
254
254
  Returns
@@ -322,7 +322,7 @@ def median_manipulate(
322
322
  Approximate percentage of pixels to be masked.
323
323
  subpatch_size : int
324
324
  Size of the subpatch the new pixel value is sampled from, by default 11.
325
- struct_params: Optional[StructMaskParameters]
325
+ struct_params : Optional[StructMaskParameters]
326
326
  Parameters for the structN2V mask (axis and span).
327
327
 
328
328
  Returns
@@ -1,3 +1,5 @@
1
+ """Class representing the parameters of structN2V masks."""
2
+
1
3
  from dataclasses import dataclass
2
4
  from typing import Literal
3
5
 
@@ -6,7 +8,7 @@ from typing import Literal
6
8
  class StructMaskParameters:
7
9
  """Parameters of structN2V masks.
8
10
 
9
- Parameters
11
+ Attributes
10
12
  ----------
11
13
  axis : Literal[0, 1]
12
14
  Axis along which to apply the mask, horizontal (0) or vertical (1).
@@ -1,33 +1,24 @@
1
1
  """A general parent class for transforms."""
2
2
 
3
- from typing import Optional, Tuple
4
-
5
- import numpy as np
3
+ from typing import Any
6
4
 
7
5
 
8
6
  class Transform:
9
7
  """A general parent class for transforms."""
10
8
 
11
- def __call__(
12
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
13
- ) -> Tuple[np.ndarray, ...]:
14
- """Apply the transform to the input data.
9
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
10
+ """Apply the transform.
15
11
 
16
12
  Parameters
17
13
  ----------
18
- patch : np.ndarray
19
- The input data to transform.
20
- target : Optional[np.ndarray], optional
21
- The target data to transform, by default None
14
+ *args : Any
15
+ Arguments.
16
+ **kwargs : Any
17
+ Keyword arguments.
22
18
 
23
19
  Returns
24
20
  -------
25
- Tuple[np.ndarray, ...]
26
- The output of the transformations.
27
-
28
- Raises
29
- ------
30
- NotImplementedError
31
- This method should be implemented in the child class.
21
+ Any
22
+ Transformed data.
32
23
  """
33
- raise NotImplementedError
24
+ pass
@@ -0,0 +1,123 @@
1
+ """XY flip transform."""
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.transforms.transform import Transform
8
+
9
+
10
+ class XYFlip(Transform):
11
+ """Flip image along X and Y axis, one at a time.
12
+
13
+ This transform randomly flips one of the last two axes.
14
+
15
+ This transform expects C(Z)YX dimensions.
16
+
17
+ Attributes
18
+ ----------
19
+ axis_indices : List[int]
20
+ Indices of the axes that can be flipped.
21
+ rng : np.random.Generator
22
+ Random number generator.
23
+ p : float
24
+ Probability of applying the transform.
25
+ seed : Optional[int]
26
+ Random seed.
27
+
28
+ Parameters
29
+ ----------
30
+ flip_x : bool, optional
31
+ Whether to flip along the X axis, by default True.
32
+ flip_y : bool, optional
33
+ Whether to flip along the Y axis, by default True.
34
+ p : float, optional
35
+ Probability of applying the transform, by default 0.5.
36
+ seed : Optional[int], optional
37
+ Random seed, by default None.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ flip_x: bool = True,
43
+ flip_y: bool = True,
44
+ p: float = 0.5,
45
+ seed: Optional[int] = None,
46
+ ) -> None:
47
+ """Constructor.
48
+
49
+ Parameters
50
+ ----------
51
+ flip_x : bool, optional
52
+ Whether to flip along the X axis, by default True.
53
+ flip_y : bool, optional
54
+ Whether to flip along the Y axis, by default True.
55
+ p : float
56
+ Probability of applying the transform, by default 0.5.
57
+ seed : Optional[int], optional
58
+ Random seed, by default None.
59
+ """
60
+ if p < 0 or p > 1:
61
+ raise ValueError("Probability must be in [0, 1].")
62
+
63
+ if not flip_x and not flip_y:
64
+ raise ValueError("At least one axis must be flippable.")
65
+
66
+ # probability to apply the transform
67
+ self.p = p
68
+
69
+ # "flippable" axes
70
+ self.axis_indices = []
71
+
72
+ if flip_y:
73
+ self.axis_indices.append(-2)
74
+ if flip_x:
75
+ self.axis_indices.append(-1)
76
+
77
+ # numpy random generator
78
+ self.rng = np.random.default_rng(seed=seed)
79
+
80
+ def __call__(
81
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
82
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
83
+ """Apply the transform to the source patch and the target (optional).
84
+
85
+ Parameters
86
+ ----------
87
+ patch : np.ndarray
88
+ Patch, 2D or 3D, shape C(Z)YX.
89
+ target : Optional[np.ndarray], optional
90
+ Target for the patch, by default None.
91
+
92
+ Returns
93
+ -------
94
+ Tuple[np.ndarray, Optional[np.ndarray]]
95
+ Transformed patch and target.
96
+ """
97
+ if self.rng.random() > self.p:
98
+ return patch, target
99
+
100
+ # choose an axis to flip
101
+ axis = self.rng.choice(self.axis_indices)
102
+
103
+ patch_transformed = self._apply(patch, axis)
104
+ target_transformed = self._apply(target, axis) if target is not None else None
105
+
106
+ return patch_transformed, target_transformed
107
+
108
+ def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
109
+ """Apply the transform to the image.
110
+
111
+ Parameters
112
+ ----------
113
+ patch : np.ndarray
114
+ Image patch, 2D or 3D, shape C(Z)YX.
115
+ axis : int
116
+ Axis to flip.
117
+
118
+ Returns
119
+ -------
120
+ np.ndarray
121
+ Flipped image patch.
122
+ """
123
+ return np.ascontiguousarray(np.flip(patch, axis=axis))