careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,13 +1,14 @@
1
1
  """A class chaining transforms together."""
2
2
 
3
- from typing import Dict, List, Optional, Tuple, cast
3
+ from typing import Dict, List, Optional, Tuple, Union, cast
4
4
 
5
- import numpy as np
5
+ from numpy.typing import NDArray
6
6
 
7
- from careamics.config.data_model import TRANSFORMS_UNION
7
+ from careamics.config.transformations import TransformModel
8
8
 
9
9
  from .n2v_manipulate import N2VManipulate
10
10
  from .normalize import Normalize
11
+ from .transform import Transform
11
12
  from .xy_flip import XYFlip
12
13
  from .xy_random_rotate90 import XYRandomRotate90
13
14
 
@@ -36,7 +37,7 @@ class Compose:
36
37
 
37
38
  Parameters
38
39
  ----------
39
- transform_list : List[TRANSFORMS_UNION]
40
+ transform_list : List[TransformModel]
40
41
  A list of dictionaries where each dictionary contains the name of a
41
42
  transform and its parameters.
42
43
 
@@ -46,26 +47,27 @@ class Compose:
46
47
  A callable that applies the transforms to the input data.
47
48
  """
48
49
 
49
- def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
50
+ def __init__(self, transform_list: List[TransformModel]) -> None:
50
51
  """Instantiate a Compose object.
51
52
 
52
53
  Parameters
53
54
  ----------
54
- transform_list : List[TRANSFORMS_UNION]
55
+ transform_list : List[TransformModel]
55
56
  A list of dictionaries where each dictionary contains the name of a
56
57
  transform and its parameters.
57
58
  """
58
59
  # retrieve all available transforms
59
- all_transforms = get_all_transforms()
60
+ # TODO: correctly type hint get_all_transforms function output
61
+ all_transforms: dict[str, type[Transform]] = get_all_transforms()
60
62
 
61
63
  # instantiate all transforms
62
- self.transforms = [
64
+ self.transforms: list[Transform] = [
63
65
  all_transforms[t.name](**t.model_dump()) for t in transform_list
64
66
  ]
65
67
 
66
68
  def _chain_transforms(
67
- self, patch: np.ndarray, target: Optional[np.ndarray]
68
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
69
+ self, patch: NDArray, target: Optional[NDArray]
70
+ ) -> Tuple[Optional[NDArray], ...]:
69
71
  """Chain transforms on the input data.
70
72
 
71
73
  Parameters
@@ -80,16 +82,56 @@ class Compose:
80
82
  Tuple[np.ndarray, Optional[np.ndarray]]
81
83
  The output of the transformations.
82
84
  """
83
- params = (patch, target)
85
+ params: Union[
86
+ tuple[NDArray, Optional[NDArray]],
87
+ tuple[NDArray, NDArray, NDArray], # N2VManiupulate output
88
+ ] = (patch, target)
84
89
 
85
90
  for t in self.transforms:
86
- params = t(*params)
91
+ # N2VManipulate returns tuple of 3 arrays
92
+ # - Other transoforms return tuple of (patch, target, additional_arrays)
93
+ if isinstance(t, N2VManipulate):
94
+ patch, *_ = params
95
+ params = t(patch=patch)
96
+ else:
97
+ *params, _ = t(*params) # ignore additional_arrays dict
87
98
 
88
99
  return params
89
100
 
101
+ def _chain_transforms_additional_arrays(
102
+ self,
103
+ patch: NDArray,
104
+ target: Optional[NDArray],
105
+ **additional_arrays: NDArray,
106
+ ) -> Tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
107
+ """Chain transforms on the input data, with additional arrays.
108
+
109
+ Parameters
110
+ ----------
111
+ patch : np.ndarray
112
+ Input data.
113
+ target : Optional[np.ndarray]
114
+ Target data, by default None.
115
+ **additional_arrays : NDArray
116
+ Additional arrays that will be transformed identically to `patch` and
117
+ `target`.
118
+
119
+ Returns
120
+ -------
121
+ Tuple[np.ndarray, Optional[np.ndarray]]
122
+ The output of the transformations.
123
+ """
124
+ params = {"patch": patch, "target": target, **additional_arrays}
125
+
126
+ for t in self.transforms:
127
+ patch, target, additional_arrays = t(**params)
128
+ params = {"patch": patch, "target": target, **additional_arrays}
129
+
130
+ return patch, target, additional_arrays
131
+
90
132
  def __call__(
91
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
92
- ) -> Tuple[np.ndarray, ...]:
133
+ self, patch: NDArray, target: Optional[NDArray] = None
134
+ ) -> Tuple[NDArray, ...]:
93
135
  """Apply the transforms to the input data.
