careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,407 @@
1
+ """
2
+ Pixel manipulation methods.
3
+
4
+ Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
+ masked pixels.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ from .struct_mask_parameters import StructMaskParameters
13
+
14
+
15
+ def _apply_struct_mask(
16
+ patch: np.ndarray,
17
+ coords: np.ndarray,
18
+ struct_params: StructMaskParameters,
19
+ rng: Optional[np.random.Generator] = None,
20
+ ) -> np.ndarray:
21
+ """Apply structN2V masks to patch.
22
+
23
+ Each point in `coords` corresponds to the center of a mask, masks are paremeterized
24
+ by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
25
+ a random value.
26
+
27
+ Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
28
+
29
+ Parameters
30
+ ----------
31
+ patch : np.ndarray
32
+ Patch to be manipulated, 2D or 3D.
33
+ coords : np.ndarray
34
+ Coordinates of the ROI(subpatch) centers.
35
+ struct_params : StructMaskParameters
36
+ Parameters for the structN2V mask (axis and span).
37
+ rng : np.random.Generator or None
38
+ Random number generator.
39
+
40
+ Returns
41
+ -------
42
+ np.ndarray
43
+ Patch with the structN2V mask applied.
44
+ """
45
+ if rng is None:
46
+ rng = np.random.default_rng()
47
+
48
+ # relative axis
49
+ moving_axis = -1 - struct_params.axis
50
+
51
+ # Create a mask array
52
+ mask = np.expand_dims(
53
+ np.ones(struct_params.span), axis=list(range(len(patch.shape) - 1))
54
+ ) # (1, 1, span) or (1, span)
55
+
56
+ # Move the moving axis to the correct position
57
+ # i.e. the axis along which the coordinates should change
58
+ mask = np.moveaxis(mask, -1, moving_axis)
59
+ center = np.array(mask.shape) // 2
60
+
61
+ # Mark the center
62
+ mask[tuple(center.T)] = 0
63
+
64
+ # displacements from center
65
+ dx = np.indices(mask.shape)[:, mask == 1] - center[:, None]
66
+
67
+ # combine all coords (ndim, npts,) with all displacements (ncoords,ndim,)
68
+ mix = dx.T[..., None] + coords.T[None]
69
+ mix = mix.transpose([1, 0, 2]).reshape([mask.ndim, -1]).T
70
+
71
+ # delete entries that are out of bounds
72
+ mix = np.delete(mix, mix[:, moving_axis] < 0, axis=0)
73
+
74
+ max_bound = patch.shape[moving_axis] - 1
75
+ mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
76
+
77
+ # replace neighbouring pixels with random values from flat dist
78
+ patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
79
+
80
+ return patch
81
+
82
+
83
+ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
84
+ """
85
+ Randomly sample a jitter to be applied to the masking grid.
86
+
87
+ This is done to account for cases where the step size is not an integer.
88
+
89
+ Parameters
90
+ ----------
91
+ step : float
92
+ Step size of the grid, output of np.linspace.
93
+ rng : np.random.Generator
94
+ Random number generator.
95
+
96
+ Returns
97
+ -------
98
+ np.ndarray
99
+ Array of random jitter to be added to the grid.
100
+ """
101
+ # Define the random jitter to be added to the grid
102
+ odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
103
+
104
+ # Round the step size to the nearest integer depending on the jitter
105
+ return np.floor(step) if odd_jitter == 0 else np.ceil(step)
106
+
107
+
108
+ def _get_stratified_coords(
109
+ mask_pixel_perc: float,
110
+ shape: Tuple[int, ...],
111
+ rng: Optional[np.random.Generator] = None,
112
+ ) -> np.ndarray:
113
+ """
114
+ Generate coordinates of the pixels to mask.
115
+
116
+ Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
117
+ the distance between masked pixels is approximately the same.
118
+
119
+ Parameters
120
+ ----------
121
+ mask_pixel_perc : float
122
+ Actual (quasi) percentage of masked pixels across the whole image. Used in
123
+ calculating the distance between masked pixels across each axis.
124
+ shape : Tuple[int, ...]
125
+ Shape of the input patch.
126
+ rng : np.random.Generator or None
127
+ Random number generator.
128
+
129
+ Returns
130
+ -------
131
+ np.ndarray
132
+ Array of coordinates of the masked pixels.
133
+ """
134
+ if len(shape) < 2 or len(shape) > 3:
135
+ raise ValueError(
136
+ "Calculating coordinates is only possible for 2D and 3D patches"
137
+ )
138
+
139
+ if rng is None:
140
+ rng = np.random.default_rng()
141
+
142
+ mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
143
+ np.int32
144
+ )
145
+
146
+ # Define a grid of coordinates for each axis in the input patch and the step size
147
+ pixel_coords = []
148
+ steps = []
149
+ for axis_size in shape:
150
+ # make sure axis size is evenly divisible by box size
151
+ num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
152
+ axis_pixel_coords, step = np.linspace(
153
+ 0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
154
+ )
155
+ # explain
156
+ pixel_coords.append(axis_pixel_coords.T)
157
+ steps.append(step)
158
+
159
+ # Create a meshgrid of coordinates for each axis in the input patch
160
+ coordinate_grid_list = np.meshgrid(*pixel_coords)
161
+ coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
162
+
163
+ grid_random_increment = rng.integers(
164
+ _odd_jitter_func(float(max(steps)), rng) # type: ignore
165
+ * np.ones_like(coordinate_grid).astype(np.int32)
166
+ - 1,
167
+ size=coordinate_grid.shape,
168
+ endpoint=True,
169
+ )
170
+ coordinate_grid += grid_random_increment
171
+ coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
172
+ return coordinate_grid
173
+
174
+
175
+ def _create_subpatch_center_mask(
176
+ subpatch: np.ndarray, center_coords: np.ndarray
177
+ ) -> np.ndarray:
178
+ """Create a mask with the center of the subpatch masked.
179
+
180
+ Parameters
181
+ ----------
182
+ subpatch : np.ndarray
183
+ Subpatch to be manipulated.
184
+ center_coords : np.ndarray
185
+ Coordinates of the original center before possible crop.
186
+
187
+ Returns
188
+ -------
189
+ np.ndarray
190
+ Mask with the center of the subpatch masked.
191
+ """
192
+ mask = np.ones(subpatch.shape)
193
+ mask[tuple(center_coords)] = 0
194
+ return np.ma.make_mask(mask) # type: ignore
195
+
196
+
197
+ def _create_subpatch_struct_mask(
198
+ subpatch: np.ndarray, center_coords: np.ndarray, struct_params: StructMaskParameters
199
+ ) -> np.ndarray:
200
+ """Create a structN2V mask for the subpatch.
201
+
202
+ Parameters
203
+ ----------
204
+ subpatch : np.ndarray
205
+ Subpatch to be manipulated.
206
+ center_coords : np.ndarray
207
+ Coordinates of the original center before possible crop.
208
+ struct_params : StructMaskParameters
209
+ Parameters for the structN2V mask (axis and span).
210
+
211
+ Returns
212
+ -------
213
+ np.ndarray
214
+ StructN2V mask for the subpatch.
215
+ """
216
+ # Create a mask with the center of the subpatch masked
217
+ mask_placeholder = np.ones(subpatch.shape)
218
+
219
+ # reshape to move the struct axis to the first position
220
+ mask_reshaped = np.moveaxis(mask_placeholder, struct_params.axis, 0)
221
+
222
+ # create the mask index for the struct axis
223
+ mask_index = slice(
224
+ max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2),
225
+ min(
226
+ 1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2,
227
+ subpatch.shape[struct_params.axis],
228
+ ),
229
+ )
230
+ mask_reshaped[struct_params.axis][mask_index] = 0
231
+
232
+ # reshape back to the original shape
233
+ mask = np.moveaxis(mask_reshaped, 0, struct_params.axis)
234
+
235
+ return np.ma.make_mask(mask) # type: ignore
236
+
237
+
238
+ def uniform_manipulate(
239
+ patch: np.ndarray,
240
+ mask_pixel_percentage: float,
241
+ subpatch_size: int = 11,
242
+ remove_center: bool = True,
243
+ struct_params: Optional[StructMaskParameters] = None,
244
+ rng: Optional[np.random.Generator] = None,
245
+ ) -> Tuple[np.ndarray, np.ndarray]:
246
+ """
247
+ Manipulate pixels by replacing them with a neighbor values.
248
+
249
+ Manipulated pixels are selected unformly selected in a subpatch, away from a grid
250
+ with an approximate uniform probability to be selected across the whole patch.
251
+ If `struct_params` is not None, an additional structN2V mask is applied to the
252
+ data, replacing the pixels in the mask with random values (excluding the pixel
253
+ already manipulated).
254
+
255
+ Parameters
256
+ ----------
257
+ patch : np.ndarray
258
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
259
+ mask_pixel_percentage : float
260
+ Approximate percentage of pixels to be masked.
261
+ subpatch_size : int
262
+ Size of the subpatch the new pixel value is sampled from, by default 11.
263
+ remove_center : bool
264
+ Whether to remove the center pixel from the subpatch, by default False.
265
+ struct_params : StructMaskParameters or None
266
+ Parameters for the structN2V mask (axis and span).
267
+ rng : np.random.Generator or None
268
+ Random number generator.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[np.ndarray]
273
+ Tuple containing the manipulated patch and the corresponding mask.
274
+ """
275
+ if rng is None:
276
+ rng = np.random.default_rng()
277
+
278
+ # Get the coordinates of the pixels to be replaced
279
+ transformed_patch = patch.copy()
280
+
281
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
282
+
283
+ # Generate coordinate grid for subpatch
284
+ roi_span_full = np.arange(
285
+ -np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)
286
+ ).astype(np.int32)
287
+
288
+ # Remove the center pixel from the grid if needed
289
+ roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
290
+
291
+ # Randomly select coordinates from the grid
292
+ random_increment = rng.choice(roi_span, size=subpatch_centers.shape)
293
+
294
+ # Clip the coordinates to the patch size
295
+ replacement_coords = np.clip(
296
+ subpatch_centers + random_increment,
297
+ 0,
298
+ [patch.shape[i] - 1 for i in range(len(patch.shape))],
299
+ )
300
+
301
+ # Get the replacement pixels from all subpatchs
302
+ replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
303
+
304
+ # Replace the original pixels with the replacement pixels
305
+ transformed_patch[tuple(subpatch_centers.T.tolist())] = replacement_pixels
306
+ mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
307
+
308
+ if struct_params is not None:
309
+ transformed_patch = _apply_struct_mask(
310
+ transformed_patch, subpatch_centers, struct_params
311
+ )
312
+
313
+ return (
314
+ transformed_patch,
315
+ mask,
316
+ )
317
+
318
+
319
+ def median_manipulate(
320
+ patch: np.ndarray,
321
+ mask_pixel_percentage: float,
322
+ subpatch_size: int = 11,
323
+ struct_params: Optional[StructMaskParameters] = None,
324
+ rng: Optional[np.random.Generator] = None,
325
+ ) -> Tuple[np.ndarray, np.ndarray]:
326
+ """
327
+ Manipulate pixels by replacing them with the median of their surrounding subpatch.
328
+
329
+ N2V2 version, manipulated pixels are selected randomly away from a grid with an
330
+ approximate uniform probability to be selected across the whole patch.
331
+
332
+ If `struct_params` is not None, an additional structN2V mask is applied to the data,
333
+ replacing the pixels in the mask with random values (excluding the pixel already
334
+ manipulated).
335
+
336
+ Parameters
337
+ ----------
338
+ patch : np.ndarray
339
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
340
+ mask_pixel_percentage : floar
341
+ Approximate percentage of pixels to be masked.
342
+ subpatch_size : int
343
+ Size of the subpatch the new pixel value is sampled from, by default 11.
344
+ struct_params : StructMaskParameters or None, optional
345
+ Parameters for the structN2V mask (axis and span).
346
+ rng : np.random.Generator or None, optional
347
+ Random number generato, by default None.
348
+
349
+ Returns
350
+ -------
351
+ Tuple[np.ndarray]
352
+ Tuple containing the manipulated patch, the original patch and the mask.
353
+ """
354
+ if rng is None:
355
+ rng = np.random.default_rng()
356
+
357
+ transformed_patch = patch.copy()
358
+
359
+ # Get the coordinates of the pixels to be replaced
360
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
361
+
362
+ # Generate coordinate grid for subpatch
363
+ roi_span = np.array(
364
+ [-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)]
365
+ ).astype(np.int32)
366
+
367
+ subpatch_crops_span_full = subpatch_centers[np.newaxis, ...].T + roi_span
368
+
369
+ # Dimensions n dims, n centers, (min, max)
370
+ subpatch_crops_span_clipped = np.clip(
371
+ subpatch_crops_span_full,
372
+ a_min=np.zeros_like(patch.shape)[:, np.newaxis, np.newaxis],
373
+ a_max=np.array(patch.shape)[:, np.newaxis, np.newaxis],
374
+ )
375
+
376
+ for idx in range(subpatch_crops_span_clipped.shape[1]):
377
+ subpatch_coords = subpatch_crops_span_clipped[:, idx, ...]
378
+ idxs = [
379
+ slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1)
380
+ for x in subpatch_coords
381
+ ]
382
+ subpatch = patch[tuple(idxs)]
383
+ subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0]
384
+
385
+ if struct_params is None:
386
+ subpatch_mask = _create_subpatch_center_mask(
387
+ subpatch, subpatch_center_adjusted
388
+ )
389
+ else:
390
+ subpatch_mask = _create_subpatch_struct_mask(
391
+ subpatch, subpatch_center_adjusted, struct_params
392
+ )
393
+ transformed_patch[tuple(subpatch_centers[idx])] = np.median(
394
+ subpatch[subpatch_mask]
395
+ )
396
+
397
+ mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
398
+
399
+ if struct_params is not None:
400
+ transformed_patch = _apply_struct_mask(
401
+ transformed_patch, subpatch_centers, struct_params
402
+ )
403
+
404
+ return (
405
+ transformed_patch,
406
+ mask,
407
+ )
@@ -0,0 +1,20 @@
1
+ """Class representing the parameters of structN2V masks."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+
7
+ @dataclass
8
+ class StructMaskParameters:
9
+ """Parameters of structN2V masks.
10
+
11
+ Attributes
12
+ ----------
13
+ axis : Literal[0, 1]
14
+ Axis along which to apply the mask, horizontal (0) or vertical (1).
15
+ span : int
16
+ Span of the mask.
17
+ """
18
+
19
+ axis: Literal[0, 1]
20
+ span: int
@@ -0,0 +1,24 @@
1
+ """A general parent class for transforms."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ class Transform:
7
+ """A general parent class for transforms."""
8
+
9
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
10
+ """Apply the transform.
11
+
12
+ Parameters
13
+ ----------
14
+ *args : Any
15
+ Arguments.
16
+ **kwargs : Any
17
+ Keyword arguments.
18
+
19
+ Returns
20
+ -------
21
+ Any
22
+ Transformed data.
23
+ """
24
+ pass
@@ -0,0 +1,88 @@
1
+ """Test-time augmentations."""
2
+
3
+ from torch import Tensor, flip, mean, rot90, stack
4
+
5
+
6
+ class ImageRestorationTTA:
7
+ """
8
+ Test-time augmentation for image restoration tasks.
9
+
10
+ The augmentation is performed using all 90 deg rotations and their flipped version,
11
+ as well as the original image flipped.
12
+
13
+ Tensors should be of shape SC(Z)YX.
14
+
15
+ This transformation is used in the LightningModule in order to perform test-time
16
+ augmentation.
17
+ """
18
+
19
+ def forward(self, input_tensor: Tensor) -> list[Tensor]:
20
+ """
21
+ Apply test-time augmentation to the input tensor.
22
+
23
+ Parameters
24
+ ----------
25
+ input_tensor : Tensor
26
+ Input tensor, shape SC(Z)YX.
27
+
28
+ Returns
29
+ -------
30
+ list of torch.Tensor
31
+ List of augmented tensors.
32
+ """
33
+ # axes: only applies to YX axes
34
+ axes = (-2, -1)
35
+
36
+ augmented = [
37
+ # original
38
+ input_tensor,
39
+ # rotations
40
+ rot90(input_tensor, 1, dims=axes),
41
+ rot90(input_tensor, 2, dims=axes),
42
+ rot90(input_tensor, 3, dims=axes),
43
+ # original flipped
44
+ flip(input_tensor, dims=(axes[0],)),
45
+ flip(input_tensor, dims=(axes[1],)),
46
+ ]
47
+
48
+ # rotated once, flipped
49
+ augmented.extend(
50
+ [
51
+ flip(augmented[1], dims=(axes[0],)),
52
+ flip(augmented[1], dims=(axes[1],)),
53
+ ]
54
+ )
55
+
56
+ return augmented
57
+
58
+ def backward(self, x: list[Tensor]) -> Tensor:
59
+ """Undo the test-time augmentation.
60
+
61
+ Parameters
62
+ ----------
63
+ x : Any
64
+ List of augmented tensors of shape SC(Z)YX.
65
+
66
+ Returns
67
+ -------
68
+ Any
69
+ Original tensor.
70
+ """
71
+ axes = (-2, -1)
72
+
73
+ reverse = [
74
+ # original
75
+ x[0],
76
+ # rotated
77
+ rot90(x[1], -1, dims=axes),
78
+ rot90(x[2], -2, dims=axes),
79
+ rot90(x[3], -3, dims=axes),
80
+ # original flipped
81
+ flip(x[4], dims=(axes[0],)),
82
+ flip(x[5], dims=(axes[1],)),
83
+ # rotated once, flipped
84
+ rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
85
+ rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
86
+ ]
87
+
88
+ return mean(stack(reverse), dim=0)
@@ -0,0 +1,123 @@
1
+ """XY flip transform."""
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.transforms.transform import Transform
8
+
9
+
10
+ class XYFlip(Transform):
11
+ """Flip image along X and Y axis, one at a time.
12
+
13
+ This transform randomly flips one of the last two axes.
14
+
15
+ This transform expects C(Z)YX dimensions.
16
+
17
+ Attributes
18
+ ----------
19
+ axis_indices : List[int]
20
+ Indices of the axes that can be flipped.
21
+ rng : np.random.Generator
22
+ Random number generator.
23
+ p : float
24
+ Probability of applying the transform.
25
+ seed : Optional[int]
26
+ Random seed.
27
+
28
+ Parameters
29
+ ----------
30
+ flip_x : bool, optional
31
+ Whether to flip along the X axis, by default True.
32
+ flip_y : bool, optional
33
+ Whether to flip along the Y axis, by default True.
34
+ p : float, optional
35
+ Probability of applying the transform, by default 0.5.
36
+ seed : Optional[int], optional
37
+ Random seed, by default None.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ flip_x: bool = True,
43
+ flip_y: bool = True,
44
+ p: float = 0.5,
45
+ seed: Optional[int] = None,
46
+ ) -> None:
47
+ """Constructor.
48
+
49
+ Parameters
50
+ ----------
51
+ flip_x : bool, optional
52
+ Whether to flip along the X axis, by default True.
53
+ flip_y : bool, optional
54
+ Whether to flip along the Y axis, by default True.
55
+ p : float
56
+ Probability of applying the transform, by default 0.5.
57
+ seed : Optional[int], optional
58
+ Random seed, by default None.
59
+ """
60
+ if p < 0 or p > 1:
61
+ raise ValueError("Probability must be in [0, 1].")
62
+
63
+ if not flip_x and not flip_y:
64
+ raise ValueError("At least one axis must be flippable.")
65
+
66
+ # probability to apply the transform
67
+ self.p = p
68
+
69
+ # "flippable" axes
70
+ self.axis_indices = []
71
+
72
+ if flip_y:
73
+ self.axis_indices.append(-2)
74
+ if flip_x:
75
+ self.axis_indices.append(-1)
76
+
77
+ # numpy random generator
78
+ self.rng = np.random.default_rng(seed=seed)
79
+
80
+ def __call__(
81
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
82
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
83
+ """Apply the transform to the source patch and the target (optional).
84
+
85
+ Parameters
86
+ ----------
87
+ patch : np.ndarray
88
+ Patch, 2D or 3D, shape C(Z)YX.
89
+ target : Optional[np.ndarray], optional
90
+ Target for the patch, by default None.
91
+
92
+ Returns
93
+ -------
94
+ Tuple[np.ndarray, Optional[np.ndarray]]
95
+ Transformed patch and target.
96
+ """
97
+ if self.rng.random() > self.p:
98
+ return patch, target
99
+
100
+ # choose an axis to flip
101
+ axis = self.rng.choice(self.axis_indices)
102
+
103
+ patch_transformed = self._apply(patch, axis)
104
+ target_transformed = self._apply(target, axis) if target is not None else None
105
+
106
+ return patch_transformed, target_transformed
107
+
108
+ def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
109
+ """Apply the transform to the image.
110
+
111
+ Parameters
112
+ ----------
113
+ patch : np.ndarray
114
+ Image patch, 2D or 3D, shape C(Z)YX.
115
+ axis : int
116
+ Axis to flip.
117
+
118
+ Returns
119
+ -------
120
+ np.ndarray
121
+ Flipped image patch.
122
+ """
123
+ return np.ascontiguousarray(np.flip(patch, axis=axis))