careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -4,88 +4,98 @@ Metrics submodule.
4
4
  This module contains various metrics and a metrics tracking class.
5
5
  """
6
6
 
7
- from typing import Union
7
+ from typing import Callable, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  import torch
11
- from skimage.metrics import peak_signal_noise_ratio
11
+ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
12
+ from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
12
13
 
14
+ # TODO: does this add additional dependency?
13
15
 
14
- def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float:
16
+
17
+ def psnr(gt: np.ndarray, pred: np.ndarray, data_range: float) -> float:
15
18
  """
16
19
  Peak Signal to Noise Ratio.
17
20
 
18
21
  This method calls skimage.metrics.peak_signal_noise_ratio. See:
19
22
  https://scikit-image.org/docs/dev/api/skimage.metrics.html.
20
23
 
24
+ NOTE: to avoid unwanted behaviors (e.g., data_range inferred from array dtype),
25
+ the data_range parameter is mandatory.
26
+
21
27
  Parameters
22
28
  ----------
23
- gt : NumPy array
24
- Ground truth image.
25
- pred : NumPy array
26
- Predicted image.
27
- range : float, optional
28
- The images pixel range, by default 255.0.
29
+ gt : np.ndarray
30
+ Ground truth array.
31
+ pred : np.ndarray
32
+ Predicted array.
33
+ data_range : float
34
+ The images pixel range.
29
35
 
30
36
  Returns
31
37
  -------
32
38
  float
33
39
  PSNR value.
34
40
  """
35
- return peak_signal_noise_ratio(gt, pred, data_range=range)
41
+ return peak_signal_noise_ratio(gt, pred, data_range=data_range)
36
42
 
37
43
 
38
- def _zero_mean(x: np.ndarray) -> np.ndarray:
44
+ def _zero_mean(x: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
39
45
  """
40
46
  Zero the mean of an array.
41
47
 
42
48
  Parameters
43
49
  ----------
44
- x : NumPy array
50
+ x : numpy.ndarray or torch.Tensor
45
51
  Input array.
46
52
 
47
53
  Returns
48
54
  -------
49
- NumPy array
55
+ numpy.ndarray or torch.Tensor
50
56
  Zero-mean array.
51
57
  """
52
- return x - np.mean(x)
58
+ return x - x.mean()
53
59
 
54
60
 
55
- def _fix_range(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
61
+ def _fix_range(
62
+ gt: Union[np.ndarray, torch.Tensor], x: Union[np.ndarray, torch.Tensor]
63
+ ) -> Union[np.ndarray, torch.Tensor]:
56
64
  """
57
65
  Adjust the range of an array based on a reference ground-truth array.
58
66
 
59
67
  Parameters
60
68
  ----------
61
- gt : np.ndarray
62
- Ground truth image.
63
- x : np.ndarray
69
+ gt : Union[np.ndarray, torch.Tensor]
70
+ Ground truth array.
71
+ x : Union[np.ndarray, torch.Tensor]
64
72
  Input array.
65
73
 
66
74
  Returns
67
75
  -------
68
- np.ndarray
76
+ Union[np.ndarray, torch.Tensor]
69
77
  Range-adjusted array.
70
78
  """
71
- a = np.sum(gt * x) / (np.sum(x * x))
79
+ a = (gt * x).sum() / (x * x).sum()
72
80
  return x * a
73
81
 
74
82
 
75
- def _fix(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
83
+ def _fix(
84
+ gt: Union[np.ndarray, torch.Tensor], x: Union[np.ndarray, torch.Tensor]
85
+ ) -> Union[np.ndarray, torch.Tensor]:
76
86
  """
77
87
  Zero mean a groud truth array and adjust the range of the array.
78
88
 
79
89
  Parameters
80
90
  ----------
81
- gt : np.ndarray
91
+ gt : Union[np.ndarray, torch.Tensor]
82
92
  Ground truth image.
83
- x : np.ndarray
93
+ x : Union[np.ndarray, torch.Tensor]
84
94
  Input array.
85
95
 
86
96
  Returns
87
97
  -------
88
- np.ndarray
98
+ Union[np.ndarray, torch.Tensor]
89
99
  Zero-mean and range-adjusted array.
