careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,165 @@
1
+ """Module containing functions to convert prediction outputs to desired form."""
2
+
3
+ from typing import Any, List, Literal, Tuple, Union, overload
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from ..config.tile_information import TileInformation
9
+ from .stitch_prediction import stitch_prediction
10
+
11
+
12
+ def convert_outputs(
13
+ predictions: List[Any], tiled: bool
14
+ ) -> Union[List[NDArray], NDArray]:
15
+ """
16
+ Convert the outputs to the desired form.
17
+
18
+ Parameters
19
+ ----------
20
+ predictions : list
21
+ Predictions that are output from `Trainer.predict`.
22
+ tiled : bool
23
+ Whether the predictions are tiled.
24
+
25
+ Returns
26
+ -------
27
+ list of numpy.ndarray or numpy.ndarray
28
+ List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
29
+ be in a list.
30
+ """
31
+ if len(predictions) == 0:
32
+ return predictions
33
+
34
+ # this layout is to stop mypy complaining
35
+ if tiled:
36
+ predictions_comb = combine_batches(predictions, tiled)
37
+ # remove sample dimension (always 1) `stitch_predict` func expects no S dim
38
+ tiles = [pred[0] for pred in predictions_comb[0]]
39
+ tile_infos = predictions_comb[1]
40
+ predictions_output = stitch_prediction(tiles, tile_infos)
41
+ else:
42
+ predictions_output = combine_batches(predictions, tiled)
43
+
44
+ # TODO: add this in? Returns output with same axes as input
45
+ # Won't work with tiling rn because stitch_prediction func removes S axis
46
+ # predictions = reshape(predictions, axes)
47
+ # At least make sure stitched prediction and non-tiled prediction have matching axes
48
+
49
+ # TODO: might want to remove this
50
+ if len(predictions_output) == 1:
51
+ return predictions_output[0]
52
+ return predictions_output
53
+
54
+
55
+ # for mypy
56
+ @overload
57
+ def combine_batches( # numpydoc ignore=GL08
58
+ predictions: List[Any], tiled: Literal[True]
59
+ ) -> Tuple[List[NDArray], List[TileInformation]]: ...
60
+
61
+
62
+ # for mypy
63
+ @overload
64
+ def combine_batches( # numpydoc ignore=GL08
65
+ predictions: List[Any], tiled: Literal[False]
66
+ ) -> List[NDArray]: ...
67
+
68
+
69
+ # for mypy
70
+ @overload
71
+ def combine_batches( # numpydoc ignore=GL08
72
+ predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
73
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
74
+
75
+
76
+ def combine_batches(
77
+ predictions: List[Any], tiled: bool
78
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
79
+ """
80
+ If predictions are in batches, they will be combined.
81
+
82
+ Parameters
83
+ ----------
84
+ predictions : list
85
+ Predictions that are output from `Trainer.predict`.
86
+ tiled : bool
87
+ Whether the predictions are tiled.
88
+
89
+ Returns
90
+ -------
91
+ (list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
92
+ Combined batches.
93
+ """
94
+ if tiled:
95
+ return _combine_tiled_batches(predictions)
96
+ else:
97
+ return _combine_untiled_batches(predictions)
98
+
99
+
100
+ def _combine_tiled_batches(
101
+ predictions: List[Tuple[NDArray, List[TileInformation]]]
102
+ ) -> Tuple[List[NDArray], List[TileInformation]]:
103
+ """
104
+ Combine batches from tiled output.
105
+
106
+ Parameters
107
+ ----------
108
+ predictions : list
109
+ Predictions that are output from `Trainer.predict`.
110
+
111
+ Returns
112
+ -------
113
+ tuple of (list of numpy.ndarray, list of TileInformation)
114
+ Combined batches.
115
+ """
116
+ # turn list of lists into single list
117
+ tile_infos = [
118
+ tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
119
+ ]
120
+ prediction_tiles: List[NDArray] = _combine_untiled_batches(
121
+ [preds for preds, _ in predictions]
122
+ )
123
+ return prediction_tiles, tile_infos
124
+
125
+
126
+ def _combine_untiled_batches(predictions: List[NDArray]) -> List[NDArray]:
127
+ """
128
+ Combine batches from un-tiled output.
129
+
130
+ Parameters
131
+ ----------
132
+ predictions : list
133
+ Predictions that are output from `Trainer.predict`.
134
+
135
+ Returns
136
+ -------
137
+ list of nunpy.ndarray
138
+ Combined batches.
139
+ """
140
+ prediction_concat: NDArray = np.concatenate(predictions, axis=0)
141
+ prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
142
+ return prediction_split
143
+
144
+
145
+ def reshape(predictions: List[NDArray], axes: str) -> List[NDArray]:
146
+ """
147
+ Reshape predictions to have dimensions of input.
148
+
149
+ Parameters
150
+ ----------
151
+ predictions : list
152
+ Predictions that are output from `Trainer.predict`.
153
+ axes : str
154
+ Axes SC(Z)YX.
155
+
156
+ Returns
157
+ -------
158
+ List[NDArray]
159
+ Reshaped predicitions.
160
+ """
161
+ if "C" not in axes:
162
+ predictions = [pred[:, 0] for pred in predictions]
163
+ if "S" not in axes:
164
+ predictions = [pred[0] for pred in predictions]
165
+ return predictions
@@ -0,0 +1,100 @@
1
+ """Prediction utility functions."""
2
+
3
+ from typing import List
4
+
5
+ import numpy as np
6
+
7
+ from careamics.config.tile_information import TileInformation
8
+
9
+
10
+ # TODO: why not allow input and output of torch.tensor ?
11
+ def stitch_prediction(
12
+ tiles: List[np.ndarray],
13
+ tile_infos: List[TileInformation],
14
+ ) -> List[np.ndarray]:
15
+ """
16
+ Stitch tiles back together to form a full image(s).
17
+
18
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
19
+ singleton dimension.
20
+
21
+ Parameters
22
+ ----------
23
+ tiles : list of numpy.ndarray
24
+ Cropped tiles and their respective stitching coordinates. Can contain tiles
25
+ from multiple images.
26
+ tile_infos : list of TileInformation
27
+ List of information and coordinates obtained from
28
+ `dataset.tiled_patching.extract_tiles`.
29
+
30
+ Returns
31
+ -------
32
+ list of numpy.ndarray
33
+ Full image(s).
34
+ """
35
+ # Find where to split the lists so that only info from one image is contained.
36
+ # Do this by locating the last tiles of each image.
37
+ last_tiles = [tile_info.last_tile for tile_info in tile_infos]
38
+ last_tile_position = np.where(last_tiles)[0]
39
+ image_slices = [
40
+ slice(
41
+ None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
42
+ )
43
+ for i in range(len(last_tile_position))
44
+ ]
45
+ image_predictions = []
46
+ # slice the lists and apply stitch_prediction_single to each in turn.
47
+ for image_slice in image_slices:
48
+ image_predictions.append(
49
+ stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
50
+ )
51
+ return image_predictions
52
+
53
+
54
+ def stitch_prediction_single(
55
+ tiles: List[np.ndarray],
56
+ tile_infos: List[TileInformation],
57
+ ) -> np.ndarray:
58
+ """
59
+ Stitch tiles back together to form a full image.
60
+
61
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
62
+ singleton dimension.
63
+
64
+ Parameters
65
+ ----------
66
+ tiles : list of numpy.ndarray
67
+ Cropped tiles and their respective stitching coordinates.
68
+ tile_infos : list of TileInformation
69
+ List of information and coordinates obtained from
70
+ `dataset.tiled_patching.extract_tiles`.
71
+
72
+ Returns
73
+ -------
74
+ numpy.ndarray
75
+ Full image.
76
+ """
77
+ # retrieve whole array size
78
+ input_shape = tile_infos[0].array_shape
79
+ predicted_image = np.zeros(input_shape, dtype=np.float32)
80
+
81
+ for tile, tile_info in zip(tiles, tile_infos):
82
+ n_channels = tile.shape[0]
83
+
84
+ # Compute coordinates for cropping predicted tile
85
+ slices = (slice(0, n_channels),) + tuple(
86
+ [slice(c[0], c[1]) for c in tile_info.overlap_crop_coords]
87
+ )
88
+
89
+ # Crop predited tile according to overlap coordinates
90
+ cropped_tile = tile[slices]
91
+
92
+ # Insert cropped tile into predicted image using stitch coordinates
93
+ predicted_image[
94
+ (
95
+ ...,
96
+ *[slice(c[0], c[1]) for c in tile_info.stitch_coords],
97
+ )
98
+ ] = cropped_tile.astype(np.float32)
99
+
100
+ return predicted_image
@@ -60,7 +60,7 @@ class N2VManipulate(Transform):
60
60
  remove_center: bool = True,
61
61
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
62
62
  struct_mask_span: int = 5,
63
- seed: Optional[int] = None, # TODO use in pixel manipulation
63
+ seed: Optional[int] = None,
64
64
  ):
65
65
  """Constructor.
66
66
 
@@ -127,6 +127,7 @@ class N2VManipulate(Transform):
127
127
  subpatch_size=self.roi_size,
128
128
  remove_center=self.remove_center,
129
129
  struct_params=self.struct_mask,
130
+ rng=self.rng,
130
131
  )
131
132
  elif self.strategy == SupportedPixelManipulation.MEDIAN:
132
133
  # Iterate over the channels to apply manipulation separately
@@ -136,6 +137,7 @@ class N2VManipulate(Transform):
136
137
  mask_pixel_percentage=self.masked_pixel_percentage,
137
138
  subpatch_size=self.roi_size,
138
139
  struct_params=self.struct_mask,
140
+ rng=self.rng,
139
141
  )
140
142
  else:
141
143
  raise ValueError(f"Unknown masking strategy ({self.strategy}).")
@@ -1,12 +1,34 @@
1
1
  """Normalization and denormalization transforms for image patches."""
2
2
 
3
- from typing import Optional, Tuple
3
+ from typing import Optional
4
4
 
5
5
  import numpy as np
6
+ from numpy.typing import NDArray
6
7
 
7
8
  from careamics.transforms.transform import Transform
8
9
 
9
10
 
11
+ def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
12
+ """Reshape stats to match the number of dimensions of the input image.
13
+
14
+ This allows to broadcast the stats (mean or std) to the image dimensions, and
15
+ thus directly perform a vectorial calculation.
16
+
17
+ Parameters
18
+ ----------
19
+ stats : list of float
20
+ List of stats, mean or standard deviation.
21
+ ndim : int
22
+ Number of dimensions of the image, including the C channel.
23
+
24
+ Returns
25
+ -------
26
+ NDArray
27
+ Reshaped stats.
28
+ """
29
+ return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
30
+
31
+
10
32
  class Normalize(Transform):
11
33
  """
12
34
  Normalize an image or image patch.
@@ -19,154 +41,203 @@ class Normalize(Transform):
19
41
 
20
42
  Parameters
21
43
  ----------
22
- mean : float
23
- Mean value.
24
- std : float
25
- Standard deviation value.
44
+ image_means : list of float
45
+ Mean value per channel.
46
+ image_stds : list of float
47
+ Standard deviation value per channel.
48
+ target_means : list of float, optional
49
+ Target mean value per channel, by default None.
50
+ target_stds : list of float, optional
51
+ Target standard deviation value per channel, by default None.
26
52
 
27
53
  Attributes
28
54
  ----------
29
- mean : float
30
- Mean value.
31
- std : float
32
- Standard deviation value.
55
+ image_means : list of float
56
+ Mean value per channel.
57
+ image_stds : list of float
58
+ Standard deviation value per channel.
59
+ target_means :list of float, optional
60
+ Target mean value per channel, by default None.
61
+ target_stds : list of float, optional
62
+ Target standard deviation value per channel, by default None.
33
63
  """
34
64
 
35
65
  def __init__(
36
66
  self,
37
- mean: float,
38
- std: float,
67
+ image_means: list[float],
68
+ image_stds: list[float],
69
+ target_means: Optional[list[float]] = None,
70
+ target_stds: Optional[list[float]] = None,
39
71
  ):
40
72
  """Constructor.
41
73
 
42
74
  Parameters
43
75
  ----------
44
- mean : float
45
- Mean value.
46
- std : float
47
- Standard deviation value.
76
+ image_means : list of float
77
+ Mean value per channel.
78
+ image_stds : list of float
79
+ Standard deviation value per channel.
80
+ target_means : list of float, optional
81
+ Target mean value per channel, by default None.
82
+ target_stds : list of float, optional
83
+ Target standard deviation value per channel, by default None.
48
84
  """
49
- self.mean = mean
50
- self.std = std
85
+ self.image_means = image_means
86
+ self.image_stds = image_stds
87
+ self.target_means = target_means
88
+ self.target_stds = target_stds
89
+
51
90
  self.eps = 1e-6
52
91
 
53
92
  def __call__(
54
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
55
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
93
+ self, patch: np.ndarray, target: Optional[NDArray] = None
94
+ ) -> tuple[NDArray, Optional[NDArray]]:
56
95
  """Apply the transform to the source patch and the target (optional).
57
96
 
58
97
  Parameters
59
98
  ----------
60
- patch : np.ndarray
99
+ patch : NDArray
61
100
  Patch, 2D or 3D, shape C(Z)YX.
62
- target : Optional[np.ndarray], optional
101
+ target : NDArray, optional
63
102
  Target for the patch, by default None.
64
103
 
65
104
  Returns
66
105
  -------
67
- Tuple[np.ndarray, Optional[np.ndarray]]
68
- Transformed patch and target.
106
+ tuple of NDArray
107
+ Transformed patch and target, the target can be returned as `None`.
69
108
  """
70
- norm_patch = self._apply(patch)
71
- norm_target = self._apply(target) if target is not None else None
109
+ if len(self.image_means) != patch.shape[0]:
110
+ raise ValueError(
111
+ f"Number of means (got a list of size {len(self.image_means)}) and "
112
+ f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
113
+ )
114
+
115
+ # reshape mean and std and apply the normalization to the patch
116
+ means = _reshape_stats(self.image_means, patch.ndim)
117
+ stds = _reshape_stats(self.image_stds, patch.ndim)
118
+ norm_patch = self._apply(patch, means, stds)
119
+
120
+ # same for the target patch
121
+ if (
122
+ target is not None
123
+ and self.target_means is not None
124
+ and self.target_stds is not None
125
+ ):
126
+ target_means = _reshape_stats(self.target_means, target.ndim)
127
+ target_stds = _reshape_stats(self.target_stds, target.ndim)
128
+ norm_target = self._apply(target, target_means, target_stds)
129
+ else:
130
+ norm_target = None
72
131
 
73
132
  return norm_patch, norm_target
74
133
 
75
- def _apply(self, patch: np.ndarray) -> np.ndarray:
134
+ def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
76
135
  """
77
136
  Apply the transform to the image.
78
137
 
79
138
  Parameters
80
139
  ----------
81
- patch : np.ndarray
140
+ patch : NDArray
82
141
  Image patch, 2D or 3D, shape C(Z)YX.
142
+ mean : NDArray
143
+ Mean values.
144
+ std : NDArray
145
+ Standard deviations.
83
146
 
84
147
  Returns
85
148
  -------
86
- np.ndarray
87
- Normalizedimage patch.
149
+ NDArray
150
+ Normalized image patch.
88
151
  """
89
- return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
152
+ return ((patch - mean) / (std + self.eps)).astype(np.float32)
90
153
 
91
154
 
92
155
  class Denormalize:
93
156
  """
94
- Denormalize an image or image patch.
157
+ Denormalize an image.
95
158
 
96
159
  Denormalization is performed expecting a zero mean and unit variance input. This
97
160
  transform expects C(Z)YX dimensions.
98
161
 
99
- Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
162
+ Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
100
163
  division by zero during the normalization step, which is taken into account during
101
164
  denormalization.
102
165
 
103
166
  Parameters
104
167
  ----------
105
- mean : float
106
- Mean value.
107
- std : float
108
- Standard deviation value.
168
+ image_means : list or tuple of float
169
+ Mean value per channel.
170
+ image_stds : list or tuple of float
171
+ Standard deviation value per channel.
109
172
 
110
- Attributes
111
- ----------
112
- mean : float
113
- Mean value.
114
- std : float
115
- Standard deviation value.
116
173
  """
117
174
 
118
175
  def __init__(
119
176
  self,
120
- mean: float,
121
- std: float,
177
+ image_means: list[float],
178
+ image_stds: list[float],
122
179
  ):
123
180
  """Constructor.
124
181
 
125
182
  Parameters
126
183
  ----------
127
- mean : float
128
- Mean.
129
- std : float
130
- Standard deviation.
184
+ image_means : list of float
185
+ Mean value per channel.
186
+ image_stds : list of float
187
+ Standard deviation value per channel.
131
188
  """
132
- self.mean = mean
133
- self.std = std
189
+ self.image_means = image_means
190
+ self.image_stds = image_stds
191
+
134
192
  self.eps = 1e-6
135
193
 
136
- def __call__(
137
- self, patch: np.ndarray, target: Optional[np.ndarray] = None
138
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
139
- """Apply the transform to the source patch and the target (optional).
194
+ def __call__(self, patch: NDArray) -> NDArray:
195
+ """Reverse the normalization operation for a batch of patches.
140
196
 
141
197
  Parameters
142
198
  ----------
143
- patch : np.ndarray
144
- Patch, 2D or 3D, shape C(Z)YX.
145
- target : Optional[np.ndarray], optional
146
- Target for the patch, by default None.
199
+ patch : NDArray
200
+ Patch, 2D or 3D, shape BC(Z)YX.
147
201
 
148
202
  Returns
149
203
  -------
150
- Tuple[np.ndarray, Optional[np.ndarray]]
151
- Transformed patch and target.
204
+ NDArray
205
+ Transformed array.
152
206
  """
153
- norm_patch = self._apply(patch)
154
- norm_target = self._apply(target) if target is not None else None
207
+ if len(self.image_means) != patch.shape[1]:
208
+ raise ValueError(
209
+ f"Number of means (got a list of size {len(self.image_means)}) and "
210
+ f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
211
+ f"match."
212
+ )
155
213
 
156
- return norm_patch, norm_target
214
+ means = _reshape_stats(self.image_means, patch.ndim)
215
+ stds = _reshape_stats(self.image_stds, patch.ndim)
216
+
217
+ denorm_array = self._apply(
218
+ patch,
219
+ np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
220
+ np.swapaxes(stds, 0, 1),
221
+ )
222
+
223
+ return denorm_array.astype(np.float32)
157
224
 
158
- def _apply(self, patch: np.ndarray) -> np.ndarray:
225
+ def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
159
226
  """
160
227
  Apply the transform to the image.
161
228
 
162
229
  Parameters
163
230
  ----------
164
- patch : np.ndarray
231
+ array : NDArray
165
232
  Image patch, 2D or 3D, shape C(Z)YX.
233
+ mean : NDArray
234
+ Mean values.
235
+ std : NDArray
236
+ Standard deviations.
166
237
 
167
238
  Returns
168
239
  -------
169
- np.ndarray
170
- Denormalized image patch.
240
+ NDArray
241
+ Denormalized image array.
171
242
  """
172
- return patch * (self.std + self.eps) + self.mean
243
+ return array * (std + self.eps) + mean