94
136
 
95
137
  Parameters
@@ -104,4 +146,37 @@ class Compose:
104
146
  Tuple[np.ndarray, ...]
105
147
  The output of the transformations.
106
148
  """
107
- return cast(Tuple[np.ndarray, ...], self._chain_transforms(patch, target))
149
+ # TODO: solve casting Compose.__call__ ouput
150
+ return cast(Tuple[NDArray, ...], self._chain_transforms(patch, target))
151
+
152
+ def transform_with_additional_arrays(
153
+ self,
154
+ patch: NDArray,
155
+ target: Optional[NDArray] = None,
156
+ **additional_arrays: NDArray,
157
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
158
+ """Apply the transforms to the input data, including additional arrays.
159
+
160
+ Parameters
161
+ ----------
162
+ patch : np.ndarray
163
+ The input data.
164
+ target : Optional[np.ndarray], optional
165
+ Target data, by default None.
166
+ **additional_arrays : NDArray
167
+ Additional arrays that will be transformed identically to `patch` and
168
+ `target`.
169
+
170
+ Returns
171
+ -------
172
+ NDArray
173
+ The transformed patch.
174
+ NDArray | None
175
+ The transformed target.
176
+ dict of {str, NDArray}
177
+ Transformed additional arrays. Keys correspond to the keyword argument
178
+ names.
179
+ """
180
+ return self._chain_transforms_additional_arrays(
181
+ patch, target, **additional_arrays
182
+ )
@@ -3,6 +3,7 @@
3
3
  from typing import Any, Literal, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
8
9
  from careamics.transforms.transform import Transform
@@ -98,8 +99,8 @@ class N2VManipulate(Transform):
98
99
  self.rng = np.random.default_rng(seed=seed)
99
100
 
100
101
  def __call__(
101
- self, patch: np.ndarray, *args: Any, **kwargs: Any
102
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
102
+ self, patch: NDArray, *args: Any, **kwargs: Any
103
+ ) -> Tuple[NDArray, NDArray, NDArray]:
103
104
  """Apply the transform to the image.
104
105
 
105
106
  Parameters
@@ -142,5 +143,8 @@ class N2VManipulate(Transform):
142
143
  else:
143
144
  raise ValueError(f"Unknown masking strategy ({self.strategy}).")
144
145
 
146
+ # TODO: Output does not match other transforms, how to resolve?
147
+ # - Don't include in Compose and apply after if algorithm is N2V?
148
+ # - or just don't return patch? but then mask is in the target position
145
149
  # TODO why return patch?
146
150
  return masked, patch, mask
@@ -90,8 +90,11 @@ class Normalize(Transform):
90
90
  self.eps = 1e-6
91
91
 
92
92
  def __call__(
93
- self, patch: np.ndarray, target: Optional[NDArray] = None
94
- ) -> tuple[NDArray, Optional[NDArray]]:
93
+ self,
94
+ patch: np.ndarray,
95
+ target: Optional[NDArray] = None,
96
+ **additional_arrays: NDArray,
97
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
95
98
  """Apply the transform to the source patch and the target (optional).
96
99
 
97
100
  Parameters
@@ -100,6 +103,9 @@ class Normalize(Transform):
100
103
  Patch, 2D or 3D, shape C(Z)YX.
101
104
  target : NDArray, optional
102
105
  Target for the patch, by default None.
106
+ **additional_arrays : NDArray
107
+ Additional arrays that will be transformed identically to `patch` and
108
+ `target`.
103
109
 
104
110
  Returns
105
111
  -------
@@ -111,6 +117,11 @@ class Normalize(Transform):
111
117
  f"Number of means (got a list of size {len(self.image_means)}) and "
112
118
  f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
113
119
  )
120
+ if len(additional_arrays) != 0:
121
+ raise NotImplementedError(
122
+ "Transforming additional arrays is currently not supported for "
123
+ "`Normalize`."
124
+ )
114
125
 
115
126
  # reshape mean and std and apply the normalization to the patch
116
127
  means = _reshape_stats(self.image_means, patch.ndim)
@@ -129,7 +140,7 @@ class Normalize(Transform):
129
140
  else:
130
141
  norm_target = None
131
142
 
132
- return norm_patch, norm_target
143
+ return norm_patch, norm_target, additional_arrays
133
144
 
134
145
  def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
135
146
  """
