careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,34 @@
1
- from typing import Optional, Tuple
1
+ """Normalization and denormalization transforms for image patches."""
2
+
3
+ from typing import Optional
2
4
 
3
5
  import numpy as np
6
+ from numpy.typing import NDArray
4
7
 
5
8
  from careamics.transforms.transform import Transform
6
9
 
7
10
 
11
+ def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
12
+ """Reshape stats to match the number of dimensions of the input image.
13
+
14
+ This allows to broadcast the stats (mean or std) to the image dimensions, and
15
+ thus directly perform a vectorial calculation.
16
+
17
+ Parameters
18
+ ----------
19
+ stats : list of float
20
+ List of stats, mean or standard deviation.
21
+ ndim : int
22
+ Number of dimensions of the image, including the C channel.
23
+
24
+ Returns
25
+ -------
26
+ NDArray
27
+ Reshaped stats.
28
+ """
29
+ return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
30
+
31
+
8
32
  class Normalize(Transform):
9
33
  """
10
34
  Normalize an image or image patch.
@@ -15,106 +39,205 @@ class Normalize(Transform):
15
39
  Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
16
40
  division by zero and that it returns a float32 image.
17
41
 
42
+ Parameters
43
+ ----------
44
+ image_means : list of float
45
+ Mean value per channel.
46
+ image_stds : list of float
47
+ Standard deviation value per channel.
48
+ target_means : list of float, optional
49
+ Target mean value per channel, by default None.
50
+ target_stds : list of float, optional
51
+ Target standard deviation value per channel, by default None.
52
+
18
53
  Attributes
19
54
  ----------
20
- mean : float
21
- Mean value.
22
- std : float
23
- Standard deviation value.
55
+ image_means : list of float
56
+ Mean value per channel.
57
+ image_stds : list of float
58
+ Standard deviation value per channel.
59
+ target_means :list of float, optional
60
+ Target mean value per channel, by default None.
61
+ target_stds : list of float, optional
62
+ Target standard deviation value per channel, by default None.
24
63
  """
25
64
 
26
65
  def __init__(
27
66
  self,
28
- mean: float,
29
- std: float,
67
+ image_means: list[float],
68
+ image_stds: list[float],
69
+ target_means: Optional[list[float]] = None,
70
+ target_stds: Optional[list[float]] = None,
30
71
  ):
31
- self.mean = mean
32
- self.std = std
72
+ """Constructor.
73
+
74
+ Parameters
75
+ ----------
76
+ image_means : list of float
77
+ Mean value per channel.
78
+ image_stds : list of float
79
+ Standard deviation value per channel.
80
+ target_means : list of float, optional
81
+ Target mean value per channel, by default None.
82
+ target_stds : list of float, optional
83
+ Target standard deviation value per channel, by default None.
84
+ """
85
+ self.image_means = image_means
86
+ self.image_stds = image_stds
87
+ self.target_means = target_means
88
+ self.target_stds = target_stds
89
+
33
90
  self.eps = 1e-6
34
91
 
35
92
  def __call__(
36
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
37
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
93
+ self, patch: np.ndarray, target: Optional[NDArray] = None
94
+ ) -> tuple[NDArray, Optional[NDArray]]:
38
95
  """Apply the transform to the source patch and the target (optional).
39
96
 
40
97
  Parameters
41
98
  ----------
42
- patch : np.ndarray
99
+ patch : NDArray
43
100
  Patch, 2D or 3D, shape C(Z)YX.
44
- target : Optional[np.ndarray], optional
45
- Target for the patch, by default None
101
+ target : NDArray, optional
102
+ Target for the patch, by default None.
46
103
 
47
104
  Returns
48
105
  -------
49
- Tuple[np.ndarray, Optional[np.ndarray]]
50
- Transformed patch and target.
106
+ tuple of NDArray
107
+ Transformed patch and target, the target can be returned as `None`.
51
108
  """