90
100
  """
91
101
  gt_ = _zero_mean(gt)
@@ -113,3 +123,246 @@ def scale_invariant_psnr(
113
123
  range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
114
124
  gt_ = _zero_mean(gt) / np.std(gt)
115
125
  return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter)
126
+
127
+
128
+ class RunningPSNR:
129
+ """Compute the running PSNR during validation step in training.
130
+
131
+ This class allows to compute the PSNR on the entire validation set
132
+ one batch at the time.
133
+
134
+ Attributes
135
+ ----------
136
+ N : int
137
+ Number of elements seen so far during the epoch.
138
+ mse_sum : float
139
+ Running sum of the MSE over the N elements seen so far.
140
+ max : float
141
+ Running max value of the N target images seen so far.
142
+ min : float
143
+ Running min value of the N target images seen so far.
144
+ """
145
+
146
+ def __init__(self):
147
+ """Constructor."""
148
+ self.N = None
149
+ self.mse_sum = None
150
+ self.max = self.min = None
151
+ self.reset()
152
+
153
+ def reset(self):
154
+ """Reset the running PSNR computation.
155
+
156
+ Usually called at the end of each epoch.
157
+ """
158
+ self.mse_sum = 0
159
+ self.N = 0
160
+ self.max = self.min = None
161
+
162
+ def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None:
163
+ """Update the running PSNR statistics given a new batch.
164
+
165
+ Parameters
166
+ ----------
167
+ rec : torch.Tensor
168
+ Reconstructed batch.
169
+ tar : torch.Tensor
170
+ Target batch.
171
+ """
172
+ ins_max = torch.max(tar).item()
173
+ ins_min = torch.min(tar).item()
174
+ if self.max is None:
175
+ assert self.min is None
176
+ self.max = ins_max
177
+ self.min = ins_min
178
+ else:
179
+ self.max = max(self.max, ins_max)
180
+ self.min = min(self.min, ins_min)
181
+
182
+ mse = (rec - tar) ** 2
183
+ elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
184
+ self.mse_sum += torch.nansum(elementwise_mse)
185
+ self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
186
+
187
+ def get(self) -> Optional[torch.Tensor]:
188
+ """Get the actual PSNR value given the running statistics.
189
+
190
+ Returns
191
+ -------
192
+ Optional[torch.Tensor]
193
+ PSNR value.
194
+ """
195
+ if self.N == 0 or self.N is None:
196
+ return None
197
+ rmse = torch.sqrt(self.mse_sum / self.N)
198
+ return 20 * torch.log10((self.max - self.min) / rmse)
199
+
200
+
201
+ def _range_invariant_multiscale_ssim(
202
+ gt_: Union[np.ndarray, torch.Tensor], pred_: Union[np.ndarray, torch.Tensor]
203
+ ) -> float:
204
+ """Compute range invariant multiscale SSIM for a single channel.
205
+
206
+ The advantage of this metric in comparison to commonly used SSIM is that
207
+ it is invariant to scalar multiplications in the prediction.
208
+ # TODO: Add reference to the paper.
209
+
210
+ NOTE: images fed to this function should have channels dimension as the last one.
211
+
212
+ Parameters
213
+ ----------
214
+ gt_ : Union[np.ndarray, torch.Tensor]
215
+ Ground truth image with shape (N, H, W).
216
+ pred_ : Union[np.ndarray, torch.Tensor]
217
+ Predicted image with shape (N, H, W).
218
+
219
+ Returns
220
+ -------
221
+ float
222
+ Range invariant multiscale SSIM value.
223
+ """
224
+ shape = gt_.shape
225
+ gt_ = torch.Tensor(gt_.reshape((shape[0], -1)))
226
+ pred_ = torch.Tensor(pred_.reshape((shape[0], -1)))
227
+ gt_ = _zero_mean(gt_)
228
+ pred_ = _zero_mean(pred_)
229
+ pred_ = _fix(gt_, pred_)
230
+ pred_ = pred_.reshape(shape)
231
+ gt_ = gt_.reshape(shape)
232
+
233
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
234
+ data_range=gt_.max() - gt_.min()
235
+ )
236
+ return ms_ssim(torch.Tensor(pred_[:, None]), torch.Tensor(gt_[:, None])).item()
237
+
238
+
239
+ def multiscale_ssim(
240
+ gt_: Union[np.ndarray, torch.Tensor],
241
+ pred_: Union[np.ndarray, torch.Tensor],
242
+ range_invariant: bool = True,
243
+ ) -> list[Union[float, None]]:
244
+ """Compute channel-wise multiscale SSIM for each channel.
245
+
246
+ It allows to use either standard multiscale SSIM or its range-invariant version.
247
+
248
+ NOTE: images fed to this function should have channels dimension as the last one.
249
+ # TODO: do we want to allow this behavior? or we want the usual (N, C, H, W)?
250
+
251
+ Parameters
252
+ ----------
253
+ gt_ : Union[np.ndarray, torch.Tensor]
254
+ Ground truth image with shape (N, H, W, C).
255
+ pred_ : Union[np.ndarray, torch.Tensor]
256
+ Predicted image with shape (N, H, W, C).
257
+ range_invariant : bool
258
+ Whether to use standard or range invariant multiscale SSIM.
259
+
260
+ Returns
261
+ -------
262
+ list[float]
263
+ List of SSIM values for each channel.
264
+ """
265
+ ms_ssim_values = {}
266
+ for ch_idx in range(gt_.shape[-1]):
267
+ tar_tmp = gt_[..., ch_idx]
268
+ pred_tmp = pred_[..., ch_idx]
269
+ if range_invariant:
270
+ ms_ssim_values[ch_idx] = _range_invariant_multiscale_ssim(
271
+ gt_=tar_tmp, pred_=pred_tmp
272
+ )
273
+ else:
274
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
275
+ data_range=tar_tmp.max() - tar_tmp.min()
276
+ )
277
+ ms_ssim_values[ch_idx] = ms_ssim(
278
+ torch.Tensor(pred_tmp[:, None]), torch.Tensor(tar_tmp[:, None])
279
+ ).item()
280
+
281
+ return [ms_ssim_values[i] for i in range(gt_.shape[-1])] # type: ignore
282
+
283
+
284
+ def _avg_psnr(target: np.ndarray, prediction: np.ndarray, psnr_fn: Callable) -> float:
285
+ """Compute the average PSNR over a batch of images.
286
+
287
+ Parameters
288
+ ----------
289
+ target : np.ndarray
290
+ Array of ground truth images, shape is (N, C, H, W).
291
+ prediction : np.ndarray
292
+ Array of predicted images, shape is (N, C, H, W).
293
+ psnr_fn : Callable
294
+ PSNR function to use.
295
+
296
+ Returns
297
+ -------
298
+ float
299
+ Average PSNR value over the batch.
300
+ """
301
+ return np.mean(
302
+ [
303
+ psnr_fn(target[i : i + 1], prediction[i : i + 1]).item()
304
+ for i in range(len(prediction))
305
+ ]
306
+ )
307
+
308
+
309
+ def avg_range_inv_psnr(target: np.ndarray, prediction: np.ndarray) -> float:
310
+ """Compute the average range-invariant PSNR over a batch of images.
311
+
312
+ Parameters
313
+ ----------
314
+ target : np.ndarray
315
+ Array of ground truth images, shape is (N, C, H, W).
316
+ prediction : np.ndarray
317
+ Array of predicted images, shape is (N, C, H, W).
318
+
319
+ Returns
320
+ -------
321
+ float
322
+ Average range-invariant PSNR value over the batch.
323
+ """
324
+ return _avg_psnr(target, prediction, scale_invariant_psnr)
325
+
326
+
327
+ def avg_psnr(target: np.ndarray, prediction: np.ndarray) -> float:
328
+ """Compute the average PSNR over a batch of images.
329
+
330
+ Parameters
331
+ ----------
332
+ target : np.ndarray
333
+ Array of ground truth images, shape is (N, C, H, W).
334
+ prediction : np.ndarray
335
+ Array of predicted images, shape is (N, C, H, W).
336
+
337
+ Returns
338
+ -------
339
+ float
340
+ Average PSNR value over the batch.
341
+ """
342
+ return _avg_psnr(target, prediction, psnr)
343
+
344
+
345
+ def avg_ssim(
346
+ target: Union[np.ndarray, torch.Tensor], prediction: Union[np.ndarray, torch.Tensor]
347
+ ) -> tuple[float, float]:
348
+ """Compute the average Structural Similarity (SSIM) over a batch of images.
349
+
350
+ Parameters
351
+ ----------
352
+ target : np.ndarray
353
+ Array of ground truth images, shape is (N, C, H, W).
354
+ prediction : np.ndarray
355
+ Array of predicted images, shape is (N, C, H, W).
356
+
357
+ Returns
358
+ -------
359
+ tuple[float, float]
360
+ Mean and standard deviation of SSIM values over the batch.
361
+ """
362
+ ssim = [
363
+ structural_similarity(
364
+ target[i], prediction[i], data_range=(target[i].max() - target[i].min())
365
+ )
366
+ for i in range(len(target))
367
+ ]
368
+ return np.mean(ssim), np.std(ssim)
@@ -0,0 +1,60 @@
1
+ """A script for serializers in the careamics package."""
2
+
3
+ import ast
4
+ import json
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def _array_to_json(arr: Union[np.ndarray, torch.Tensor]) -> str:
12
+ """Convert an array to a list and then to a JSON string.
13
+
14
+ Parameters
15
+ ----------
16
+ arr : Union[np.ndarray, torch.Tensor]
17
+ Array to be serialized.
18
+
19
+ Returns
20
+ -------
21
+ str
22
+ JSON string representing the array.
23
+ """
24
+ return json.dumps(arr.tolist())
25
+
26
+
27
+ def _to_numpy(lst: Union[str, list]) -> np.ndarray:
28
+ """Deserialize a list or string representing a list into `np.ndarray`.
29
+
30
+ Parameters
31
+ ----------
32
+ lst : list
33
+ List or string representing a list with the array content to be deserialized.
34
+
35
+ Returns
36
+ -------
37
+ np.ndarray
38
+ The deserialized array.
39
+ """
40
+ if isinstance(lst, str):
41
+ lst = ast.literal_eval(lst)
42
+ return np.asarray(lst)
43
+
44
+
45
+ def _to_torch(lst: Union[str, list]) -> torch.Tensor:
46
+ """Deserialize list or string representing a list into `torch.Tensor`.
47
+
48
+ Parameters
49
+ ----------
50
+ lst : Union[str, list]
51
+ List or string representing a list swith the array content to be deserialized.
52
+
53
+ Returns
54
+ -------
55
+ torch.Tensor
56
+ The deserialized tensor.
57
+ """
58
+ if isinstance(lst, str):
59
+ lst = ast.literal_eval(lst)
60
+ return torch.tensor(lst)
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: careamics
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
7
- Author-email: Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
7
+ Author-email: CAREamics team <rse@fht.org>, Ashesh <ashesh.ashesh@fht.org>, Federico Carrara <federico.carrara@fht.org>, Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Vera Galinova <vera.galinova@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
8
8
  License: BSD-3-Clause
9
9
  License-File: LICENSE
10
10
  Classifier: Development Status :: 3 - Alpha
@@ -16,16 +16,17 @@ Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Typing :: Typed
18
18
  Requires-Python: >=3.9
19
- Requires-Dist: bioimageio-core>=0.6.0
19
+ Requires-Dist: bioimageio-core>=0.6.9
20
20
  Requires-Dist: numpy<2.0.0
21
21
  Requires-Dist: psutil
22
- Requires-Dist: pydantic>=2.5
22
+ Requires-Dist: pydantic<2.9,>=2.5
23
23
  Requires-Dist: pytorch-lightning>=2.2.0
24
24
  Requires-Dist: pyyaml
25
25
  Requires-Dist: scikit-image<=0.23.2
26
26
  Requires-Dist: tifffile
27
27
  Requires-Dist: torch>=2.0.0
28
28
  Requires-Dist: torchvision
29
+ Requires-Dist: typer==0.12.3
29
30
  Requires-Dist: zarr<3.0.0
30
31
  Provides-Extra: dev
31
32
  Requires-Dist: pre-commit; extra == 'dev'