careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,383 @@
1
+ """
2
+ Pixel manipulation methods.
3
+
4
+ Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
+ masked pixels.
6
+ """
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+
11
+ from .struct_mask_parameters import StructMaskParameters
12
+
13
+
14
+ def _apply_struct_mask(
15
+ patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
16
+ ) -> np.ndarray:
17
+ """Applies structN2V masks to patch.
18
+
19
+ Each point in `coords` corresponds to the center of a mask, masks are paremeterized
20
+ by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
21
+ a random value.
22
+
23
+ Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
24
+
25
+ Parameters
26
+ ----------
27
+ patch : np.ndarray
28
+ Patch to be manipulated, 2D or 3D.
29
+ coords : np.ndarray
30
+ Coordinates of the ROI(subpatch) centers.
31
+ struct_params : StructMaskParameters
32
+ Parameters for the structN2V mask (axis and span).
33
+
34
+ Returns
35
+ -------
36
+ np.ndarray
37
+ Patch with the structN2V mask applied.
38
+ """
39
+ # relative axis
40
+ moving_axis = -1 - struct_params.axis
41
+
42
+ # Create a mask array
43
+ mask = np.expand_dims(
44
+ np.ones(struct_params.span), axis=list(range(len(patch.shape) - 1))
45
+ ) # (1, 1, span) or (1, span)
46
+
47
+ # Move the moving axis to the correct position
48
+ # i.e. the axis along which the coordinates should change
49
+ mask = np.moveaxis(mask, -1, moving_axis)
50
+ center = np.array(mask.shape) // 2
51
+
52
+ # Mark the center
53
+ mask[tuple(center.T)] = 0
54
+
55
+ # displacements from center
56
+ dx = np.indices(mask.shape)[:, mask == 1] - center[:, None]
57
+
58
+ # combine all coords (ndim, npts,) with all displacements (ncoords,ndim,)
59
+ mix = dx.T[..., None] + coords.T[None]
60
+ mix = mix.transpose([1, 0, 2]).reshape([mask.ndim, -1]).T
61
+
62
+ # delete entries that are out of bounds
63
+ mix = np.delete(mix, mix[:, moving_axis] < 0, axis=0)
64
+
65
+ max_bound = patch.shape[moving_axis] - 1
66
+ mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
67
+
68
+ # replace neighbouring pixels with random values from flat dist
69
+ patch[tuple(mix.T)] = np.random.uniform(patch.min(), patch.max(), size=mix.shape[0])
70
+
71
+ return patch
72
+
73
+
74
+ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
75
+ """
76
+ Randomly sample a jitter to be applied to the masking grid.
77
+
78
+ This is done to account for cases where the step size is not an integer.
79
+
80
+ Parameters
81
+ ----------
82
+ step : float
83
+ Step size of the grid, output of np.linspace.
84
+ rng : np.random.Generator
85
+ Random number generator.
86
+
87
+ Returns
88
+ -------
89
+ np.ndarray
90
+ Array of random jitter to be added to the grid.
91
+ """
92
+ # Define the random jitter to be added to the grid
93
+ odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
94
+
95
+ # Round the step size to the nearest integer depending on the jitter
96
+ return np.floor(step) if odd_jitter == 0 else np.ceil(step)
97
+
98
+
99
+ def _get_stratified_coords(
100
+ mask_pixel_perc: float, shape: Union[Tuple[int, int], Tuple[int, int, int]]
101
+ ) -> np.ndarray:
102
+ """
103
+ Generate coordinates of the pixels to mask.
104
+
105
+ Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
106
+ the distance between masked pixels is approximately the same.
107
+
108
+ Parameters
109
+ ----------
110
+ mask_pixel_perc : float
111
+ Actual (quasi) percentage of masked pixels across the whole image. Used in
112
+ calculating the distance between masked pixels across each axis.
113
+ shape : Tuple[int, ...]
114
+ Shape of the input patch.
115
+
116
+ Returns
117
+ -------
118
+ np.ndarray
119
+ Array of coordinates of the masked pixels.
120
+ """
121
+ if len(shape) < 2 or len(shape) > 3:
122
+ raise ValueError(
123
+ "Calculating coordinates is only possible for 2D and 3D patches"
124
+ )
125
+
126
+ rng = np.random.default_rng()
127
+
128
+ mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
129
+ np.int32
130
+ )
131
+
132
+ # Define a grid of coordinates for each axis in the input patch and the step size
133
+ pixel_coords = []
134
+ steps = []
135
+ for axis_size in shape:
136
+ # make sure axis size is evenly divisible by box size
137
+ num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
138
+ axis_pixel_coords, step = np.linspace(
139
+ 0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
140
+ )
141
+ # explain
142
+ pixel_coords.append(axis_pixel_coords.T)
143
+ steps.append(step)
144
+
145
+ # Create a meshgrid of coordinates for each axis in the input patch
146
+ coordinate_grid_list = np.meshgrid(*pixel_coords)
147
+ coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
148
+
149
+ grid_random_increment = rng.integers(
150
+ _odd_jitter_func(float(max(steps)), rng)
151
+ * np.ones_like(coordinate_grid).astype(np.int32)
152
+ - 1,
153
+ size=coordinate_grid.shape,
154
+ endpoint=True,
155
+ )
156
+ coordinate_grid += grid_random_increment
157
+ coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
158
+ return coordinate_grid
159
+
160
+
161
+ def _create_subpatch_center_mask(
162
+ subpatch: np.ndarray, center_coords: np.ndarray
163
+ ) -> np.ndarray:
164
+ """Create a mask with the center of the subpatch masked.
165
+
166
+ Parameters
167
+ ----------
168
+ subpatch : np.ndarray
169
+ Subpatch to be manipulated.
170
+ center_coords : np.ndarray
171
+ Coordinates of the original center before possible crop.
172
+
173
+ Returns
174
+ -------
175
+ np.ndarray
176
+ Mask with the center of the subpatch masked.
177
+ """
178
+ mask = np.ones(subpatch.shape)
179
+ mask[tuple(center_coords)] = 0
180
+ return np.ma.make_mask(mask) # type: ignore
181
+
182
+
183
+ def _create_subpatch_struct_mask(
184
+ subpatch: np.ndarray, center_coords: np.ndarray, struct_params: StructMaskParameters
185
+ ) -> np.ndarray:
186
+ """Create a structN2V mask for the subpatch.
187
+
188
+ Parameters
189
+ ----------
190
+ subpatch : np.ndarray
191
+ Subpatch to be manipulated.
192
+ center_coords : np.ndarray
193
+ Coordinates of the original center before possible crop.
194
+ struct_params : StructMaskParameters
195
+ Parameters for the structN2V mask (axis and span).
196
+
197
+ Returns
198
+ -------
199
+ np.ndarray
200
+ StructN2V mask for the subpatch.
201
+ """
202
+ # Create a mask with the center of the subpatch masked
203
+ mask_placeholder = np.ones(subpatch.shape)
204
+
205
+ # reshape to move the struct axis to the first position
206
+ mask_reshaped = np.moveaxis(mask_placeholder, struct_params.axis, 0)
207
+
208
+ # create the mask index for the struct axis
209
+ mask_index = slice(
210
+ max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2),
211
+ min(
212
+ 1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2,
213
+ subpatch.shape[struct_params.axis],
214
+ ),
215
+ )
216
+ mask_reshaped[struct_params.axis][mask_index] = 0
217
+
218
+ # reshape back to the original shape
219
+ mask = np.moveaxis(mask_reshaped, 0, struct_params.axis)
220
+
221
+ return np.ma.make_mask(mask) # type: ignore
222
+
223
+
224
+ def uniform_manipulate(
225
+ patch: np.ndarray,
226
+ mask_pixel_percentage: float,
227
+ subpatch_size: int = 11,
228
+ remove_center: bool = True,
229
+ struct_params: Optional[StructMaskParameters] = None,
230
+ ) -> Tuple[np.ndarray, np.ndarray]:
231
+ """
232
+ Manipulate pixels by replacing them with a neighbor values.
233
+
234
+ Manipulated pixels are selected unformly selected in a subpatch, away from a grid
235
+ with an approximate uniform probability to be selected across the whole patch.
236
+ If `struct_params` is not None, an additional structN2V mask is applied to the
237
+ data, replacing the pixels in the mask with random values (excluding the pixel
238
+ already manipulated).
239
+
240
+ Parameters
241
+ ----------
242
+ patch : np.ndarray
243
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
244
+ mask_pixel_percentage : float
245
+ Approximate percentage of pixels to be masked.
246
+ subpatch_size : int
247
+ Size of the subpatch the new pixel value is sampled from, by default 11.
248
+ remove_center : bool
249
+ Whether to remove the center pixel from the subpatch, by default False. See
250
+ uniform with/without central pixel in the documentation. #TODO add link
251
+ struct_params: Optional[StructMaskParameters]
252
+ Parameters for the structN2V mask (axis and span).
253
+
254
+ Returns
255
+ -------
256
+ Tuple[np.ndarray]
257
+ Tuple containing the manipulated patch and the corresponding mask.
258
+ """
259
+ # Get the coordinates of the pixels to be replaced
260
+ transformed_patch = patch.copy()
261
+
262
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
263
+ rng = np.random.default_rng()
264
+
265
+ # Generate coordinate grid for subpatch
266
+ roi_span_full = np.arange(
267
+ -np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)
268
+ ).astype(np.int32)
269
+
270
+ # Remove the center pixel from the grid if needed
271
+ roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
272
+
273
+ # Randomly select coordinates from the grid
274
+ random_increment = rng.choice(roi_span, size=subpatch_centers.shape)
275
+
276
+ # Clip the coordinates to the patch size
277
+ replacement_coords = np.clip(
278
+ subpatch_centers + random_increment,
279
+ 0,
280
+ [patch.shape[i] - 1 for i in range(len(patch.shape))],
281
+ )
282
+
283
+ # Get the replacement pixels from all subpatchs
284
+ replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
285
+
286
+ # Replace the original pixels with the replacement pixels
287
+ transformed_patch[tuple(subpatch_centers.T.tolist())] = replacement_pixels
288
+ mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
289
+
290
+ if struct_params is not None:
291
+ transformed_patch = _apply_struct_mask(
292
+ transformed_patch, subpatch_centers, struct_params
293
+ )
294
+
295
+ return (
296
+ transformed_patch,
297
+ mask,
298
+ )
299
+
300
+
301
+ def median_manipulate(
302
+ patch: np.ndarray,
303
+ mask_pixel_percentage: float,
304
+ subpatch_size: int = 11,
305
+ struct_params: Optional[StructMaskParameters] = None,
306
+ ) -> Tuple[np.ndarray, np.ndarray]:
307
+ """
308
+ Manipulate pixels by replacing them with the median of their surrounding subpatch.
309
+
310
+ N2V2 version, manipulated pixels are selected randomly away from a grid with an
311
+ approximate uniform probability to be selected across the whole patch.
312
+
313
+ If `struct_params` is not None, an additional structN2V mask is applied to the data,
314
+ replacing the pixels in the mask with random values (excluding the pixel already
315
+ manipulated).
316
+
317
+ Parameters
318
+ ----------
319
+ patch : np.ndarray
320
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
321
+ mask_pixel_percentage : floar
322
+ Approximate percentage of pixels to be masked.
323
+ subpatch_size : int
324
+ Size of the subpatch the new pixel value is sampled from, by default 11.
325
+ struct_params: Optional[StructMaskParameters]
326
+ Parameters for the structN2V mask (axis and span).
327
+
328
+ Returns
329
+ -------
330
+ Tuple[np.ndarray]
331
+ Tuple containing the manipulated patch, the original patch and the mask.
332
+ """
333
+ transformed_patch = patch.copy()
334
+
335
+ # Get the coordinates of the pixels to be replaced
336
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
337
+
338
+ # Generate coordinate grid for subpatch
339
+ roi_span = np.array(
340
+ [-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)]
341
+ ).astype(np.int32)
342
+
343
+ subpatch_crops_span_full = subpatch_centers[np.newaxis, ...].T + roi_span
344
+
345
+ # Dimensions n dims, n centers, (min, max)
346
+ subpatch_crops_span_clipped = np.clip(
347
+ subpatch_crops_span_full,
348
+ a_min=np.zeros_like(patch.shape)[:, np.newaxis, np.newaxis],
349
+ a_max=np.array(patch.shape)[:, np.newaxis, np.newaxis],
350
+ )
351
+
352
+ for idx in range(subpatch_crops_span_clipped.shape[1]):
353
+ subpatch_coords = subpatch_crops_span_clipped[:, idx, ...]
354
+ idxs = [
355
+ slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1)
356
+ for x in subpatch_coords
357
+ ]
358
+ subpatch = patch[tuple(idxs)]
359
+ subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0]
360
+
361
+ if struct_params is None:
362
+ subpatch_mask = _create_subpatch_center_mask(
363
+ subpatch, subpatch_center_adjusted
364
+ )
365
+ else:
366
+ subpatch_mask = _create_subpatch_struct_mask(
367
+ subpatch, subpatch_center_adjusted, struct_params
368
+ )
369
+ transformed_patch[tuple(subpatch_centers[idx])] = np.median(
370
+ subpatch[subpatch_mask]
371
+ )
372
+
373
+ mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
374
+
375
+ if struct_params is not None:
376
+ transformed_patch = _apply_struct_mask(
377
+ transformed_patch, subpatch_centers, struct_params
378
+ )
379
+
380
+ return (
381
+ transformed_patch,
382
+ mask,
383
+ )
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+
5
+ @dataclass
6
+ class StructMaskParameters:
7
+ """Parameters of structN2V masks.
8
+
9
+ Parameters
10
+ ----------
11
+ axis : Literal[0, 1]
12
+ Axis along which to apply the mask, horizontal (0) or vertical (1).
13
+ span : int
14
+ Span of the mask.
15
+ """
16
+
17
+ axis: Literal[0, 1]
18
+ span: int
@@ -0,0 +1,74 @@
1
+ """Test-time augmentations."""
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from torch import Tensor, flip, mean, rot90, stack
6
+
7
+
8
+ # TODO add tests
9
+ class ImageRestorationTTA:
10
+ """
11
+ Test-time augmentation for image restoration tasks.
12
+
13
+ The augmentation is performed using all 90 deg rotations and their flipped version,
14
+ as well as the original image flipped.
15
+
16
+ Tensors should be of shape SC(Z)YX
17
+
18
+ This transformation is used in the LightningModule in order to perform test-time
19
+ agumentation.
20
+ """
21
+
22
+ def __init__(self) -> None:
23
+ """Constructor."""
24
+ pass
25
+
26
+ def forward(self, x: Tensor) -> List[Tensor]:
27
+ """
28
+ Apply test-time augmentation to the input tensor.
29
+
30
+ Parameters
31
+ ----------
32
+ x : Tensor
33
+ Input tensor, shape SC(Z)YX.
34
+
35
+ Returns
36
+ -------
37
+ List[Tensor]
38
+ List of augmented tensors.
39
+ """
40
+ augmented = [
41
+ x,
42
+ rot90(x, 1, dims=(-2, -1)),
43
+ rot90(x, 2, dims=(-2, -1)),
44
+ rot90(x, 3, dims=(-2, -1)),
45
+ ]
46
+ augmented_flip = augmented.copy()
47
+ for x_ in augmented:
48
+ augmented_flip.append(flip(x_, dims=(-3, -1)))
49
+ return augmented_flip
50
+
51
+ def backward(self, x: List[Tensor]) -> np.ndarray:
52
+ """Undo the test-time augmentation.
53
+
54
+ Parameters
55
+ ----------
56
+ x : Any
57
+ List of augmented tensors.
58
+
59
+ Returns
60
+ -------
61
+ Any
62
+ Original tensor.
63
+ """
64
+ reverse = [
65
+ x[0],
66
+ rot90(x[1], -1, dims=(-2, -1)),
67
+ rot90(x[2], -2, dims=(-2, -1)),
68
+ rot90(x[3], -3, dims=(-2, -1)),
69
+ flip(x[4], dims=(-3, -1)),
70
+ rot90(flip(x[5], dims=(-3, -1)), -1, dims=(-2, -1)),
71
+ rot90(flip(x[6], dims=(-3, -1)), -2, dims=(-2, -1)),
72
+ rot90(flip(x[7], dims=(-3, -1)), -3, dims=(-2, -1)),
73
+ ]
74
+ return mean(stack(reverse), dim=0)
@@ -0,0 +1,95 @@
1
+ from typing import Any, Dict, Tuple
2
+
3
+ import numpy as np
4
+ from albumentations import DualTransform
5
+
6
+
7
+ class XYRandomRotate90(DualTransform):
8
+ """Applies random 90 degree rotations to the YX axis.
9
+
10
+ This transform expects (Z)YXC dimensions.
11
+
12
+ Parameters
13
+ ----------
14
+ p : int, optional
15
+ Probability to apply the transform, by default 0.5
16
+ is_3D : bool, optional
17
+ Whether the patches are 3D, by default False
18
+ """
19
+
20
+ def __init__(self, p: float = 0.5, is_3D: bool = False):
21
+ """Constructor.
22
+
23
+ Parameters
24
+ ----------
25
+ p : float, optional
26
+ Probability to apply the transform, by default 0.5
27
+ is_3D : bool, optional
28
+ Whether the patches are 3D, by default False
29
+ """
30
+ super().__init__(p=p)
31
+
32
+ self.is_3D = is_3D
33
+
34
+ # rotation axes
35
+ if is_3D:
36
+ self.axes = (1, 2)
37
+ else:
38
+ self.axes = (0, 1)
39
+
40
+ def get_params(self, **kwargs: Any) -> Dict[str, int]:
41
+ """Get the transform parameters.
42
+
43
+ Returns
44
+ -------
45
+ Dict[str, int]
46
+ Transform parameters.
47
+ """
48
+ return {"n_rotations": np.random.randint(1, 4)}
49
+
50
+ def apply(self, patch: np.ndarray, n_rotations: int, **kwargs: Any) -> np.ndarray:
51
+ """Apply the transform to the image.
52
+
53
+ Parameters
54
+ ----------
55
+ patch : np.ndarray
56
+ Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
57
+ flip_axis : int
58
+ Axis along which to flip the patch.
59
+ """
60
+ if len(patch.shape) == 3 and self.is_3D:
61
+ raise ValueError(
62
+ "Incompatible patch shape and dimensionality. ZYXC patch shape "
63
+ "expected, but got YXC shape."
64
+ )
65
+
66
+ return np.ascontiguousarray(np.rot90(patch, k=n_rotations, axes=self.axes))
67
+
68
+ def apply_to_mask(
69
+ self, mask: np.ndarray, n_rotations: int, **kwargs: Any
70
+ ) -> np.ndarray:
71
+ """Apply the transform to the mask.
72
+
73
+ Parameters
74
+ ----------
75
+ mask : np.ndarray
76
+ Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
77
+ """
78
+ if len(mask.shape) != 4 and self.is_3D:
79
+ raise ValueError(
80
+ "Incompatible mask shape and dimensionality. ZYXC patch shape "
81
+ "expected, but got YXC shape."
82
+ )
83
+
84
+ return np.ascontiguousarray(np.rot90(mask, k=n_rotations, axes=self.axes))
85
+
86
+ def get_transform_init_args_names(self) -> Tuple[str, str]:
87
+ """
88
+ Get the transform arguments.
89
+
90
+ Returns
91
+ -------
92
+ Tuple[str]
93
+ Transform arguments.
94
+ """
95
+ return ("p", "is_3D")
@@ -2,19 +2,17 @@
2
2
 
