careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (59) hide show
  1. careamics/careamist.py +6 -12
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +9 -8
  12. careamics/config/configuration_factories.py +843 -29
  13. careamics/config/data/data_model.py +1 -2
  14. careamics/config/data/ng_data_model.py +1 -2
  15. careamics/config/inference_model.py +1 -2
  16. careamics/config/likelihood_model.py +2 -2
  17. careamics/config/loss_model.py +6 -2
  18. careamics/config/nm_model.py +26 -1
  19. careamics/config/optimizer_models.py +1 -2
  20. careamics/config/support/supported_algorithms.py +5 -3
  21. careamics/config/support/supported_losses.py +5 -2
  22. careamics/config/training_model.py +6 -36
  23. careamics/config/transformations/normalize_model.py +1 -2
  24. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  25. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  26. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  27. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  28. careamics/file_io/read/__init__.py +0 -1
  29. careamics/lightning/__init__.py +16 -2
  30. careamics/lightning/callbacks/__init__.py +2 -0
  31. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  32. careamics/lightning/lightning_module.py +161 -61
  33. careamics/lightning/microsplit_data_module.py +631 -0
  34. careamics/lightning/predict_data_module.py +8 -1
  35. careamics/lightning/train_data_module.py +19 -8
  36. careamics/losses/__init__.py +7 -1
  37. careamics/losses/loss_factory.py +9 -1
  38. careamics/losses/lvae/losses.py +85 -0
  39. careamics/lvae_training/dataset/__init__.py +8 -8
  40. careamics/lvae_training/dataset/config.py +56 -44
  41. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  42. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  43. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  44. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  45. careamics/model_io/bmz_io.py +9 -5
  46. careamics/models/lvae/likelihoods.py +30 -14
  47. careamics/models/lvae/lvae.py +2 -2
  48. careamics/models/lvae/noise_models.py +20 -14
  49. careamics/prediction_utils/__init__.py +8 -2
  50. careamics/prediction_utils/prediction_outputs.py +48 -3
  51. careamics/prediction_utils/stitch_prediction.py +71 -0
  52. careamics/transforms/xy_random_rotate90.py +1 -1
  53. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
  54. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
  55. careamics/dataset/zarr_dataset.py +0 -151
  56. careamics/file_io/read/zarr.py +0 -60
  57. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  58. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  59. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -54,12 +54,8 @@ def likelihood_factory(
54
54
  )
55
55
  elif isinstance(config, NMLikelihoodConfig):
56
56
  return NoiseModelLikelihood(
57
- data_mean=config.data_mean,
58
- data_std=config.data_std,
59
57
  noise_model=noise_model,
60
58
  )
61
- else:
62
- raise ValueError(f"Invalid likelihood model type: {config.model_type}")
63
59
 
64
60
 
65
61
  # TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
@@ -290,27 +286,40 @@ class NoiseModelLikelihood(LikelihoodModule):
290
286
 
291
287
  def __init__(
292
288
  self,
293
- data_mean: Union[np.ndarray, torch.Tensor],
294
- data_std: Union[np.ndarray, torch.Tensor],
295
289
  noise_model: NoiseModel,
296
290
  ):
297
291
  """Constructor.
298
292
 
299
293
  Parameters
300
294
  ----------
301
- data_mean: Union[np.ndarray, torch.Tensor]
302
- The mean of the data, used to unnormalize data for noise model evaluation.
303
- data_std: Union[np.ndarray, torch.Tensor]
304
- The standard deviation of the data, used to unnormalize data for noise
305
- model evaluation.
306
295
  noiseModel: NoiseModel
307
296
  The noise model instance used to compute the likelihood.
308
297
  """
309
298
  super().__init__()
310
- self.data_mean = torch.Tensor(data_mean)
311
- self.data_std = torch.Tensor(data_std)
299
+ self.data_mean = None
300
+ self.data_std = None
312
301
  self.noiseModel = noise_model