52
- norm_patch = self._apply(patch)
53
- norm_target = self._apply(target) if target is not None else None
109
+ if len(self.image_means) != patch.shape[0]:
110
+ raise ValueError(
111
+ f"Number of means (got a list of size {len(self.image_means)}) and "
112
+ f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
113
+ )
114
+
115
+ # reshape mean and std and apply the normalization to the patch
116
+ means = _reshape_stats(self.image_means, patch.ndim)
117
+ stds = _reshape_stats(self.image_stds, patch.ndim)
118
+ norm_patch = self._apply(patch, means, stds)
119
+
120
+ # same for the target patch
121
+ if (
122
+ target is not None
123
+ and self.target_means is not None
124
+ and self.target_stds is not None
125
+ ):
126
+ target_means = _reshape_stats(self.target_means, target.ndim)
127
+ target_stds = _reshape_stats(self.target_stds, target.ndim)
128
+ norm_target = self._apply(target, target_means, target_stds)
129
+ else:
130
+ norm_target = None
54
131
 
55
132
  return norm_patch, norm_target
56
133
 
57
- def _apply(self, patch: np.ndarray) -> np.ndarray:
58
- return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
134
+ def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
135
+ """
136
+ Apply the transform to the image.
137
+
138
+ Parameters
139
+ ----------
140
+ patch : NDArray
141
+ Image patch, 2D or 3D, shape C(Z)YX.
142
+ mean : NDArray
143
+ Mean values.
144
+ std : NDArray
145
+ Standard deviations.
146
+
147
+ Returns
148
+ -------
149
+ NDArray
150
+ Normalized image patch.
151
+ """
152
+ return ((patch - mean) / (std + self.eps)).astype(np.float32)
59
153
 
60
154
 
61
155
  class Denormalize:
62
156
  """
63
- Denormalize an image or image patch.
157
+ Denormalize an image.
64
158
 
65
159
  Denormalization is performed expecting a zero mean and unit variance input. This
66
160
  transform expects C(Z)YX dimensions.
67
161
 
68
- Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
162
+ Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
69
163
  division by zero during the normalization step, which is taken into account during
70
164
  denormalization.
71
165
 
72
- Attributes
166
+ Parameters
73
167
  ----------
74
- mean : float
75
- Mean value.
76
- std : float
77
- Standard deviation value.
168
+ image_means : list or tuple of float
169
+ Mean value per channel.
170
+ image_stds : list or tuple of float
171
+ Standard deviation value per channel.
172
+
78
173
  """
79
174
 
80
175
  def __init__(
81
176
  self,
82
- mean: float,
83
- std: float,
177
+ image_means: list[float],
178
+ image_stds: list[float],
84
179
  ):
85
- self.mean = mean
86
- self.std = std
180
+ """Constructor.
181
+
182
+ Parameters
183
+ ----------
184
+ image_means : list of float
185
+ Mean value per channel.
186
+ image_stds : list of float
187
+ Standard deviation value per channel.
188
+ """
189
+ self.image_means = image_means
190
+ self.image_stds = image_stds
191
+
87
192
  self.eps = 1e-6
88
193
 