3
3
 
4
4
  __all__ = [
5
- "denormalize",
6
- "normalize",
7
- "get_device",
8
- "check_axes_validity",
9
- "add_axes",
10
- "check_tiling_validity",
11
5
  "cwd",
12
- "MetricTracker",
6
+ "get_ram_size",
7
+ "check_path_exists",
8
+ "BaseEnum",
9
+ "get_logger",
10
+ "get_careamics_home",
13
11
  ]
14
12
 
15
13
 
16
- from .context import cwd
17
- from .metrics import MetricTracker
18
- from .normalization import denormalize, normalize
19
- from .torch_utils import get_device
20
- from .validators import add_axes, check_axes_validity, check_tiling_validity
14
+ from .base_enum import BaseEnum
15
+ from .context import cwd, get_careamics_home
16
+ from .logging import get_logger
17
+ from .path_utils import check_path_exists
18
+ from .ram import get_ram_size
@@ -0,0 +1,32 @@
1
+ from enum import Enum, EnumMeta
2
+ from typing import Any
3
+
4
+
5
+ class _ContainerEnum(EnumMeta):
6
+ def __contains__(cls, item: Any) -> bool:
7
+ try:
8
+ cls(item)
9
+ except ValueError:
10
+ return False
11
+ return True
12
+
13
+ @classmethod
14
+ def has_value(cls, value: Any) -> bool:
15
+ return value in cls._value2member_map_
16
+
17
+
18
+ class BaseEnum(Enum, metaclass=_ContainerEnum):
19
+ """Base Enum class, allowing checking if a value is in the enum.
20
+
21
+ Example
22
+ -------
23
+ >>> from careamics.utils.base_enum import BaseEnum
24
+ >>> # Define a new enum
25
+ >>> class BaseEnumExtension(BaseEnum):
26
+ ... VALUE = "value"
27
+ >>> # Check if value is in the enum
28
+ >>> "value" in BaseEnumExtension
29
+ True
30
+ """
31
+
32
+ pass
@@ -9,6 +9,24 @@ from pathlib import Path
9
9
  from typing import Iterator, Union
10
10
 
11
11
 
12
+ def get_careamics_home() -> Path:
13
+ """Return the CAREamics home directory.
14
+
15
+ CAREamics home directory is a hidden folder in home.
16
+
17
+ Returns
18
+ -------
19
+ Path
20
+ CAREamics home directory path.
21
+ """
22
+ home = Path.home() / ".careamics"
23
+
24
+ if not home.exists():
25
+ home.mkdir(parents=True, exist_ok=True)
26
+
27
+ return home
28
+
29
+
12
30
  @contextmanager
13
31
  def cwd(path: Union[str, Path]) -> Iterator[None]:
14
32
  """
@@ -29,8 +47,10 @@ def cwd(path: Union[str, Path]) -> Iterator[None]:
29
47
 
30
48
  Examples
31
49
  --------
32
- >>> with cwd(path):
33
- ... pass
50
+ The context is whcnaged within the block and then restored to the original one.
51
+
52
+ >>> with cwd(my_path):
53
+ ... pass # do something
34
54
  """
35
55
  path = Path(path)
36
56