careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,34 @@
1
1
  """Normalization and denormalization transforms for image patches."""
2
2
 
3
- from typing import Optional, Tuple
3
+ from typing import Optional
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
9
10
 
11
+ def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
12
+ """Reshape stats to match the number of dimensions of the input image.
13
+
14
+ This allows to broadcast the stats (mean or std) to the image dimensions, and
15
+ thus directly perform a vectorial calculation.
16
+
17
+ Parameters
18
+ ----------
19
+ stats : list of float
20
+ List of stats, mean or standard deviation.
21
+ ndim : int
22
+ Number of dimensions of the image, including the C channel.
23
+
24
+ Returns
25
+ -------
26
+ NDArray
27
+ Reshaped stats.
28
+ """
29
+ return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
30
+
31
+
10
32
  class Normalize(Transform):
11
33
  """
12
34
  Normalize an image or image patch.
@@ -19,154 +41,203 @@ class Normalize(Transform):
19
41
 
20
42
  Parameters
21
43
  ----------
22
- mean : float
23
- Mean value.
24
- std : float
25
- Standard deviation value.
44
+ image_means : list of float
45
+ Mean value per channel.
46
+ image_stds : list of float
47
+ Standard deviation value per channel.
48
+ target_means : list of float, optional
49
+ Target mean value per channel, by default None.
50
+ target_stds : list of float, optional
51
+ Target standard deviation value per channel, by default None.
26
52
 
27
53
  Attributes
28
54
  ----------
29
- mean : float
30
- Mean value.
31
- std : float
32
- Standard deviation value.
55
+ image_means : list of float
56
+ Mean value per channel.
57
+ image_stds : list of float
58
+ Standard deviation value per channel.
59
+ target_means :list of float, optional
60
+ Target mean value per channel, by default None.
61
+ target_stds : list of float, optional
62
+ Target standard deviation value per channel, by default None.
33
63
  """
34
64
 
35
65
  def __init__(
36
66
  self,
37
- mean: float,
38
- std: float,
67
+ image_means: list[float],
68
+ image_stds: list[float],
69
+ target_means: Optional[list[float]] = None,
70
+ target_stds: Optional[list[float]] = None,
39
71
  ):
40
72
  """Constructor.
41
73
 
42
74
  Parameters
43
75
  ----------
44
- mean : float
45
- Mean value.
46
- std : float
47
- Standard deviation value.
76
+ image_means : list of float
77
+ Mean value per channel.
78
+ image_stds : list of float
79
+ Standard deviation value per channel.
80
+ target_means : list of float, optional
81
+ Target mean value per channel, by default None.
82
+ target_stds : list of float, optional
83
+ Target standard deviation value per channel, by default None.
48
84
  """
49
- self.mean = mean
50
- self.std = std
85
+ self.image_means = image_means
86
+ self.image_stds = image_stds
87
+ self.target_means = target_means
88
+ self.target_stds = target_stds
89
+
51
90
  self.eps = 1e-6
52
91
 
53
92
  def __call__(
54
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
55
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
93
+ self, patch: np.ndarray, target: Optional[NDArray] = None
94
+ ) -> tuple[NDArray, Optional[NDArray]]:
56
95
  """Apply the transform to the source patch and the target (optional).
57
96
 
58
97
  Parameters
59
98
  ----------
60
- patch : np.ndarray
99
+ patch : NDArray
61
100
  Patch, 2D or 3D, shape C(Z)YX.
62
- target : Optional[np.ndarray], optional
101
+ target : NDArray, optional
63
102
  Target for the patch, by default None.
64
103
 
65
104
  Returns
66
105
  -------
67
- Tuple[np.ndarray, Optional[np.ndarray]]
68
- Transformed patch and target.
106
+ tuple of NDArray
107
+ Transformed patch and target, the target can be returned as `None`.
69
108
  """