89
- def __call__(
90
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
91
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
92
- """Apply the transform to the source patch and the target (optional).
194
+ def __call__(self, patch: NDArray) -> NDArray:
195
+ """Reverse the normalization operation for a batch of patches.
93
196
 
94
197
  Parameters
95
198
  ----------
96
- patch : np.ndarray
97
- Patch, 2D or 3D, shape C(Z)YX.
98
- target : Optional[np.ndarray], optional
99
- Target for the patch, by default None
199
+ patch : NDArray
200
+ Patch, 2D or 3D, shape BC(Z)YX.
100
201
 
101
202
  Returns
102
203
  -------
103
- Tuple[np.ndarray, Optional[np.ndarray]]
104
- Transformed patch and target.
204
+ NDArray
205
+ Transformed array.
105
206
  """
106
- norm_patch = self._apply(patch)
107
- norm_target = self._apply(target) if target is not None else None
207
+ if len(self.image_means) != patch.shape[1]:
208
+ raise ValueError(
209
+ f"Number of means (got a list of size {len(self.image_means)}) and "
210
+ f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
211
+ f"match."
212
+ )
108
213
 
109
- return norm_patch, norm_target
214
+ means = _reshape_stats(self.image_means, patch.ndim)
215
+ stds = _reshape_stats(self.image_stds, patch.ndim)
110
216
 
111
- def _apply(self, patch: np.ndarray) -> np.ndarray:
217
+ denorm_array = self._apply(
218
+ patch,
219
+ np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
220
+ np.swapaxes(stds, 0, 1),
221
+ )
222
+
223
+ return denorm_array.astype(np.float32)
224
+
225
+ def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
112
226
  """
113
227
  Apply the transform to the image.
114
228
 
115
229
  Parameters
116
230
  ----------
117
- patch : np.ndarray
118
- Image or image patch, 2D or 3D, shape C(Z)YX.
231
+ array : NDArray
232
+ Image patch, 2D or 3D, shape C(Z)YX.
233
+ mean : NDArray
234
+ Mean values.
235
+ std : NDArray
236
+ Standard deviations.
237
+
238
+ Returns
239
+ -------
240
+ NDArray
241
+ Denormalized image array.
119
242
  """
120
- return patch * (self.std + self.eps) + self.mean
243
+ return array * (std + self.eps) + mean
@@ -5,7 +5,7 @@ Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
5
  masked pixels.
6
6
  """
7
7
 
8
- from typing import Optional, Tuple, Union
8
+ from typing import Optional, Tuple
9
9
 
10
10
  import numpy as np
11
11
 
@@ -13,9 +13,12 @@ from .struct_mask_parameters import StructMaskParameters
13
13
 
14
14
 
15
15
  def _apply_struct_mask(
16
- patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
16
+ patch: np.ndarray,
17
+ coords: np.ndarray,
18
+ struct_params: StructMaskParameters,
19
+ rng: Optional[np.random.Generator] = None,
17
20
  ) -> np.ndarray:
18
- """Applies structN2V masks to patch.
21
+ """Apply structN2V masks to patch.
19
22
 
20
23
  Each point in `coords` corresponds to the center of a mask, masks are paremeterized
21
24
  by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
@@ -31,12 +34,17 @@ def _apply_struct_mask(
31
34
  Coordinates of the ROI(subpatch) centers.
32
35
  struct_params : StructMaskParameters
33
36
  Parameters for the structN2V mask (axis and span).
37
+ rng : np.random.Generator or None
38
+ Random number generator.
34
39
 
35
40
  Returns
36
41
  -------
37
42
  np.ndarray
38
43
  Patch with the structN2V mask applied.
39
44
  """
45
+ if rng is None:
46
+ rng = np.random.default_rng()
47
+
40
48
  # relative axis
41
49
  moving_axis = -1 - struct_params.axis
42
50
 
@@ -67,7 +75,7 @@ def _apply_struct_mask(
67
75
  mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
68
76
 
69
77
  # replace neighbouring pixels with random values from flat dist
70
- patch[tuple(mix.T)] = np.random.uniform(patch.min(), patch.max(), size=mix.shape[0])
78
+ patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
71
79
 
72
80
  return patch
73
81
 
@@ -98,7 +106,9 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
98
106
 
99
107
 
100
108
  def _get_stratified_coords(
101
- mask_pixel_perc: float, shape: Union[Tuple[int, int], Tuple[int, int, int]]
109
+ mask_pixel_perc: float,
110
+ shape: Tuple[int, ...],
111
+ rng: Optional[np.random.Generator] = None,
102
112
  ) -> np.ndarray:
103
113
  """
104
114
  Generate coordinates of the pixels to mask.
@@ -113,6 +123,8 @@ def _get_stratified_coords(
113
123
  calculating the distance between masked pixels across each axis.
114
124
  shape : Tuple[int, ...]
115
125
  Shape of the input patch.
126
+ rng : np.random.Generator or None
127
+ Random number generator.
116
128
 
117
129
  Returns
118
130
  -------
@@ -124,7 +136,8 @@ def _get_stratified_coords(
124
136
  "Calculating coordinates is only possible for 2D and 3D patches"
125
137
  )
126
138
 
127
- rng = np.random.default_rng()
139
+ if rng is None:
140
+ rng = np.random.default_rng()
128
141
 
129
142
  mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
130
143
  np.int32
@@ -228,6 +241,7 @@ def uniform_manipulate(
228
241
  subpatch_size: int = 11,
229
242
  remove_center: bool = True,
230
243
  struct_params: Optional[StructMaskParameters] = None,
244
+ rng: Optional[np.random.Generator] = None,
231
245
  ) -> Tuple[np.ndarray, np.ndarray]:
232
246
  """
233
247
  Manipulate pixels by replacing them with a neighbor values.
@@ -248,19 +262,23 @@ def uniform_manipulate(
248
262
  Size of the subpatch the new pixel value is sampled from, by default 11.
249
263
  remove_center : bool
250
264
  Whether to remove the center pixel from the subpatch, by default False.
251
- struct_params: Optional[StructMaskParameters]
265
+ struct_params : StructMaskParameters or None
252
266
  Parameters for the structN2V mask (axis and span).
267
+ rng : np.random.Generator or None
268
+ Random number generator.
253
269
 
254
270
  Returns
255
271
  -------
256
272
  Tuple[np.ndarray]
257
273
  Tuple containing the manipulated patch and the corresponding mask.
258
274
  """
275
+ if rng is None:
276
+ rng = np.random.default_rng()
277
+
259
278
  # Get the coordinates of the pixels to be replaced
260
279
  transformed_patch = patch.copy()
261
280
 
262
- subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
263
- rng = np.random.default_rng()
281
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
264
282
 
265
283
  # Generate coordinate grid for subpatch
266
284
  roi_span_full = np.arange(
@@ -303,6 +321,7 @@ def median_manipulate(
303
321
  mask_pixel_percentage: float,
304
322
  subpatch_size: int = 11,
305
323
  struct_params: Optional[StructMaskParameters] = None,
324
+ rng: Optional[np.random.Generator] = None,
306
325
  ) -> Tuple[np.ndarray, np.ndarray]:
307
326
  """
308
327
  Manipulate pixels by replacing them with the median of their surrounding subpatch.
@@ -322,18 +341,23 @@ def median_manipulate(
322
341
  Approximate percentage of pixels to be masked.
323
342
  subpatch_size : int
324
343
  Size of the subpatch the new pixel value is sampled from, by default 11.
325
- struct_params: Optional[StructMaskParameters]
344
+ struct_params : StructMaskParameters or None, optional
326
345
  Parameters for the structN2V mask (axis and span).
346
+ rng : np.random.Generator or None, optional
347
+ Random number generato, by default None.
327
348
 
328
349
  Returns
329
350
  -------
330
351
  Tuple[np.ndarray]
331
352
  Tuple containing the manipulated patch, the original patch and the mask.
332
353
  """
354
+ if rng is None:
355
+ rng = np.random.default_rng()
356
+
333
357
  transformed_patch = patch.copy()
334
358
 
335
359
  # Get the coordinates of the pixels to be replaced
336
- subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
360
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
337
361
 
338
362
  # Generate coordinate grid for subpatch
339
363
  roi_span = np.array(
@@ -1,3 +1,5 @@
1
+ """Class representing the parameters of structN2V masks."""
2
+
1
3
  from dataclasses import dataclass
2
4
  from typing import Literal
3
5
 
@@ -6,7 +8,7 @@ from typing import Literal
6
8
  class StructMaskParameters:
7
9
  """Parameters of structN2V masks.
8
10
 
9
- Parameters
11
+ Attributes
10
12
  ----------
11
13
  axis : Literal[0, 1]
12
14
  Axis along which to apply the mask, horizontal (0) or vertical (1).
@@ -1,33 +1,24 @@
1
1
  """A general parent class for transforms."""
2
2
 
3
- from typing import Optional, Tuple
4
-
5
- import numpy as np
3
+ from typing import Any
6
4
 
7
5
 
8
6
  class Transform:
9
7
  """A general parent class for transforms."""
10
8
 
11
- def __call__(
12
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
13
- ) -> Tuple[np.ndarray, ...]:
14
- """Apply the transform to the input data.
9
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
10
+ """Apply the transform.
15
11
 
16
12
  Parameters
17
13
  ----------
18
- patch : np.ndarray
19
- The input data to transform.
20
- target : Optional[np.ndarray], optional
21
- The target data to transform, by default None
14
+ *args : Any
15
+ Arguments.
16
+ **kwargs : Any
17
+ Keyword arguments.
22
18
 
23
19
  Returns
24
20
  -------
25
- Tuple[np.ndarray, ...]
26
- The output of the transformations.
27
-
28
- Raises
29
- ------
30
- NotImplementedError
31
- This method should be implemented in the child class.
21
+ Any
22
+ Transformed data.
32
23
  """
33
- raise NotImplementedError
24
+ pass
@@ -1,11 +1,8 @@
1
1
  """Test-time augmentations."""
2
2
 
3
- from typing import List
4
-
5
3
  from torch import Tensor, flip, mean, rot90, stack
6
4
 
7
5
 
8
- # TODO add tests
9
6
  class ImageRestorationTTA:
10
7
  """
11
8
  Test-time augmentation for image restoration tasks.
@@ -13,62 +10,79 @@ class ImageRestorationTTA:
13
10
  The augmentation is performed using all 90 deg rotations and their flipped version,
14
11
  as well as the original image flipped.
15
12
 
16
- Tensors should be of shape SC(Z)YX
13
+ Tensors should be of shape SC(Z)YX.
17
14
 
18
15
  This transformation is used in the LightningModule in order to perform test-time
19
- agumentation.
16
+ augmentation.
20
17
  """
21
18
 
22
- def __init__(self) -> None:
23
- """Constructor."""
24
- pass
25
-
26
- def forward(self, x: Tensor) -> List[Tensor]:
19
+ def forward(self, input_tensor: Tensor) -> list[Tensor]:
27
20
  """
28
21
  Apply test-time augmentation to the input tensor.
29
22
 
30
23
  Parameters
31
24
  ----------
32
- x : Tensor
25
+ input_tensor : Tensor
33
26
  Input tensor, shape SC(Z)YX.
34
27
 
35
28
  Returns
36
29
  -------
37
- List[Tensor]
30
+ list of torch.Tensor
38
31
  List of augmented tensors.
39
32
  """
33
+ # axes: only applies to YX axes
34
+ axes = (-2, -1)
35
+
40
36
  augmented = [
41
- x,
42
- rot90(x, 1, dims=(-2, -1)),
43
- rot90(x, 2, dims=(-2, -1)),
44
- rot90(x, 3, dims=(-2, -1)),
37
+ # original
38
+ input_tensor,
39
+ # rotations
40
+ rot90(input_tensor, 1, dims=axes),
41
+ rot90(input_tensor, 2, dims=axes),
42
+ rot90(input_tensor, 3, dims=axes),
43
+ # original flipped
44
+ flip(input_tensor, dims=(axes[0],)),
45
+ flip(input_tensor, dims=(axes[1],)),
45
46
  ]
46
- augmented_flip = augmented.copy()
47
- for x_ in augmented:
48
- augmented_flip.append(flip(x_, dims=(-3, -1)))
49
- return augmented_flip
50
47
 
51
- def backward(self, x: List[Tensor]) -> Tensor:
48
+ # rotated once, flipped
49
+ augmented.extend(
50
+ [
51
+ flip(augmented[1], dims=(axes[0],)),
52
+ flip(augmented[1], dims=(axes[1],)),
53
+ ]
54
+ )
55
+
56
+ return augmented
57
+
58
+ def backward(self, x: list[Tensor]) -> Tensor:
52
59
  """Undo the test-time augmentation.
53
60
 
54
61
  Parameters
55
62
  ----------
56
63
  x : Any
57
- List of augmented tensors.
64
+ List of augmented tensors of shape SC(Z)YX.
58
65
 
59
66
  Returns
60
67
  -------
61
68
  Any
62
69
  Original tensor.
63
70
  """
71
+ axes = (-2, -1)
72
+
64
73
  reverse = [
74
+ # original
65
75
  x[0],
66
- rot90(x[1], -1, dims=(-2, -1)),
67
- rot90(x[2], -2, dims=(-2, -1)),
68
- rot90(x[3], -3, dims=(-2, -1)),
69
- flip(x[4], dims=(-3, -1)),
70
- rot90(flip(x[5], dims=(-3, -1)), -1, dims=(-2, -1)),
71
- rot90(flip(x[6], dims=(-3, -1)), -2, dims=(-2, -1)),
72
- rot90(flip(x[7], dims=(-3, -1)), -3, dims=(-2, -1)),
76
+ # rotated
77
+ rot90(x[1], -1, dims=axes),
78
+ rot90(x[2], -2, dims=axes),
79
+ rot90(x[3], -3, dims=axes),
80
+ # original flipped
81
+ flip(x[4], dims=(axes[0],)),
82
+ flip(x[5], dims=(axes[1],)),
83
+ # rotated once, flipped
84
+ rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
85
+ rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
73
86
  ]
87
+
74
88
  return mean(stack(reverse), dim=0)