careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,107 @@
1
+ """A class chaining transforms together."""
2
+
3
+ from typing import Dict, List, Optional, Tuple, cast
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.data_model import TRANSFORMS_UNION
8
+
9
+ from .n2v_manipulate import N2VManipulate
10
+ from .normalize import Normalize
11
+ from .xy_flip import XYFlip
12
+ from .xy_random_rotate90 import XYRandomRotate90
13
+
14
+ ALL_TRANSFORMS = {
15
+ "Normalize": Normalize,
16
+ "N2VManipulate": N2VManipulate,
17
+ "XYFlip": XYFlip,
18
+ "XYRandomRotate90": XYRandomRotate90,
19
+ }
20
+
21
+
22
+ def get_all_transforms() -> Dict[str, type]:
23
+ """Return all the transforms accepted by CAREamics.
24
+
25
+ Returns
26
+ -------
27
+ dict
28
+ A dictionary with all the transforms accepted by CAREamics, where the keys are
29
+ the transform names and the values are the transform classes.
30
+ """
31
+ return ALL_TRANSFORMS
32
+
33
+
34
+ class Compose:
35
+ """A class chaining transforms together.
36
+
37
+ Parameters
38
+ ----------
39
+ transform_list : List[TRANSFORMS_UNION]
40
+ A list of dictionaries where each dictionary contains the name of a
41
+ transform and its parameters.
42
+
43
+ Attributes
44
+ ----------
45
+ _callable_transforms : Callable
46
+ A callable that applies the transforms to the input data.
47
+ """
48
+
49
+ def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
50
+ """Instantiate a Compose object.
51
+
52
+ Parameters
53
+ ----------
54
+ transform_list : List[TRANSFORMS_UNION]
55
+ A list of dictionaries where each dictionary contains the name of a
56
+ transform and its parameters.
57
+ """
58
+ # retrieve all available transforms
59
+ all_transforms = get_all_transforms()
60
+
61
+ # instantiate all transforms
62
+ self.transforms = [
63
+ all_transforms[t.name](**t.model_dump()) for t in transform_list
64
+ ]
65
+
66
+ def _chain_transforms(
67
+ self, patch: np.ndarray, target: Optional[np.ndarray]
68
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
69
+ """Chain transforms on the input data.
70
+
71
+ Parameters
72
+ ----------
73
+ patch : np.ndarray
74
+ Input data.
75
+ target : Optional[np.ndarray]
76
+ Target data, by default None.
77
+
78
+ Returns
79
+ -------
80
+ Tuple[np.ndarray, Optional[np.ndarray]]
81
+ The output of the transformations.
82
+ """
83
+ params = (patch, target)
84
+
85
+ for t in self.transforms:
86
+ params = t(*params)
87
+
88
+ return params
89
+
90
+ def __call__(
91
+ self, patch: np.ndarray, target: Optional[np.ndarray] = None
92
+ ) -> Tuple[np.ndarray, ...]:
93
+ """Apply the transforms to the input data.
94
+
95
+ Parameters
96
+ ----------
97
+ patch : np.ndarray
98
+ The input data.
99
+ target : Optional[np.ndarray], optional
100
+ Target data, by default None.
101
+
102
+ Returns
103
+ -------
104
+ Tuple[np.ndarray, ...]
105
+ The output of the transformations.
106
+ """
107
+ return cast(Tuple[np.ndarray, ...], self._chain_transforms(patch, target))
@@ -0,0 +1,146 @@
1
+ """N2V manipulation transform."""
2
+
3
+ from typing import Any, Literal, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
8
+ from careamics.transforms.transform import Transform
9
+
10
+ from .pixel_manipulation import median_manipulate, uniform_manipulate
11
+ from .struct_mask_parameters import StructMaskParameters
12
+
13
+
14
+ class N2VManipulate(Transform):
15
+ """
16
+ Default augmentation for the N2V model.
17
+
18
+ This transform expects C(Z)YX dimensions.
19
+
20
+ Parameters
21
+ ----------
22
+ roi_size : int, optional
23
+ Size of the replacement area, by default 11.
24
+ masked_pixel_percentage : float, optional
25
+ Percentage of pixels to mask, by default 0.2.
26
+ strategy : Literal[ "uniform", "median" ], optional
27
+ Replaccement strategy, uniform or median, by default uniform.
28
+ remove_center : bool, optional
29
+ Whether to remove central pixel from patch, by default True.
30
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
31
+ StructN2V mask axis, by default "none".
32
+ struct_mask_span : int, optional
33
+ StructN2V mask span, by default 5.
34
+ seed : Optional[int], optional
35
+ Random seed, by default None.
36
+
37
+ Attributes
38
+ ----------
39
+ masked_pixel_percentage : float
40
+ Percentage of pixels to mask.
41
+ roi_size : int
42
+ Size of the replacement area.
43
+ strategy : Literal[ "uniform", "median" ]
44
+ Replaccement strategy, uniform or median.
45
+ remove_center : bool
46
+ Whether to remove central pixel from patch.
47
+ struct_mask : Optional[StructMaskParameters]
48
+ StructN2V mask parameters.
49
+ rng : Generator
50
+ Random number generator.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ roi_size: int = 11,
56
+ masked_pixel_percentage: float = 0.2,
57
+ strategy: Literal[
58
+ "uniform", "median"
59
+ ] = SupportedPixelManipulation.UNIFORM.value,
60
+ remove_center: bool = True,
61
+ struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
62
+ struct_mask_span: int = 5,
63
+ seed: Optional[int] = None,
64
+ ):
65
+ """Constructor.
66
+
67
+ Parameters
68
+ ----------
69
+ roi_size : int, optional
70
+ Size of the replacement area, by default 11.
71
+ masked_pixel_percentage : float, optional
72
+ Percentage of pixels to mask, by default 0.2.
73
+ strategy : Literal[ "uniform", "median" ], optional
74
+ Replaccement strategy, uniform or median, by default uniform.
75
+ remove_center : bool, optional
76
+ Whether to remove central pixel from patch, by default True.
77
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
78
+ StructN2V mask axis, by default "none".
79
+ struct_mask_span : int, optional
80
+ StructN2V mask span, by default 5.
81
+ seed : Optional[int], optional
82
+ Random seed, by default None.
83
+ """
84
+ self.masked_pixel_percentage = masked_pixel_percentage
85
+ self.roi_size = roi_size
86
+ self.strategy = strategy
87
+ self.remove_center = remove_center # TODO is this ever used?
88
+
89
+ if struct_mask_axis == SupportedStructAxis.NONE:
90
+ self.struct_mask: Optional[StructMaskParameters] = None
91
+ else:
92
+ self.struct_mask = StructMaskParameters(
93
+ axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1,
94
+ span=struct_mask_span,
95
+ )
96
+
97
+ # numpy random generator
98
+ self.rng = np.random.default_rng(seed=seed)
99
+
100
+ def __call__(
101
+ self, patch: np.ndarray, *args: Any, **kwargs: Any
102
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
103
+ """Apply the transform to the image.
104
+
105
+ Parameters
106
+ ----------
107
+ patch : np.ndarray
108
+ Image patch, 2D or 3D, shape C(Z)YX.
109
+ *args : Any
110
+ Additional arguments, unused.
111
+ **kwargs : Any
112
+ Additional keyword arguments, unused.
113
+
114
+ Returns
115
+ -------
116
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
117
+ Masked patch, original patch, and mask.
118
+ """
119
+ masked = np.zeros_like(patch)
120
+ mask = np.zeros_like(patch)
121
+ if self.strategy == SupportedPixelManipulation.UNIFORM:
122
+ # Iterate over the channels to apply manipulation separately
123
+ for c in range(patch.shape[0]):
124
+ masked[c, ...], mask[c, ...] = uniform_manipulate(
125
+ patch=patch[c, ...],
126
+ mask_pixel_percentage=self.masked_pixel_percentage,
127
+ subpatch_size=self.roi_size,
128
+ remove_center=self.remove_center,
129
+ struct_params=self.struct_mask,
130
+ rng=self.rng,
131
+ )
132
+ elif self.strategy == SupportedPixelManipulation.MEDIAN:
133
+ # Iterate over the channels to apply manipulation separately
134
+ for c in range(patch.shape[0]):
135
+ masked[c, ...], mask[c, ...] = median_manipulate(
136
+ patch=patch[c, ...],
137
+ mask_pixel_percentage=self.masked_pixel_percentage,
138
+ subpatch_size=self.roi_size,
139
+ struct_params=self.struct_mask,
140
+ rng=self.rng,
141
+ )
142
+ else:
143
+ raise ValueError(f"Unknown masking strategy ({self.strategy}).")
144
+
145
+ # TODO why return patch?
146
+ return masked, patch, mask
@@ -0,0 +1,243 @@
1
+ """Normalization and denormalization transforms for image patches."""
2
+
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.transforms.transform import Transform
9
+
10
+
11
+ def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
12
+ """Reshape stats to match the number of dimensions of the input image.
13
+
14
+ This allows to broadcast the stats (mean or std) to the image dimensions, and
15
+ thus directly perform a vectorial calculation.
16
+
17
+ Parameters
18
+ ----------
19
+ stats : list of float
20
+ List of stats, mean or standard deviation.
21
+ ndim : int
22
+ Number of dimensions of the image, including the C channel.
23
+
24
+ Returns
25
+ -------
26
+ NDArray
27
+ Reshaped stats.
28
+ """
29
+ return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
30
+
31
+
32
+ class Normalize(Transform):
33
+ """
34
+ Normalize an image or image patch.
35
+
36
+ Normalization is a zero mean and unit variance. This transform expects C(Z)YX
37
+ dimensions.
38
+
39
+ Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
40
+ division by zero and that it returns a float32 image.
41
+
42
+ Parameters
43
+ ----------
44
+ image_means : list of float
45
+ Mean value per channel.
46
+ image_stds : list of float
47
+ Standard deviation value per channel.
48
+ target_means : list of float, optional
49
+ Target mean value per channel, by default None.
50
+ target_stds : list of float, optional
51
+ Target standard deviation value per channel, by default None.
52
+
53
+ Attributes
54
+ ----------
55
+ image_means : list of float
56
+ Mean value per channel.
57
+ image_stds : list of float
58
+ Standard deviation value per channel.
59
+ target_means :list of float, optional
60
+ Target mean value per channel, by default None.
61
+ target_stds : list of float, optional
62
+ Target standard deviation value per channel, by default None.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ image_means: list[float],
68
+ image_stds: list[float],
69
+ target_means: Optional[list[float]] = None,
70
+ target_stds: Optional[list[float]] = None,
71
+ ):
72
+ """Constructor.
73
+
74
+ Parameters
75
+ ----------
76
+ image_means : list of float
77
+ Mean value per channel.
78
+ image_stds : list of float
79
+ Standard deviation value per channel.
80
+ target_means : list of float, optional
81
+ Target mean value per channel, by default None.
82
+ target_stds : list of float, optional
83
+ Target standard deviation value per channel, by default None.
84
+ """
85
+ self.image_means = image_means
86
+ self.image_stds = image_stds
87
+ self.target_means = target_means
88
+ self.target_stds = target_stds
89
+
90
+ self.eps = 1e-6
91
+
92
+ def __call__(
93
+ self, patch: np.ndarray, target: Optional[NDArray] = None
94
+ ) -> tuple[NDArray, Optional[NDArray]]:
95
+ """Apply the transform to the source patch and the target (optional).
96
+
97
+ Parameters
98
+ ----------
99
+ patch : NDArray
100
+ Patch, 2D or 3D, shape C(Z)YX.
101
+ target : NDArray, optional
102
+ Target for the patch, by default None.
103
+
104
+ Returns
105
+ -------
106
+ tuple of NDArray
107
+ Transformed patch and target, the target can be returned as `None`.
108
+ """
109
+ if len(self.image_means) != patch.shape[0]:
110
+ raise ValueError(
111
+ f"Number of means (got a list of size {len(self.image_means)}) and "
112
+ f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
113
+ )
114
+
115
+ # reshape mean and std and apply the normalization to the patch
116
+ means = _reshape_stats(self.image_means, patch.ndim)
117
+ stds = _reshape_stats(self.image_stds, patch.ndim)
118
+ norm_patch = self._apply(patch, means, stds)
119
+
120
+ # same for the target patch
121
+ if (
122
+ target is not None
123
+ and self.target_means is not None
124
+ and self.target_stds is not None
125
+ ):
126
+ target_means = _reshape_stats(self.target_means, target.ndim)
127
+ target_stds = _reshape_stats(self.target_stds, target.ndim)
128
+ norm_target = self._apply(target, target_means, target_stds)
129
+ else:
130
+ norm_target = None
131
+
132
+ return norm_patch, norm_target
133
+
134
+ def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
135
+ """
136
+ Apply the transform to the image.
137
+
138
+ Parameters
139
+ ----------
140
+ patch : NDArray
141
+ Image patch, 2D or 3D, shape C(Z)YX.
142
+ mean : NDArray
143
+ Mean values.
144
+ std : NDArray
145
+ Standard deviations.
146
+
147
+ Returns
148
+ -------
149
+ NDArray
150
+ Normalized image patch.
151
+ """
152
+ return ((patch - mean) / (std + self.eps)).astype(np.float32)
153
+
154
+
155
+ class Denormalize:
156
+ """
157
+ Denormalize an image.
158
+
159
+ Denormalization is performed expecting a zero mean and unit variance input. This
160
+ transform expects C(Z)YX dimensions.
161
+
162
+ Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
163
+ division by zero during the normalization step, which is taken into account during
164
+ denormalization.
165
+
166
+ Parameters
167
+ ----------
168
+ image_means : list or tuple of float
169
+ Mean value per channel.
170
+ image_stds : list or tuple of float
171
+ Standard deviation value per channel.
172
+
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ image_means: list[float],
178
+ image_stds: list[float],
179
+ ):
180
+ """Constructor.
181
+
182
+ Parameters
183
+ ----------
184
+ image_means : list of float
185
+ Mean value per channel.
186
+ image_stds : list of float
187
+ Standard deviation value per channel.
188
+ """
189
+ self.image_means = image_means
190
+ self.image_stds = image_stds
191
+
192
+ self.eps = 1e-6
193
+
194
+ def __call__(self, patch: NDArray) -> NDArray:
195
+ """Reverse the normalization operation for a batch of patches.
196
+
197
+ Parameters
198
+ ----------
199
+ patch : NDArray
200
+ Patch, 2D or 3D, shape BC(Z)YX.
201
+
202
+ Returns
203
+ -------
204
+ NDArray
205
+ Transformed array.
206
+ """
207
+ if len(self.image_means) != patch.shape[1]:
208
+ raise ValueError(
209
+ f"Number of means (got a list of size {len(self.image_means)}) and "
210
+ f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
211
+ f"match."
212
+ )
213
+
214
+ means = _reshape_stats(self.image_means, patch.ndim)
215
+ stds = _reshape_stats(self.image_stds, patch.ndim)
216
+
217
+ denorm_array = self._apply(
218
+ patch,
219
+ np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
220
+ np.swapaxes(stds, 0, 1),
221
+ )
222
+
223
+ return denorm_array.astype(np.float32)
224
+
225
+ def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
226
+ """
227
+ Apply the transform to the image.
228
+
229
+ Parameters
230
+ ----------
231
+ array : NDArray
232
+ Image patch, 2D or 3D, shape C(Z)YX.
233
+ mean : NDArray
234
+ Mean values.
235
+ std : NDArray
236
+ Standard deviations.
237
+
238
+ Returns
239
+ -------
240
+ NDArray
241
+ Denormalized image array.
242
+ """
243
+ return array * (std + self.eps) + mean