careamics 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (55) hide show
  1. careamics/careamist.py +25 -17
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/architectures/lvae_model.py +0 -4
  6. careamics/config/configuration_factory.py +480 -177
  7. careamics/config/configuration_model.py +1 -2
  8. careamics/config/data_model.py +1 -15
  9. careamics/config/fcn_algorithm_model.py +14 -9
  10. careamics/config/likelihood_model.py +21 -4
  11. careamics/config/nm_model.py +31 -5
  12. careamics/config/optimizer_models.py +3 -1
  13. careamics/config/support/supported_optimizers.py +1 -1
  14. careamics/config/support/supported_transforms.py +1 -0
  15. careamics/config/training_model.py +35 -6
  16. careamics/config/transformations/__init__.py +4 -1
  17. careamics/config/transformations/transform_union.py +20 -0
  18. careamics/config/vae_algorithm_model.py +2 -36
  19. careamics/dataset/tiling/lvae_tiled_patching.py +90 -8
  20. careamics/lightning/lightning_module.py +10 -8
  21. careamics/lightning/train_data_module.py +2 -2
  22. careamics/losses/loss_factory.py +3 -3
  23. careamics/losses/lvae/losses.py +2 -2
  24. careamics/lvae_training/dataset/__init__.py +15 -0
  25. careamics/lvae_training/dataset/{vae_data_config.py → config.py} +25 -81
  26. careamics/lvae_training/dataset/lc_dataset.py +28 -20
  27. careamics/lvae_training/dataset/{vae_dataset.py → multich_dataset.py} +91 -51
  28. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  29. careamics/lvae_training/dataset/types.py +43 -0
  30. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  31. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  32. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  33. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  34. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  35. careamics/lvae_training/eval_utils.py +109 -64
  36. careamics/lvae_training/get_config.py +1 -1
  37. careamics/lvae_training/train_lvae.py +1 -1
  38. careamics/model_io/bmz_io.py +1 -1
  39. careamics/models/lvae/likelihoods.py +18 -9
  40. careamics/models/lvae/lvae.py +12 -16
  41. careamics/models/lvae/noise_models.py +1 -1
  42. careamics/transforms/compose.py +90 -15
  43. careamics/transforms/n2v_manipulate.py +6 -2
  44. careamics/transforms/normalize.py +14 -3
  45. careamics/transforms/xy_flip.py +16 -6
  46. careamics/transforms/xy_random_rotate90.py +16 -7
  47. careamics/utils/metrics.py +204 -24
  48. careamics/utils/serializers.py +60 -0
  49. {careamics-0.0.3.dist-info → careamics-0.0.4.dist-info}/METADATA +4 -3
  50. {careamics-0.0.3.dist-info → careamics-0.0.4.dist-info}/RECORD +53 -42
  51. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  52. careamics/lvae_training/dataset/data_utils.py +0 -701
  53. careamics/lvae_training/dataset/lc_dataset_config.py +0 -13
  54. {careamics-0.0.3.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  55. {careamics-0.0.3.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,7 @@ from __future__ import annotations
7
7
  import math
8
8
  from typing import Literal, Union, TYPE_CHECKING, Any, Optional
9
9
 
10
+ import numpy as np
10
11
  import torch
11
12
  from torch import nn
12
13
 
@@ -287,30 +288,37 @@ class NoiseModelLikelihood(LikelihoodModule):
287
288
 
288
289
  def __init__(
289
290
  self,
290
- data_mean: torch.Tensor,
291
- data_std: torch.Tensor,
292
- noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports...
291
+ data_mean: Union[np.ndarray, torch.Tensor],
292
+ data_std: Union[np.ndarray, torch.Tensor],
293
+ noiseModel: NoiseModel,
293
294
  ):
294
295
  """Constructor.
295
296
 
296
297
  Parameters
297
298
  ----------
298
- data_mean: torch.Tensor
299
+ data_mean: Union[np.ndarray, torch.Tensor]
299
300
  The mean of the data, used to unnormalize data for noise model evaluation.
300
- data_std: torch.Tensor
301
+ data_std: Union[np.ndarray, torch.Tensor]
301
302
  The standard deviation of the data, used to unnormalize data for noise
302
303
  model evaluation.
303
304
  noiseModel: NoiseModel
304
305
  The noise model instance used to compute the likelihood.
305
306
  """
306
307
  super().__init__()
307
- self.data_mean = data_mean
308
- self.data_std = data_std
308
+ self.data_mean = torch.Tensor(data_mean)
309
+ self.data_std = torch.Tensor(data_std)
309
310
  self.noiseModel = noiseModel
310
311
 
311
- def set_params_to_same_device_as(
312
+ def _set_params_to_same_device_as(
312
313
  self, correct_device_tensor: torch.Tensor
313
- ) -> None: # TODO: needed?
314
+ ) -> None:
315
+ """Set the parameters to the same device as the input tensor.
316
+
317
+ Parameters
318
+ ----------
319
+ correct_device_tensor: torch.Tensor
320
+ The tensor whose device is used to set the parameters.
321
+ """
314
322
  if self.data_mean.device != correct_device_tensor.device:
315
323
  self.data_mean = self.data_mean.to(correct_device_tensor.device)
316
324
  self.data_std = self.data_std.to(correct_device_tensor.device)
@@ -355,6 +363,7 @@ class NoiseModelLikelihood(LikelihoodModule):
355
363
  torch.Tensor
356
364
  The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
357
365
  """
366
+ self._set_params_to_same_device_as(x)
358
367
  predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
359
368
  x_denormalized = x * self.data_std + self.data_mean
360
369
  likelihoods = self.noiseModel.likelihood(
@@ -38,7 +38,6 @@ class LadderVAE(nn.Module):
38
38
  decoder_dropout: float,
39
39
  nonlinearity: str,
40
40
  predict_logvar: bool,
41
- enable_noise_model: bool,
42
41
  analytical_kl: bool,
43
42
  ):
44
43
  """
@@ -62,15 +61,12 @@ class LadderVAE(nn.Module):
62
61
  self.decoder_dropout = decoder_dropout
63
62
  self.nonlin = nonlinearity
64
63
  self.predict_logvar = predict_logvar
65
- self.enable_noise_model = enable_noise_model
66
-
67
64
  self.analytical_kl = analytical_kl
68
65
  # -------------------------------------------------------
69
66
 
70
67
  # -------------------------------------------------------
71
68
  # Model attributes -> Hardcoded
72
69
  self.model_type = ModelType.LadderVae # TODO remove !
73
- self.model_type = ModelType.LadderVae # TODO remove !
74
70
  self.encoder_blocks_per_layer = 1
75
71
  self.decoder_blocks_per_layer = 1
76
72
  self.bottomup_batchnorm = True
@@ -94,13 +90,6 @@ class LadderVAE(nn.Module):
94
90
  self._stochastic_use_naive_exponential = False
95
91
  self._enable_topdown_normalize_factor = True
96
92
 
97
- # Noise model attributes -> Hardcoded
98
- self.noise_model_type = "gmm"
99
- self.denoise_channel = (
100
- "input" # 4 values for denoise_channel {'Ch1', 'Ch2', 'input','all'}
101
- )
102
- self.noise_model_learnable = False
103
-
104
93
  # Attributes that handle LC -> Hardcoded
105
94
  self.enable_multiscale = (
106
95
  self._multiscale_count is not None and self._multiscale_count > 1
@@ -806,11 +795,18 @@ class LadderVAE(nn.Module):
806
795
 
807
796
  # return samples
808
797
 
809
- # def reset_for_different_output_size(self, output_size):
810
- # for i in range(self.n_layers):
811
- # sz = output_size // 2**(1 + i)
812
- # self.bottom_up_layers[i].output_expected_shape = (sz, sz)
813
- # self.top_down_layers[i].latent_shape = (output_size, output_size)
798
+ def reset_for_different_output_size(self, output_size: int) -> None:
799
+ """Reset shape of output and latent tensors for different output size.
800
+
801
+ Used during evaluation to reset expected shapes of tensors when
802
+ input/output shape changes.
803
+ For instance, it is needed when the model was trained on, say, 64x64 sized
804
+ patches, but prediction is done on 128x128 patches.
805
+ """
806
+ for i in range(self.n_layers):
807
+ sz = output_size // 2 ** (1 + i)
808
+ self.bottom_up_layers[i].output_expected_shape = (sz, sz)
809
+ self.top_down_layers[i].latent_shape = (output_size, output_size)
814
810
 
815
811
  def pad_input(self, x):
816
812
  """
@@ -76,7 +76,7 @@ def train_gm_noise_model(
76
76
  # TODO any training params ? Different channels ?
77
77
  noise_model = GaussianMixtureNoiseModel(model_config)
78
78
  # TODO revisit config unpacking
79
- noise_model.train_noise_model(noise_model.signal, noise_model.observation)
79
+ noise_model.train_noise_model(model_config.signal, model_config.observation)
80
80
  return noise_model
81
81
 
82
82
 
@@ -1,13 +1,14 @@
1
1
  """A class chaining transforms together."""
2
2
 
3
- from typing import Dict, List, Optional, Tuple, cast
3
+ from typing import Dict, List, Optional, Tuple, Union, cast
4
4
 
5
- import numpy as np
5
+ from numpy.typing import NDArray
6
6
 
7
- from careamics.config.data_model import TRANSFORMS_UNION
7
+ from careamics.config.transformations import TransformModel
8
8
 
9
9
  from .n2v_manipulate import N2VManipulate
10
10
  from .normalize import Normalize
11
+ from .transform import Transform
11
12
  from .xy_flip import XYFlip
12
13
  from .xy_random_rotate90 import XYRandomRotate90
13
14
 
@@ -36,7 +37,7 @@ class Compose:
36
37
 
37
38
  Parameters
38
39
  ----------
39
- transform_list : List[TRANSFORMS_UNION]
40
+ transform_list : List[TransformModel]
40
41
  A list of dictionaries where each dictionary contains the name of a
41
42
  transform and its parameters.
42
43
 
@@ -46,26 +47,27 @@ class Compose:
46
47
  A callable that applies the transforms to the input data.
47
48
  """
48
49
 
49
- def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
50
+ def __init__(self, transform_list: List[TransformModel]) -> None:
50
51
  """Instantiate a Compose object.
51
52
 
52
53
  Parameters
53
54
  ----------
54
- transform_list : List[TRANSFORMS_UNION]
55
+ transform_list : List[TransformModel]
55
56
  A list of dictionaries where each dictionary contains the name of a
56
57
  transform and its parameters.
57
58
  """
58
59
  # retrieve all available transforms
59
- all_transforms = get_all_transforms()
60
+ # TODO: correctly type hint get_all_transforms function output
61
+ all_transforms: dict[str, type[Transform]] = get_all_transforms()
60
62
 
61
63
  # instantiate all transforms
62
- self.transforms = [
64
+ self.transforms: list[Transform] = [
63
65
  all_transforms[t.name](**t.model_dump()) for t in transform_list
64
66
  ]
65
67
 
66
68
  def _chain_transforms(
67
- self, patch: np.ndarray, target: Optional[np.ndarray]
68
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
69
+ self, patch: NDArray, target: Optional[NDArray]
70
+ ) -> Tuple[Optional[NDArray], ...]:
69
71
  """Chain transforms on the input data.
70
72
 
71
73
  Parameters
@@ -80,16 +82,56 @@ class Compose:
80
82
  Tuple[np.ndarray, Optional[np.ndarray]]
81
83
  The output of the transformations.
82
84
  """
83
- params = (patch, target)
85
+ params: Union[
86
+ tuple[NDArray, Optional[NDArray]],
87
+ tuple[NDArray, NDArray, NDArray], # N2VManiupulate output
88
+ ] = (patch, target)
84
89
 
85
90
  for t in self.transforms:
86
- params = t(*params)
91
+ # N2VManipulate returns tuple of 3 arrays
92
+ # - Other transoforms return tuple of (patch, target, additional_arrays)
93
+ if isinstance(t, N2VManipulate):
94
+ patch, *_ = params
95
+ params = t(patch=patch)
96
+ else:
97
+ *params, _ = t(*params) # ignore additional_arrays dict
87
98
 
88
99
  return params
89
100
 
101
+ def _chain_transforms_additional_arrays(
102
+ self,
103
+ patch: NDArray,
104
+ target: Optional[NDArray],
105
+ **additional_arrays: NDArray,
106
+ ) -> Tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
107
+ """Chain transforms on the input data, with additional arrays.
108
+
109
+ Parameters
110
+ ----------
111
+ patch : np.ndarray
112
+ Input data.
113
+ target : Optional[np.ndarray]
114
+ Target data, by default None.
115
+ **additional_arrays : NDArray
116
+ Additional arrays that will be transformed identically to `patch` and
117
+ `target`.
118
+
119
+ Returns
120
+ -------
121
+ Tuple[np.ndarray, Optional[np.ndarray]]
122
+ The output of the transformations.
123
+ """
124
+ params = {"patch": patch, "target": target, **additional_arrays}
125
+
126
+ for t in self.transforms:
127
+ patch, target, additional_arrays = t(**params)
128
+ params = {"patch": patch, "target": target, **additional_arrays}
129
+
130
+ return patch, target, additional_arrays
131
+
90
132
  def __call__(
91
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
92
- ) -> Tuple[np.ndarray, ...]:
133
+ self, patch: NDArray, target: Optional[NDArray] = None
134
+ ) -> Tuple[NDArray, ...]:
93
135
  """Apply the transforms to the input data.
94
136
 
95
137
  Parameters
@@ -104,4 +146,37 @@ class Compose:
104
146
  Tuple[np.ndarray, ...]
105
147
  The output of the transformations.
106
148
  """
107
- return cast(Tuple[np.ndarray, ...], self._chain_transforms(patch, target))
149
+ # TODO: solve casting Compose.__call__ ouput
150
+ return cast(Tuple[NDArray, ...], self._chain_transforms(patch, target))
151
+
152
+ def transform_with_additional_arrays(
153
+ self,
154
+ patch: NDArray,
155
+ target: Optional[NDArray] = None,
156
+ **additional_arrays: NDArray,
157
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
158
+ """Apply the transforms to the input data, including additional arrays.
159
+
160
+ Parameters
161
+ ----------
162
+ patch : np.ndarray
163
+ The input data.
164
+ target : Optional[np.ndarray], optional
165
+ Target data, by default None.
166
+ **additional_arrays : NDArray
167
+ Additional arrays that will be transformed identically to `patch` and
168
+ `target`.
169
+
170
+ Returns
171
+ -------
172
+ NDArray
173
+ The transformed patch.
174
+ NDArray | None
175
+ The transformed target.
176
+ dict of {str, NDArray}
177
+ Transformed additional arrays. Keys correspond to the keyword argument
178
+ names.
179
+ """
180
+ return self._chain_transforms_additional_arrays(
181
+ patch, target, **additional_arrays
182
+ )
@@ -3,6 +3,7 @@
3
3
  from typing import Any, Literal, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
8
9
  from careamics.transforms.transform import Transform
@@ -98,8 +99,8 @@ class N2VManipulate(Transform):
98
99
  self.rng = np.random.default_rng(seed=seed)
99
100
 
100
101
  def __call__(
101
- self, patch: np.ndarray, *args: Any, **kwargs: Any
102
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
102
+ self, patch: NDArray, *args: Any, **kwargs: Any
103
+ ) -> Tuple[NDArray, NDArray, NDArray]:
103
104
  """Apply the transform to the image.
104
105
 
105
106
  Parameters
@@ -142,5 +143,8 @@ class N2VManipulate(Transform):
142
143
  else:
143
144
  raise ValueError(f"Unknown masking strategy ({self.strategy}).")
144
145
 
146
+ # TODO: Output does not match other transforms, how to resolve?
147
+ # - Don't include in Compose and apply after if algorithm is N2V?
148
+ # - or just don't return patch? but then mask is in the target position
145
149
  # TODO why return patch?
146
150
  return masked, patch, mask
@@ -90,8 +90,11 @@ class Normalize(Transform):
90
90
  self.eps = 1e-6
91
91
 
92
92
  def __call__(
93
- self, patch: np.ndarray, target: Optional[NDArray] = None
94
- ) -> tuple[NDArray, Optional[NDArray]]:
93
+ self,
94
+ patch: np.ndarray,
95
+ target: Optional[NDArray] = None,
96
+ **additional_arrays: NDArray,
97
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
95
98
  """Apply the transform to the source patch and the target (optional).
96
99
 
97
100
  Parameters
@@ -100,6 +103,9 @@ class Normalize(Transform):
100
103
  Patch, 2D or 3D, shape C(Z)YX.
101
104
  target : NDArray, optional
102
105
  Target for the patch, by default None.
106
+ **additional_arrays : NDArray
107
+ Additional arrays that will be transformed identically to `patch` and
108
+ `target`.
103
109
 
104
110
  Returns
105
111
  -------
@@ -111,6 +117,11 @@ class Normalize(Transform):
111
117
  f"Number of means (got a list of size {len(self.image_means)}) and "
112
118
  f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
113
119
  )
120
+ if len(additional_arrays) != 0:
121
+ raise NotImplementedError(
122
+ "Transforming additional arrays is currently not supported for "
123
+ "`Normalize`."
124
+ )
114
125
 
115
126
  # reshape mean and std and apply the normalization to the patch
116
127
  means = _reshape_stats(self.image_means, patch.ndim)
@@ -129,7 +140,7 @@ class Normalize(Transform):
129
140
  else:
130
141
  norm_target = None
131
142
 
132
- return norm_patch, norm_target
143
+ return norm_patch, norm_target, additional_arrays
133
144
 
134
145
  def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
135
146
  """
@@ -1,8 +1,9 @@
1
1
  """XY flip transform."""
2
2
 
3
- from typing import Optional, Tuple
3
+ from typing import Optional
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
@@ -78,8 +79,11 @@ class XYFlip(Transform):
78
79
  self.rng = np.random.default_rng(seed=seed)
79
80
 
80
81
  def __call__(
81
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
82
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
82
+ self,
83
+ patch: NDArray,
84
+ target: Optional[NDArray] = None,
85
+ **additional_arrays: NDArray,
86
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
83
87
  """Apply the transform to the source patch and the target (optional).
84
88
 
85
89
  Parameters
@@ -88,6 +92,9 @@ class XYFlip(Transform):
88
92
  Patch, 2D or 3D, shape C(Z)YX.
89
93
  target : Optional[np.ndarray], optional
90
94
  Target for the patch, by default None.
95
+ **additional_arrays : NDArray
96
+ Additional arrays that will be transformed identically to `patch` and
97
+ `target`.
91
98
 
92
99
  Returns
93
100
  -------
@@ -95,17 +102,20 @@ class XYFlip(Transform):
95
102
  Transformed patch and target.
96
103
  """
97
104
  if self.rng.random() > self.p:
98
- return patch, target
105
+ return patch, target, additional_arrays
99
106
 
100
107
  # choose an axis to flip
101
108
  axis = self.rng.choice(self.axis_indices)
102
109
 
103
110
  patch_transformed = self._apply(patch, axis)
104
111
  target_transformed = self._apply(target, axis) if target is not None else None
112
+ additional_transformed = {
113
+ key: self._apply(array, axis) for key, array in additional_arrays.items()
114
+ }
105
115
 
106
- return patch_transformed, target_transformed
116
+ return patch_transformed, target_transformed, additional_transformed
107
117
 
108
- def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
118
+ def _apply(self, patch: NDArray, axis: int) -> NDArray:
109
119
  """Apply the transform to the image.
110
120
 
111
121
  Parameters
@@ -3,6 +3,7 @@
3
3
  from typing import Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
@@ -49,8 +50,11 @@ class XYRandomRotate90(Transform):
49
50
  self.rng = np.random.default_rng(seed=seed)
50
51
 
51
52
  def __call__(
52
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
53
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
53
+ self,
54
+ patch: NDArray,
55
+ target: Optional[NDArray] = None,
56
+ **additional_arrays: NDArray,
57
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
54
58
  """Apply the transform to the source patch and the target (optional).
55
59
 
56
60
  Parameters
@@ -59,6 +63,9 @@ class XYRandomRotate90(Transform):
59
63
  Patch, 2D or 3D, shape C(Z)YX.
60
64
  target : Optional[np.ndarray], optional
61
65
  Target for the patch, by default None.
66
+ **additional_arrays : NDArray
67
+ Additional arrays that will be transformed identically to `patch` and
68
+ `target`.
62
69
 
63
70
  Returns
64
71
  -------
@@ -66,7 +73,7 @@ class XYRandomRotate90(Transform):
66
73
  Transformed patch and target.
67
74
  """
68
75
  if self.rng.random() > self.p:
69
- return patch, target
76
+ return patch, target, additional_arrays
70
77
 
71
78
  # number of rotations
72
79
  n_rot = self.rng.integers(1, 4)
@@ -76,12 +83,14 @@ class XYRandomRotate90(Transform):
76
83
  target_transformed = (
77
84
  self._apply(target, n_rot, axes) if target is not None else None
78
85
  )
86
+ additional_transformed = {
87
+ key: self._apply(array, n_rot, axes)
88
+ for key, array in additional_arrays.items()
89
+ }
79
90
 
80
- return patch_transformed, target_transformed
91
+ return patch_transformed, target_transformed, additional_transformed
81
92
 
82
- def _apply(
83
- self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
84
- ) -> np.ndarray:
93
+ def _apply(self, patch: NDArray, n_rot: int, axes: Tuple[int, int]) -> NDArray:
85
94
  """Apply the transform to the image.
86
95
 
87
96
  Parameters