313
302
 
303
+ def set_data_stats(
304
+ self,
305
+ data_mean: Union[np.ndarray, torch.Tensor],
306
+ data_std: Union[np.ndarray, torch.Tensor],
307
+ ) -> None:
308
+ """Set the data mean and std for denormalization.
309
+ # TODO check this !!
310
+ Parameters
311
+ ----------
312
+ data_mean : Union[np.ndarray, torch.Tensor]
313
+ Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
314
+ data_std : Union[np.ndarray, torch.Tensor]
315
+ Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
316
+ """
317
+ # Convert to tensor if needed
318
+ self.data_mean = torch.as_tensor(data_mean, dtype=torch.float32)
319
+ self.data_std = torch.as_tensor(data_std, dtype=torch.float32)
320
+
321
+ # TODO add extra dim for 3D ?
322
+
314
323
  def _set_params_to_same_device_as(
315
324
  self, correct_device_tensor: torch.Tensor
316
325
  ) -> None:
@@ -321,7 +330,10 @@ class NoiseModelLikelihood(LikelihoodModule):
321
330
  correct_device_tensor: torch.Tensor
322
331
  The tensor whose device is used to set the parameters.
323
332
  """
324
- if self.data_mean.device != correct_device_tensor.device:
333
+ if (
334
+ self.data_mean is not None
335
+ and self.data_mean.device != correct_device_tensor.device
336
+ ):
325
337
  self.data_mean = self.data_mean.to(correct_device_tensor.device)
326
338
  self.data_std = self.data_std.to(correct_device_tensor.device)
327
339
  if correct_device_tensor.device != self.noiseModel.device:
@@ -367,6 +379,10 @@ class NoiseModelLikelihood(LikelihoodModule):
367
379
  torch.Tensor
368
380
  The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
369
381
  """
382
+ if self.data_mean is None or self.data_std is None:
383
+ raise RuntimeError(
384
+ "NoiseModelLikelihood: data_mean and data_std must be set before calling log_likelihood."
385
+ )
370
386
  self._set_params_to_same_device_as(x)
371
387
  predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
372
388
  x_denormalized = x * self.data_std + self.data_mean
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
6
6
  """
7
7
 
8
8
  from collections.abc import Iterable
9
- from typing import Optional, Union
9
+ from typing import Union
10
10
 
11
11
  import numpy as np
12
12
  import torch
@@ -835,7 +835,7 @@ class LadderVAE(nn.Module):
835
835
  top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
836
836
  return top_layer_shape
837
837
 
838
- def reset_for_inference(self, tile_size: Optional[tuple[int, int]] = None):
838
+ def reset_for_inference(self, tile_size: tuple[int, int] | None = None):
839
839
  """Should be called if we want to predict for a different input/output size."""
840
840
  self.mode_pred = True
841
841
  if tile_size is None:
@@ -3,10 +3,10 @@ from __future__ import annotations
3
3
  import os
4
4
  from typing import TYPE_CHECKING, Optional
5
5
 
6
- from numpy.typing import NDArray
7
6
  import numpy as np
8
7
  import torch
9
8
  import torch.nn as nn
9
+ from numpy.typing import NDArray
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
@@ -355,16 +355,16 @@ class GaussianMixtureNoiseModel(nn.Module):
355
355
 
356
356
  Parameters
357
357
  ----------
358
- x: Tensor
359
- Observations
360
- mean: Tensor
361
- Mean
362
- std: Tensor
363
- Standard-deviation
358
+ x: torch.Tensor
359
+ The ground-truth tensor. Shape is (batch, 1, dim1, dim2).
360
+ mean: torch.Tensor
361
+ The inferred mean of distribution. Shape is (batch, 1, dim1, dim2).
362
+ std: torch.Tensor
363
+ The inferred standard deviation of distribution. Shape is (batch, 1, dim1, dim2).
364
364
 
365
365
  Returns
366
366
  -------
367
- tmp: Tensor
367
+ tmp: torch.Tensor
368
368
  Normal probability density of `x` given `mean` and `std`
369
369
  """