70
- norm_patch = self._apply(patch)
71
- norm_target = self._apply(target) if target is not None else None
109
+ if len(self.image_means) != patch.shape[0]:
110
+ raise ValueError(
111
+ f"Number of means (got a list of size {len(self.image_means)}) and "
112
+ f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
113
+ )
114
+
115
+ # reshape mean and std and apply the normalization to the patch
116
+ means = _reshape_stats(self.image_means, patch.ndim)
117
+ stds = _reshape_stats(self.image_stds, patch.ndim)
118
+ norm_patch = self._apply(patch, means, stds)
119
+
120
+ # same for the target patch
121
+ if (
122
+ target is not None
123
+ and self.target_means is not None
124
+ and self.target_stds is not None
125
+ ):
126
+ target_means = _reshape_stats(self.target_means, target.ndim)
127
+ target_stds = _reshape_stats(self.target_stds, target.ndim)
128
+ norm_target = self._apply(target, target_means, target_stds)
129
+ else:
130
+ norm_target = None
72
131
 
73
132
  return norm_patch, norm_target
74
133
 
75
- def _apply(self, patch: np.ndarray) -> np.ndarray:
134
+ def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
76
135
  """
77
136
  Apply the transform to the image.
78
137
 
79
138
  Parameters
80
139
  ----------
81
- patch : np.ndarray
140
+ patch : NDArray
82
141
  Image patch, 2D or 3D, shape C(Z)YX.
142
+ mean : NDArray
143
+ Mean values.
144
+ std : NDArray
145
+ Standard deviations.
83
146
 
84
147
  Returns
85
148
  -------
86
- np.ndarray
87
- Normalizedimage patch.
149
+ NDArray
150
+ Normalized image patch.
88
151
  """
89
- return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
152
+ return ((patch - mean) / (std + self.eps)).astype(np.float32)
90
153
 
91
154
 
92
155
  class Denormalize:
93
156
  """
94
- Denormalize an image or image patch.
157
+ Denormalize an image.
95
158
 
96
159
  Denormalization is performed expecting a zero mean and unit variance input. This
97
160
  transform expects C(Z)YX dimensions.
98
161
 
99
- Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
162
+ Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
100
163
  division by zero during the normalization step, which is taken into account during
101
164
  denormalization.
102
165
 
103
166
  Parameters
104
167
  ----------
105
- mean : float
106
- Mean value.
107
- std : float
108
- Standard deviation value.
168
+ image_means : list or tuple of float
169
+ Mean value per channel.
170
+ image_stds : list or tuple of float
171
+ Standard deviation value per channel.
109
172
 
110
- Attributes
111
- ----------
112
- mean : float
113
- Mean value.
114
- std : float
115
- Standard deviation value.
116
173
  """
117
174
 
118
175
  def __init__(
119
176
  self,
120
- mean: float,
121
- std: float,
177
+ image_means: list[float],
178
+ image_stds: list[float],
122
179
  ):
123
180
  """Constructor.
124
181
 
125
182
  Parameters
126
183
  ----------
127
- mean : float
128
- Mean.
129
- std : float
130
- Standard deviation.
184
+ image_means : list of float
185
+ Mean value per channel.
186
+ image_stds : list of float
187
+ Standard deviation value per channel.
131
188
  """
132
- self.mean = mean
133
- self.std = std
189
+ self.image_means = image_means
190
+ self.image_stds = image_stds
191
+
134
192
  self.eps = 1e-6
135
193
 