@@ -161,7 +161,7 @@ def _get_stratified_coords(
161
161
  coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
162
162
 
163
163
  grid_random_increment = rng.integers(
164
- _odd_jitter_func(float(max(steps)), rng)
164
+ _odd_jitter_func(float(max(steps)), rng) # type: ignore
165
165
  * np.ones_like(coordinate_grid).astype(np.int32)
166
166
  - 1,
167
167
  size=coordinate_grid.shape,
@@ -1,8 +1,9 @@
1
1
  """XY flip transform."""
2
2
 
3
- from typing import Optional, Tuple
3
+ from typing import Optional
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
@@ -78,8 +79,11 @@ class XYFlip(Transform):
78
79
  self.rng = np.random.default_rng(seed=seed)
79
80
 
80
81
  def __call__(
81
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
82
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
82
+ self,
83
+ patch: NDArray,
84
+ target: Optional[NDArray] = None,
85
+ **additional_arrays: NDArray,
86
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
83
87
  """Apply the transform to the source patch and the target (optional).
84
88
 
85
89
  Parameters
@@ -88,6 +92,9 @@ class XYFlip(Transform):
88
92
  Patch, 2D or 3D, shape C(Z)YX.
89
93
  target : Optional[np.ndarray], optional
90
94
  Target for the patch, by default None.
95
+ **additional_arrays : NDArray
96
+ Additional arrays that will be transformed identically to `patch` and
97
+ `target`.
91
98
 
92
99
  Returns
93
100
  -------
@@ -95,17 +102,20 @@ class XYFlip(Transform):
95
102
  Transformed patch and target.
96
103
  """
97
104
  if self.rng.random() > self.p:
98
- return patch, target
105
+ return patch, target, additional_arrays
99
106
 
100
107
  # choose an axis to flip
101
108
  axis = self.rng.choice(self.axis_indices)
102
109
 
103
110
  patch_transformed = self._apply(patch, axis)
104
111
  target_transformed = self._apply(target, axis) if target is not None else None
112
+ additional_transformed = {
113
+ key: self._apply(array, axis) for key, array in additional_arrays.items()
114
+ }
105
115
 
106
- return patch_transformed, target_transformed
116
+ return patch_transformed, target_transformed, additional_transformed
107
117
 
108
- def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
118
+ def _apply(self, patch: NDArray, axis: int) -> NDArray:
109
119
  """Apply the transform to the image.
110
120
 
111
121
  Parameters
@@ -3,6 +3,7 @@
3
3
  from typing import Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
@@ -49,8 +50,11 @@ class XYRandomRotate90(Transform):
49
50
  self.rng = np.random.default_rng(seed=seed)
50
51
 
51
52
  def __call__(
52
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
53
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
53
+ self,
54
+ patch: NDArray,
55
+ target: Optional[NDArray] = None,
56
+ **additional_arrays: NDArray,
57
+ ) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
54
58
  """Apply the transform to the source patch and the target (optional).
55
59
 
56
60
  Parameters
@@ -59,6 +63,9 @@ class XYRandomRotate90(Transform):
59
63
  Patch, 2D or 3D, shape C(Z)YX.
60
64
  target : Optional[np.ndarray], optional
61
65
  Target for the patch, by default None.
66
+ **additional_arrays : NDArray
67
+ Additional arrays that will be transformed identically to `patch` and
68
+ `target`.
62
69
 
63
70
  Returns
64
71
  -------
@@ -66,7 +73,7 @@ class XYRandomRotate90(Transform):
66
73
  Transformed patch and target.
67
74
  """
68
75
  if self.rng.random() > self.p:
69
- return patch, target
76
+ return patch, target, additional_arrays
70
77
 
71
78
  # number of rotations
72
79
  n_rot = self.rng.integers(1, 4)
@@ -76,12 +83,14 @@ class XYRandomRotate90(Transform):
76
83
  target_transformed = (
77
84
  self._apply(target, n_rot, axes) if target is not None else None
78
85
  )
86
+ additional_transformed = {
87
+ key: self._apply(array, n_rot, axes)
88
+ for key, array in additional_arrays.items()
89
+ }
79
90
 
80
- return patch_transformed, target_transformed
91
+ return patch_transformed, target_transformed, additional_transformed
81
92
 
82
- def _apply(
83
- self, patch: np.ndarray, n_rot: int, axes: Tuple[int, int]
84
- ) -> np.ndarray:
93
+ def _apply(self, patch: NDArray, n_rot: int, axes: Tuple[int, int]) -> NDArray:
85
94
  """Apply the transform to the image.
86
95
 
87
96
  Parameters