370
370
  tmp = -((x - mean) ** 2)
@@ -382,9 +382,9 @@ class GaussianMixtureNoiseModel(nn.Module):
382
382
  Parameters
383
383
  ----------
384
384
  observations : Tensor
385
- Noisy observations
385
+ Noisy observations. Shape is (batch, 1, dim1, dim2).
386
386
  signals : Tensor
387
- Underlying signals
387
+ Underlying signals. Shape is (batch, 1, dim1, dim2).
388
388
 
389
389
  Returns
390
390
  -------
@@ -392,15 +392,21 @@ class GaussianMixtureNoiseModel(nn.Module):
392
392
  Likelihood of observations given the signals and the GMM noise model
393
393
  """
394
394
  gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
395
- p = 0
395
+ p = torch.zeros_like(observations)
396
396
  for gaussian in range(self.n_gaussian):
397
+ # Ensure all tensors have compatible shapes
398
+ mean = gaussian_parameters[gaussian]
399
+ std = gaussian_parameters[self.n_gaussian + gaussian]
400
+ weight = gaussian_parameters[2 * self.n_gaussian + gaussian]
401
+
402
+ # Compute normal density
397
403
  p += (
398
404
  self.normal_density(
399
405
  observations,
400
- gaussian_parameters[gaussian],
401
- gaussian_parameters[self.n_gaussian + gaussian],
406
+ mean,
407
+ std,
402
408
  )
403
- * gaussian_parameters[2 * self.n_gaussian + gaussian]
409
+ * weight
404
410
  )
405
411
  return p + self.tolerance
406
412
 
@@ -2,9 +2,15 @@
2
2
 
3
3
  __all__ = [
4
4
  "convert_outputs",
5
+ "convert_outputs_microsplit",
5
6
  "stitch_prediction",
6
7
  "stitch_prediction_single",
8
+ "stitch_prediction_vae",
7
9
  ]
8
10
 
9
- from .prediction_outputs import convert_outputs
10
- from .stitch_prediction import stitch_prediction, stitch_prediction_single
11
+ from .prediction_outputs import convert_outputs, convert_outputs_microsplit
12
+ from .stitch_prediction import (
13
+ stitch_prediction,
14
+ stitch_prediction_single,
15
+ stitch_prediction_vae,
16
+ )
@@ -6,7 +6,7 @@ import numpy as np
6
6
  from numpy.typing import NDArray
7
7
 
8
8
  from ..config.tile_information import TileInformation
9
- from .stitch_prediction import stitch_prediction
9
+ from .stitch_prediction import stitch_prediction, stitch_prediction_vae
10
10
 
11
11
 
12
12
  def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
@@ -41,6 +41,48 @@ def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
41
41
  return predictions_output
42
42
 
43
43
 
44
+ def convert_outputs_microsplit(
45
+ predictions: list[tuple[NDArray, NDArray]], dataset
46
+ ) -> tuple[NDArray, NDArray]:
47
+ """
48
+ Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
49
+
50
+ This function processes microsplit predictions that return (tile_prediction, tile_std) tuples
51
+ and stitches them back together using the same logic as get_single_file_mmse.
52
+
53
+ Parameters
54
+ ----------
55
+ predictions : list of tuple[NDArray, NDArray]
56
+ Predictions from Lightning trainer for microsplit. Each element is a tuple of
57
+ (tile_prediction, tile_std) where both are numpy arrays from predict_step.
58
+ dataset : Dataset
59
+ The dataset object used for prediction, needed for stitching function selection
60
+ and stitching process.
61
+
62
+ Returns
63
+ -------
64
+ tuple[NDArray, NDArray]
65
+ A tuple of (stitched_predictions, stitched_stds) representing the full
66
+ stitched predictions and standard deviations.
67
+ """
68
+ if len(predictions) == 0:
69
+ raise ValueError("No predictions provided")
70
+
71
+ # Separate predictions and stds from the list of tuples
72
+ tile_predictions = [pred for pred, _ in predictions]
73
+ tile_stds = [std for _, std in predictions]
74
+
75
+ # Concatenate all tiles exactly like get_single_file_mmse
76
+ tiles_arr = np.concatenate(tile_predictions, axis=0)
77
+ tile_stds_arr = np.concatenate(tile_stds, axis=0)
78
+
79
+ # Apply stitching using stitch_predictions_new
80
+ stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
81
+ stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)
82
+
83
+ return stitched_predictions, stitched_stds
84
+
85
+
44
86
  # for mypy
45
87
  @overload
46
88
  def combine_batches( # numpydoc ignore=GL08
@@ -68,6 +110,8 @@ def combine_batches(
68
110
  """