136
- def __call__(
137
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
138
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
139
- """Apply the transform to the source patch and the target (optional).
194
+ def __call__(self, patch: NDArray) -> NDArray:
195
+ """Reverse the normalization operation for a batch of patches.
140
196
 
141
197
  Parameters
142
198
  ----------
143
- patch : np.ndarray
144
- Patch, 2D or 3D, shape C(Z)YX.
145
- target : Optional[np.ndarray], optional
146
- Target for the patch, by default None.
199
+ patch : NDArray
200
+ Patch, 2D or 3D, shape BC(Z)YX.
147
201
 
148
202
  Returns
149
203
  -------
150
- Tuple[np.ndarray, Optional[np.ndarray]]
151
- Transformed patch and target.
204
+ NDArray
205
+ Transformed array.
152
206
  """
153
- norm_patch = self._apply(patch)
154
- norm_target = self._apply(target) if target is not None else None
207
+ if len(self.image_means) != patch.shape[1]:
208
+ raise ValueError(
209
+ f"Number of means (got a list of size {len(self.image_means)}) and "
210
+ f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
211
+ f"match."
212
+ )
155
213
 
156
- return norm_patch, norm_target
214
+ means = _reshape_stats(self.image_means, patch.ndim)
215
+ stds = _reshape_stats(self.image_stds, patch.ndim)
216
+
217
+ denorm_array = self._apply(
218
+ patch,
219
+ np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
220
+ np.swapaxes(stds, 0, 1),
221
+ )
222
+
223
+ return denorm_array.astype(np.float32)
157
224
 
158
- def _apply(self, patch: np.ndarray) -> np.ndarray:
225
+ def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
159
226
  """
160
227
  Apply the transform to the image.
161
228
 
162
229
  Parameters
163
230
  ----------
164
- patch : np.ndarray
231
+ array : NDArray
165
232
  Image patch, 2D or 3D, shape C(Z)YX.
233
+ mean : NDArray
234
+ Mean values.
235
+ std : NDArray
236
+ Standard deviations.
166
237
 
167
238
  Returns
168
239
  -------
169
- np.ndarray
170
- Denormalized image patch.
240
+ NDArray
241
+ Denormalized image array.
171
242
  """
172
- return patch * (self.std + self.eps) + self.mean
243
+ return array * (std + self.eps) + mean
@@ -13,7 +13,10 @@ from .struct_mask_parameters import StructMaskParameters
13
13
 
14
14
 
15
15
  def _apply_struct_mask(
16
- patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
16
+ patch: np.ndarray,
17
+ coords: np.ndarray,
18
+ struct_params: StructMaskParameters,
19
+ rng: Optional[np.random.Generator] = None,
17
20
  ) -> np.ndarray:
18
21
  """Apply structN2V masks to patch.
19
22
 
@@ -31,12 +34,17 @@ def _apply_struct_mask(
31
34
  Coordinates of the ROI(subpatch) centers.
32
35
  struct_params : StructMaskParameters
33
36
  Parameters for the structN2V mask (axis and span).
37
+ rng : np.random.Generator or None
38
+ Random number generator.
34
39
 
35
40
  Returns
36
41
  -------
37
42
  np.ndarray
38
43
  Patch with the structN2V mask applied.
39
44
  """
45
+ if rng is None:
46
+ rng = np.random.default_rng()
47
+
40
48
  # relative axis
41
49
  moving_axis = -1 - struct_params.axis
42
50
 
@@ -67,7 +75,7 @@ def _apply_struct_mask(
67
75
  mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
68
76
 
69
77
  # replace neighbouring pixels with random values from flat dist
70
- patch[tuple(mix.T)] = np.random.uniform(patch.min(), patch.max(), size=mix.shape[0])
78
+ patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
71
79
 
72
80
  return patch
73
81
 
@@ -98,7 +106,9 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
98
106
 
99
107
 
100
108
  def _get_stratified_coords(
101
- mask_pixel_perc: float, shape: Tuple[int, ...]
109
+ mask_pixel_perc: float,
110
+ shape: Tuple[int, ...],
111
+ rng: Optional[np.random.Generator] = None,
102
112
  ) -> np.ndarray:
103
113
  """
104
114
  Generate coordinates of the pixels to mask.
@@ -113,6 +123,8 @@ def _get_stratified_coords(
113
123
  calculating the distance between masked pixels across each axis.
114
124
  shape : Tuple[int, ...]
115
125
  Shape of the input patch.
126
+ rng : np.random.Generator or None
127
+ Random number generator.
116
128
 
117
129
  Returns
118
130
  -------
@@ -124,7 +136,8 @@ def _get_stratified_coords(
124
136
  "Calculating coordinates is only possible for 2D and 3D patches"
125
137
  )
126
138
 
127
- rng = np.random.default_rng()
139
+ if rng is None:
140
+ rng = np.random.default_rng()
128
141
 
129
142
  mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
130
143
  np.int32
@@ -228,6 +241,7 @@ def uniform_manipulate(
228
241
  subpatch_size: int = 11,
229
242
  remove_center: bool = True,
230
243
  struct_params: Optional[StructMaskParameters] = None,
244
+ rng: Optional[np.random.Generator] = None,
231
245
  ) -> Tuple[np.ndarray, np.ndarray]:
232
246
  """
233
247
  Manipulate pixels by replacing them with a neighbor values.
@@ -248,19 +262,23 @@ def uniform_manipulate(
248
262
  Size of the subpatch the new pixel value is sampled from, by default 11.
249
263
  remove_center : bool
250
264
  Whether to remove the center pixel from the subpatch, by default False.
251
- struct_params : Optional[StructMaskParameters]
265
+ struct_params : StructMaskParameters or None
252
266
  Parameters for the structN2V mask (axis and span).
267
+ rng : np.random.Generator or None
268
+ Random number generator.
253
269
 
254
270
  Returns
255
271
  -------
256
272
  Tuple[np.ndarray]
257
273
  Tuple containing the manipulated patch and the corresponding mask.
258
274
  """
275
+ if rng is None:
276
+ rng = np.random.default_rng()
277
+
259
278
  # Get the coordinates of the pixels to be replaced
260
279
  transformed_patch = patch.copy()
261
280
 
262
- subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
263
- rng = np.random.default_rng()
281
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
264
282
 
265
283
  # Generate coordinate grid for subpatch
266
284
  roi_span_full = np.arange(
@@ -303,6 +321,7 @@ def median_manipulate(
303
321
  mask_pixel_percentage: float,
304
322
  subpatch_size: int = 11,
305
323
  struct_params: Optional[StructMaskParameters] = None,
324
+ rng: Optional[np.random.Generator] = None,
306
325
  ) -> Tuple[np.ndarray, np.ndarray]:
307
326
  """
308
327
  Manipulate pixels by replacing them with the median of their surrounding subpatch.
@@ -322,18 +341,23 @@ def median_manipulate(
322
341
  Approximate percentage of pixels to be masked.
323
342
  subpatch_size : int
324
343
  Size of the subpatch the new pixel value is sampled from, by default 11.
325
- struct_params : Optional[StructMaskParameters]
344
+ struct_params : StructMaskParameters or None, optional
326
345
  Parameters for the structN2V mask (axis and span).
346
+ rng : np.random.Generator or None, optional
347
+ Random number generato, by default None.
327
348
 
328
349
  Returns
329
350
  -------
330
351
  Tuple[np.ndarray]
331
352
  Tuple containing the manipulated patch, the original patch and the mask.
332
353
  """
354
+ if rng is None:
355
+ rng = np.random.default_rng()
356
+
333
357
  transformed_patch = patch.copy()
334
358
 
335
359
  # Get the coordinates of the pixels to be replaced
336
- subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
360
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
337
361
 
338
362
  # Generate coordinate grid for subpatch
339
363
  roi_span = np.array(
@@ -1,11 +1,8 @@
1
1
  """Test-time augmentations."""
2
2
 
3
- from typing import List
4
-
5
3
  from torch import Tensor, flip, mean, rot90, stack
6
4
 
7
5
 
8
- # TODO add tests
9
6
  class ImageRestorationTTA:
10
7
  """
11
8
  Test-time augmentation for image restoration tasks.
@@ -13,62 +10,79 @@ class ImageRestorationTTA:
13
10
  The augmentation is performed using all 90 deg rotations and their flipped version,
14
11
  as well as the original image flipped.
15
12
 
16
- Tensors should be of shape SC(Z)YX
13
+ Tensors should be of shape SC(Z)YX.
17
14
 
18
15
  This transformation is used in the LightningModule in order to perform test-time
19
- agumentation.
16
+ augmentation.
20
17
  """
21
18
 
22
- def __init__(self) -> None:
23
- """Constructor."""
24
- pass
25
-
26
- def forward(self, x: Tensor) -> List[Tensor]:
19
+ def forward(self, input_tensor: Tensor) -> list[Tensor]:
27
20
  """
28
21
  Apply test-time augmentation to the input tensor.
29
22
 
30
23
  Parameters
31
24
  ----------
32
- x : Tensor
25
+ input_tensor : Tensor
33
26
  Input tensor, shape SC(Z)YX.
34
27
 
35
28
  Returns
36
29
  -------
37
- List[Tensor]
30
+ list of torch.Tensor
38
31
  List of augmented tensors.
39
32
  """
33
+ # axes: only applies to YX axes
34
+ axes = (-2, -1)
35
+
40
36
  augmented = [
41
- x,
42
- rot90(x, 1, dims=(-2, -1)),
43
- rot90(x, 2, dims=(-2, -1)),
44
- rot90(x, 3, dims=(-2, -1)),
37
+ # original
38
+ input_tensor,
39
+ # rotations
40
+ rot90(input_tensor, 1, dims=axes),
41
+ rot90(input_tensor, 2, dims=axes),
42
+ rot90(input_tensor, 3, dims=axes),
43
+ # original flipped
44
+ flip(input_tensor, dims=(axes[0],)),
45
+ flip(input_tensor, dims=(axes[1],)),
45
46
  ]
46
- augmented_flip = augmented.copy()
47
- for x_ in augmented:
48
- augmented_flip.append(flip(x_, dims=(-3, -1)))
49
- return augmented_flip
50
47
 
51
- def backward(self, x: List[Tensor]) -> Tensor:
48
+ # rotated once, flipped
49
+ augmented.extend(
50
+ [
51
+ flip(augmented[1], dims=(axes[0],)),
52
+ flip(augmented[1], dims=(axes[1],)),
53
+ ]
54
+ )
55
+
56
+ return augmented
57
+
58
+ def backward(self, x: list[Tensor]) -> Tensor:
52
59
  """Undo the test-time augmentation.
53
60
 
54
61
  Parameters
55
62
  ----------
56
63
  x : Any
57
- List of augmented tensors.
64
+ List of augmented tensors of shape SC(Z)YX.
58
65
 
59
66
  Returns
60
67
  -------
61
68
  Any
62
69
  Original tensor.
63
70
  """
71
+ axes = (-2, -1)
72
+
64
73
  reverse = [
74
+ # original
65
75
  x[0],
66
- rot90(x[1], -1, dims=(-2, -1)),
67
- rot90(x[2], -2, dims=(-2, -1)),
68
- rot90(x[3], -3, dims=(-2, -1)),
69
- flip(x[4], dims=(-3, -1)),
70
- rot90(flip(x[5], dims=(-3, -1)), -1, dims=(-2, -1)),
71
- rot90(flip(x[6], dims=(-3, -1)), -2, dims=(-2, -1)),
72
- rot90(flip(x[7], dims=(-3, -1)), -3, dims=(-2, -1)),
76
+ # rotated
77
+ rot90(x[1], -1, dims=axes),
78
+ rot90(x[2], -2, dims=axes),
79
+ rot90(x[3], -3, dims=axes),
80
+ # original flipped
81
+ flip(x[4], dims=(axes[0],)),
82
+ flip(x[5], dims=(axes[1],)),
83
+ # rotated once, flipped
84
+ rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
85
+ rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
73
86
  ]
87
+
74
88
  return mean(stack(reverse), dim=0)
@@ -7,9 +7,11 @@ __all__ = [
7
7
  "BaseEnum",
8
8
  "get_logger",
9
9
  "get_careamics_home",
10
+ "autocorrelation",
10
11
  ]
11
12
 
12
13
 
14
+ from .autocorrelation import autocorrelation
13
15
  from .base_enum import BaseEnum
14
16
  from .context import cwd, get_careamics_home
15
17
  from .logging import get_logger
@@ -0,0 +1,40 @@
1
+ """Autocorrelation function."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ def autocorrelation(image: NDArray) -> NDArray:
8
+ """Compute the autocorrelation of an image.
9
+
10
+ This method is used to explore spatial correlations in images,
11
+ in particular in the noise.
12
+
13
+ The autocorrelation is normalized to the zero-shift value, which is centered in
14
+ the resulting images.
15
+
16
+ Parameters
17
+ ----------
18
+ image : NDArray
19
+ Input image.
20
+
21
+ Returns
22
+ -------
23
+ numpy.ndarray
24
+ Autocorrelation of the input image.
25
+ """
26
+ # normalize image
27
+ image = (image - np.mean(image)) / np.std(image)
28
+
29
+ # compute autocorrelation in fourier space
30
+ image = np.fft.fftn(image)
31
+ image = np.abs(image) ** 2
32
+ image = np.fft.ifftn(image).real
33
+
34
+ # normalize to zero shift value
35
+ image = image / image.flat[0]
36
+
37
+ # shift zero frequency to center
38
+ image = np.fft.fftshift(image)
39
+
40
+ return image
careamics/utils/ram.py CHANGED
@@ -5,11 +5,11 @@ import psutil
5
5
 
6
6
  def get_ram_size() -> int:
7
7
  """
8
- Get RAM size in bytes.
8
+ Get RAM size in mbytes.
9
9
 
10
10
  Returns
11
11
  -------
12
12
  int
13
13
  RAM size in mbytes.
14
14
  """
15
- return psutil.virtual_memory().total / 1024**2
15
+ return psutil.virtual_memory().available / 1024**2
@@ -1,31 +1,32 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: careamics
3
- Version: 0.1.0rc6
3
+ Version: 0.1.0rc8
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
7
- Author-email: Igor Zubarev <igor.zubarev@fht.org>, Joran Deschamps <joran.deschamps@fht.org>
7
+ Author-email: Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
8
8
  License: BSD-3-Clause
9
9
  License-File: LICENSE
10
10
  Classifier: Development Status :: 3 - Alpha
11
11
  Classifier: License :: OSI Approved :: BSD License
12
12
  Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.8
14
13
  Classifier: Programming Language :: Python :: 3.9
15
14
  Classifier: Programming Language :: Python :: 3.10
16
15
  Classifier: Programming Language :: Python :: 3.11
17
16
  Classifier: Programming Language :: Python :: 3.12
18
17
  Classifier: Typing :: Typed
19
- Requires-Python: >=3.8
18
+ Requires-Python: >=3.9
20
19
  Requires-Dist: bioimageio-core>=0.6.0
20
+ Requires-Dist: numpy<2.0.0
21
21
  Requires-Dist: psutil
22
22
  Requires-Dist: pydantic>=2.5
23
23
  Requires-Dist: pytorch-lightning>=2.2.0
24
24
  Requires-Dist: pyyaml
25
- Requires-Dist: scikit-image
25
+ Requires-Dist: scikit-image<=0.23.2
26
26
  Requires-Dist: tifffile
27
27
  Requires-Dist: torch>=2.0.0
28
- Requires-Dist: zarr
28
+ Requires-Dist: torchvision
29
+ Requires-Dist: zarr<3.0.0
29
30
  Provides-Extra: dev
30
31
  Requires-Dist: pre-commit; extra == 'dev'
31
32
  Requires-Dist: pytest; extra == 'dev'