69
111
  If predictions are in batches, they will be combined.
70
112
 
113
+ # TODO improve description!
114
+
71
115
  Parameters
72
116
  ----------
73
117
  predictions : list
@@ -107,11 +151,12 @@ def _combine_tiled_batches(
107
151
  """
108
152
  # turn list of lists into single list
109
153
  tile_infos = [
110
- tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
154
+ tile_info for *_, tile_info_list in predictions for tile_info in tile_info_list
111
155
  ]
112
156
  prediction_tiles: list[NDArray] = _combine_array_batches(
113
- [preds for preds, _ in predictions]
157
+ [preds for preds, *_ in predictions]
114
158
  )
159
+
115
160
  return prediction_tiles, tile_infos
116
161
 
117
162
 
@@ -9,6 +9,77 @@ from numpy.typing import NDArray
9
9
  from careamics.config.tile_information import TileInformation
10
10
 
11
11
 
12
+ class TilingMode:
13
+ """Enum for the tiling mode."""
14
+
15
+ TrimBoundary = 0
16
+ PadBoundary = 1
17
+ ShiftBoundary = 2
18
+
19
+
20
+ def stitch_prediction_vae(predictions, dset):
21
+ """
22
+ Stitch predictions back together using dataset's index manager.
23
+
24
+ Args:
25
+ predictions: Array of predictions with shape (n_tiles, channels, height, width)
26
+ dset: Dataset object with idx_manager containing tiling information
27
+ """
28
+ mng = dset.idx_manager
29
+
30
+ # if there are more channels, use all of them.
31
+ shape = list(dset.get_data_shape())
32
+ shape[-1] = max(shape[-1], predictions.shape[1])
33
+
34
+ output = np.zeros(shape, dtype=predictions.dtype)
35
+ # frame_shape = dset.get_data_shape()[:-1]
36
+ for dset_idx in range(predictions.shape[0]):
37
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2], predictions.shape[-1])
38
+ # grid start, grid end
39
+ gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
40
+ ge = gs + mng.grid_shape
41
+
42
+ # patch start, patch end
43
+ ps = gs - mng.patch_offset()
44
+ pe = ps + mng.patch_shape
45
+
46
+ # valid grid start, valid grid end
47
+ vgs = np.array([max(0, x) for x in gs], dtype=int)
48
+ vge = np.array(
49
+ [min(x, y) for x, y in zip(ge, mng.data_shape, strict=False)], dtype=int
50
+ )
51
+
52
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
53
+ for dim in range(len(vgs)):
54
+ if ps[dim] == 0:
55
+ vgs[dim] = 0
56
+ if pe[dim] == mng.data_shape[dim]:
57
+ vge[dim] = mng.data_shape[dim]
58
+
59
+ # relative start, relative end. This will be used on pred_tiled
60
+ rs = vgs - ps
61
+ re = rs + (vge - vgs)
62
+
63
+ for ch_idx in range(predictions.shape[1]):
64
+ if len(output.shape) == 4:
65
+ # channel dimension is the last one.
66
+ output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
67
+ predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
68
+ )
69
+ elif len(output.shape) == 5:
70
+ # channel dimension is the last one.
71
+ assert vge[0] - vgs[0] == 1, "Only one frame is supported"
72
+ output[
73
+ vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
74
+ ] = predictions[dset_idx][
75
+ ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
76
+ ]
77
+ else:
78
+ raise ValueError(f"Unsupported shape {output.shape}")
79
+
80
+ return output
81
+
82
+
12
83
  # TODO: why not allow input and output of torch.tensor ?
13
84
  def stitch_prediction(
14
85
  tiles: list[np.ndarray],
@@ -74,7 +74,7 @@ class XYRandomRotate90(Transform):
74
74
  return patch, target, additional_arrays
75
75
 
76
76
  # number of rotations
77
- n_rot = self.rng.integers(1, 4)
77
+ n_rot = int(self.rng.integers(1, 4))
78
78
 
79
79
  axes = (-2, -1)
80
80
  patch_transformed = self._apply(patch, n_rot, axes)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: careamics
3
- Version: 0.0.15
3
+ Version: 0.0.16
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
@@ -10,28 +10,31 @@ License-File: LICENSE
10
10
  Classifier: Development Status :: 3 - Alpha
11
11
  Classifier: License :: OSI Approved :: BSD License
12
12
  Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.10
14
13
  Classifier: Programming Language :: Python :: 3.11
15
14
  Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
16
  Classifier: Typing :: Typed
17
- Requires-Python: >=3.10
18
- Requires-Dist: bioimageio-core==0.9.0
19
- Requires-Dist: matplotlib<=3.10.3
20
- Requires-Dist: numpy<2.0.0
21
- Requires-Dist: pillow<=11.2.1
22
- Requires-Dist: psutil<=7.0.0
17
+ Requires-Python: >=3.11
18
+ Requires-Dist: bioimageio-core>=0.9.0
19
+ Requires-Dist: matplotlib<=3.10.6
20
+ Requires-Dist: numpy>=1.21
21
+ Requires-Dist: numpy>=2.1.0; python_version >= '3.13'
22
+ Requires-Dist: pillow<=11.3.0
23
+ Requires-Dist: psutil<=7.1.0
23
24
  Requires-Dist: pydantic<=2.12,>=2.11
24
- Requires-Dist: pytorch-lightning<=2.5.2,>=2.2
25
- Requires-Dist: pyyaml!=6.0.0,<=6.0.2
25
+ Requires-Dist: pytorch-lightning<=2.5.5,>=2.2
26
+ Requires-Dist: pyyaml!=6.0.0,<=6.0.3
26
27
  Requires-Dist: scikit-image<=0.25.2
27
- Requires-Dist: tifffile<=2025.5.10
28
- Requires-Dist: torch<=2.7.1,>=2.0
29
- Requires-Dist: torchvision<=0.22.1
30
- Requires-Dist: typer<=0.16.0,>=0.12.3
31
- Requires-Dist: zarr<3.0.0
28
+ Requires-Dist: tifffile<=2025.9.30
29
+ Requires-Dist: torch<=2.8.0,>=2.0
30
+ Requires-Dist: torchvision<=0.23.0
31
+ Requires-Dist: typer<=0.19.2,>=0.12.3
32
+ Requires-Dist: validators<=0.35.0
33
+ Requires-Dist: zarr<4.0.0,>=3.0.0
32
34
  Provides-Extra: czi
33
35
  Requires-Dist: pylibczirw<6.0.0,>=4.1.2; extra == 'czi'
34
36
  Provides-Extra: dev
37
+ Requires-Dist: ml-dtypes>=0.5.0; extra == 'dev'
35
38
  Requires-Dist: onnx; extra == 'dev'
36
39
  Requires-Dist: pre-commit; extra == 'dev'
37
40
  Requires-Dist: pytest; extra == 'dev'