careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Literal, Tuple, Union, overload
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from ..config.tile_information import TileInformation
|
|
9
|
+
from .stitch_prediction import stitch_prediction
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def convert_outputs(
|
|
13
|
+
predictions: List[Any], tiled: bool
|
|
14
|
+
) -> Union[List[NDArray], NDArray]:
|
|
15
|
+
"""
|
|
16
|
+
Convert the outputs to the desired form.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
predictions : list
|
|
21
|
+
Predictions that are output from `Trainer.predict`.
|
|
22
|
+
tiled : bool
|
|
23
|
+
Whether the predictions are tiled.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
list of numpy.ndarray or numpy.ndarray
|
|
28
|
+
List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
|
|
29
|
+
be in a list.
|
|
30
|
+
"""
|
|
31
|
+
if len(predictions) == 0:
|
|
32
|
+
return predictions
|
|
33
|
+
|
|
34
|
+
# this layout is to stop mypy complaining
|
|
35
|
+
if tiled:
|
|
36
|
+
predictions_comb = combine_batches(predictions, tiled)
|
|
37
|
+
# remove sample dimension (always 1) `stitch_predict` func expects no S dim
|
|
38
|
+
tiles = [pred[0] for pred in predictions_comb[0]]
|
|
39
|
+
tile_infos = predictions_comb[1]
|
|
40
|
+
predictions_output = stitch_prediction(tiles, tile_infos)
|
|
41
|
+
else:
|
|
42
|
+
predictions_output = combine_batches(predictions, tiled)
|
|
43
|
+
|
|
44
|
+
# TODO: add this in? Returns output with same axes as input
|
|
45
|
+
# Won't work with tiling rn because stitch_prediction func removes S axis
|
|
46
|
+
# predictions = reshape(predictions, axes)
|
|
47
|
+
# At least make sure stitched prediction and non-tiled prediction have matching axes
|
|
48
|
+
|
|
49
|
+
# TODO: might want to remove this
|
|
50
|
+
if len(predictions_output) == 1:
|
|
51
|
+
return predictions_output[0]
|
|
52
|
+
return predictions_output
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# for mypy
|
|
56
|
+
@overload
|
|
57
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
58
|
+
predictions: List[Any], tiled: Literal[True]
|
|
59
|
+
) -> Tuple[List[NDArray], List[TileInformation]]: ...
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# for mypy
|
|
63
|
+
@overload
|
|
64
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
65
|
+
predictions: List[Any], tiled: Literal[False]
|
|
66
|
+
) -> List[NDArray]: ...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# for mypy
|
|
70
|
+
@overload
|
|
71
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
72
|
+
predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
|
|
73
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def combine_batches(
|
|
77
|
+
predictions: List[Any], tiled: bool
|
|
78
|
+
) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
|
|
79
|
+
"""
|
|
80
|
+
If predictions are in batches, they will be combined.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
predictions : list
|
|
85
|
+
Predictions that are output from `Trainer.predict`.
|
|
86
|
+
tiled : bool
|
|
87
|
+
Whether the predictions are tiled.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
(list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
|
|
92
|
+
Combined batches.
|
|
93
|
+
"""
|
|
94
|
+
if tiled:
|
|
95
|
+
return _combine_tiled_batches(predictions)
|
|
96
|
+
else:
|
|
97
|
+
return _combine_untiled_batches(predictions)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _combine_tiled_batches(
|
|
101
|
+
predictions: List[Tuple[NDArray, List[TileInformation]]]
|
|
102
|
+
) -> Tuple[List[NDArray], List[TileInformation]]:
|
|
103
|
+
"""
|
|
104
|
+
Combine batches from tiled output.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
predictions : list
|
|
109
|
+
Predictions that are output from `Trainer.predict`.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
114
|
+
Combined batches.
|
|
115
|
+
"""
|
|
116
|
+
# turn list of lists into single list
|
|
117
|
+
tile_infos = [
|
|
118
|
+
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
119
|
+
]
|
|
120
|
+
prediction_tiles: List[NDArray] = _combine_untiled_batches(
|
|
121
|
+
[preds for preds, _ in predictions]
|
|
122
|
+
)
|
|
123
|
+
return prediction_tiles, tile_infos
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _combine_untiled_batches(predictions: List[NDArray]) -> List[NDArray]:
|
|
127
|
+
"""
|
|
128
|
+
Combine batches from un-tiled output.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
predictions : list
|
|
133
|
+
Predictions that are output from `Trainer.predict`.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
list of nunpy.ndarray
|
|
138
|
+
Combined batches.
|
|
139
|
+
"""
|
|
140
|
+
prediction_concat: NDArray = np.concatenate(predictions, axis=0)
|
|
141
|
+
prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
|
|
142
|
+
return prediction_split
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def reshape(predictions: List[NDArray], axes: str) -> List[NDArray]:
|
|
146
|
+
"""
|
|
147
|
+
Reshape predictions to have dimensions of input.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
predictions : list
|
|
152
|
+
Predictions that are output from `Trainer.predict`.
|
|
153
|
+
axes : str
|
|
154
|
+
Axes SC(Z)YX.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
List[NDArray]
|
|
159
|
+
Reshaped predicitions.
|
|
160
|
+
"""
|
|
161
|
+
if "C" not in axes:
|
|
162
|
+
predictions = [pred[:, 0] for pred in predictions]
|
|
163
|
+
if "S" not in axes:
|
|
164
|
+
predictions = [pred[0] for pred in predictions]
|
|
165
|
+
return predictions
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Prediction utility functions."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.config.tile_information import TileInformation
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# TODO: why not allow input and output of torch.tensor ?
|
|
11
|
+
def stitch_prediction(
|
|
12
|
+
tiles: List[np.ndarray],
|
|
13
|
+
tile_infos: List[TileInformation],
|
|
14
|
+
) -> List[np.ndarray]:
|
|
15
|
+
"""
|
|
16
|
+
Stitch tiles back together to form a full image(s).
|
|
17
|
+
|
|
18
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
19
|
+
singleton dimension.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
tiles : list of numpy.ndarray
|
|
24
|
+
Cropped tiles and their respective stitching coordinates. Can contain tiles
|
|
25
|
+
from multiple images.
|
|
26
|
+
tile_infos : list of TileInformation
|
|
27
|
+
List of information and coordinates obtained from
|
|
28
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
list of numpy.ndarray
|
|
33
|
+
Full image(s).
|
|
34
|
+
"""
|
|
35
|
+
# Find where to split the lists so that only info from one image is contained.
|
|
36
|
+
# Do this by locating the last tiles of each image.
|
|
37
|
+
last_tiles = [tile_info.last_tile for tile_info in tile_infos]
|
|
38
|
+
last_tile_position = np.where(last_tiles)[0]
|
|
39
|
+
image_slices = [
|
|
40
|
+
slice(
|
|
41
|
+
None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
|
|
42
|
+
)
|
|
43
|
+
for i in range(len(last_tile_position))
|
|
44
|
+
]
|
|
45
|
+
image_predictions = []
|
|
46
|
+
# slice the lists and apply stitch_prediction_single to each in turn.
|
|
47
|
+
for image_slice in image_slices:
|
|
48
|
+
image_predictions.append(
|
|
49
|
+
stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
|
|
50
|
+
)
|
|
51
|
+
return image_predictions
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def stitch_prediction_single(
|
|
55
|
+
tiles: List[np.ndarray],
|
|
56
|
+
tile_infos: List[TileInformation],
|
|
57
|
+
) -> np.ndarray:
|
|
58
|
+
"""
|
|
59
|
+
Stitch tiles back together to form a full image.
|
|
60
|
+
|
|
61
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
62
|
+
singleton dimension.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
tiles : list of numpy.ndarray
|
|
67
|
+
Cropped tiles and their respective stitching coordinates.
|
|
68
|
+
tile_infos : list of TileInformation
|
|
69
|
+
List of information and coordinates obtained from
|
|
70
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
numpy.ndarray
|
|
75
|
+
Full image.
|
|
76
|
+
"""
|
|
77
|
+
# retrieve whole array size
|
|
78
|
+
input_shape = tile_infos[0].array_shape
|
|
79
|
+
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
80
|
+
|
|
81
|
+
for tile, tile_info in zip(tiles, tile_infos):
|
|
82
|
+
n_channels = tile.shape[0]
|
|
83
|
+
|
|
84
|
+
# Compute coordinates for cropping predicted tile
|
|
85
|
+
slices = (slice(0, n_channels),) + tuple(
|
|
86
|
+
[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords]
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Crop predited tile according to overlap coordinates
|
|
90
|
+
cropped_tile = tile[slices]
|
|
91
|
+
|
|
92
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
93
|
+
predicted_image[
|
|
94
|
+
(
|
|
95
|
+
...,
|
|
96
|
+
*[slice(c[0], c[1]) for c in tile_info.stitch_coords],
|
|
97
|
+
)
|
|
98
|
+
] = cropped_tile.astype(np.float32)
|
|
99
|
+
|
|
100
|
+
return predicted_image
|
|
@@ -60,7 +60,7 @@ class N2VManipulate(Transform):
|
|
|
60
60
|
remove_center: bool = True,
|
|
61
61
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
62
62
|
struct_mask_span: int = 5,
|
|
63
|
-
seed: Optional[int] = None,
|
|
63
|
+
seed: Optional[int] = None,
|
|
64
64
|
):
|
|
65
65
|
"""Constructor.
|
|
66
66
|
|
|
@@ -127,6 +127,7 @@ class N2VManipulate(Transform):
|
|
|
127
127
|
subpatch_size=self.roi_size,
|
|
128
128
|
remove_center=self.remove_center,
|
|
129
129
|
struct_params=self.struct_mask,
|
|
130
|
+
rng=self.rng,
|
|
130
131
|
)
|
|
131
132
|
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
132
133
|
# Iterate over the channels to apply manipulation separately
|
|
@@ -136,6 +137,7 @@ class N2VManipulate(Transform):
|
|
|
136
137
|
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
137
138
|
subpatch_size=self.roi_size,
|
|
138
139
|
struct_params=self.struct_mask,
|
|
140
|
+
rng=self.rng,
|
|
139
141
|
)
|
|
140
142
|
else:
|
|
141
143
|
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
@@ -1,12 +1,34 @@
|
|
|
1
1
|
"""Normalization and denormalization transforms for image patches."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.transforms.transform import Transform
|
|
8
9
|
|
|
9
10
|
|
|
11
|
+
def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
|
|
12
|
+
"""Reshape stats to match the number of dimensions of the input image.
|
|
13
|
+
|
|
14
|
+
This allows to broadcast the stats (mean or std) to the image dimensions, and
|
|
15
|
+
thus directly perform a vectorial calculation.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
stats : list of float
|
|
20
|
+
List of stats, mean or standard deviation.
|
|
21
|
+
ndim : int
|
|
22
|
+
Number of dimensions of the image, including the C channel.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
NDArray
|
|
27
|
+
Reshaped stats.
|
|
28
|
+
"""
|
|
29
|
+
return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
|
|
30
|
+
|
|
31
|
+
|
|
10
32
|
class Normalize(Transform):
|
|
11
33
|
"""
|
|
12
34
|
Normalize an image or image patch.
|
|
@@ -19,154 +41,203 @@ class Normalize(Transform):
|
|
|
19
41
|
|
|
20
42
|
Parameters
|
|
21
43
|
----------
|
|
22
|
-
|
|
23
|
-
Mean value.
|
|
24
|
-
|
|
25
|
-
Standard deviation value.
|
|
44
|
+
image_means : list of float
|
|
45
|
+
Mean value per channel.
|
|
46
|
+
image_stds : list of float
|
|
47
|
+
Standard deviation value per channel.
|
|
48
|
+
target_means : list of float, optional
|
|
49
|
+
Target mean value per channel, by default None.
|
|
50
|
+
target_stds : list of float, optional
|
|
51
|
+
Target standard deviation value per channel, by default None.
|
|
26
52
|
|
|
27
53
|
Attributes
|
|
28
54
|
----------
|
|
29
|
-
|
|
30
|
-
Mean value.
|
|
31
|
-
|
|
32
|
-
Standard deviation value.
|
|
55
|
+
image_means : list of float
|
|
56
|
+
Mean value per channel.
|
|
57
|
+
image_stds : list of float
|
|
58
|
+
Standard deviation value per channel.
|
|
59
|
+
target_means :list of float, optional
|
|
60
|
+
Target mean value per channel, by default None.
|
|
61
|
+
target_stds : list of float, optional
|
|
62
|
+
Target standard deviation value per channel, by default None.
|
|
33
63
|
"""
|
|
34
64
|
|
|
35
65
|
def __init__(
|
|
36
66
|
self,
|
|
37
|
-
|
|
38
|
-
|
|
67
|
+
image_means: list[float],
|
|
68
|
+
image_stds: list[float],
|
|
69
|
+
target_means: Optional[list[float]] = None,
|
|
70
|
+
target_stds: Optional[list[float]] = None,
|
|
39
71
|
):
|
|
40
72
|
"""Constructor.
|
|
41
73
|
|
|
42
74
|
Parameters
|
|
43
75
|
----------
|
|
44
|
-
|
|
45
|
-
Mean value.
|
|
46
|
-
|
|
47
|
-
Standard deviation value.
|
|
76
|
+
image_means : list of float
|
|
77
|
+
Mean value per channel.
|
|
78
|
+
image_stds : list of float
|
|
79
|
+
Standard deviation value per channel.
|
|
80
|
+
target_means : list of float, optional
|
|
81
|
+
Target mean value per channel, by default None.
|
|
82
|
+
target_stds : list of float, optional
|
|
83
|
+
Target standard deviation value per channel, by default None.
|
|
48
84
|
"""
|
|
49
|
-
self.
|
|
50
|
-
self.
|
|
85
|
+
self.image_means = image_means
|
|
86
|
+
self.image_stds = image_stds
|
|
87
|
+
self.target_means = target_means
|
|
88
|
+
self.target_stds = target_stds
|
|
89
|
+
|
|
51
90
|
self.eps = 1e-6
|
|
52
91
|
|
|
53
92
|
def __call__(
|
|
54
|
-
self, patch: np.ndarray, target: Optional[
|
|
55
|
-
) ->
|
|
93
|
+
self, patch: np.ndarray, target: Optional[NDArray] = None
|
|
94
|
+
) -> tuple[NDArray, Optional[NDArray]]:
|
|
56
95
|
"""Apply the transform to the source patch and the target (optional).
|
|
57
96
|
|
|
58
97
|
Parameters
|
|
59
98
|
----------
|
|
60
|
-
patch :
|
|
99
|
+
patch : NDArray
|
|
61
100
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
62
|
-
target :
|
|
101
|
+
target : NDArray, optional
|
|
63
102
|
Target for the patch, by default None.
|
|
64
103
|
|
|
65
104
|
Returns
|
|
66
105
|
-------
|
|
67
|
-
|
|
68
|
-
Transformed patch and target
|
|
106
|
+
tuple of NDArray
|
|
107
|
+
Transformed patch and target, the target can be returned as `None`.
|
|
69
108
|
"""
|
|
70
|
-
|
|
71
|
-
|
|
109
|
+
if len(self.image_means) != patch.shape[0]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
112
|
+
f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# reshape mean and std and apply the normalization to the patch
|
|
116
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
117
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
118
|
+
norm_patch = self._apply(patch, means, stds)
|
|
119
|
+
|
|
120
|
+
# same for the target patch
|
|
121
|
+
if (
|
|
122
|
+
target is not None
|
|
123
|
+
and self.target_means is not None
|
|
124
|
+
and self.target_stds is not None
|
|
125
|
+
):
|
|
126
|
+
target_means = _reshape_stats(self.target_means, target.ndim)
|
|
127
|
+
target_stds = _reshape_stats(self.target_stds, target.ndim)
|
|
128
|
+
norm_target = self._apply(target, target_means, target_stds)
|
|
129
|
+
else:
|
|
130
|
+
norm_target = None
|
|
72
131
|
|
|
73
132
|
return norm_patch, norm_target
|
|
74
133
|
|
|
75
|
-
def _apply(self, patch:
|
|
134
|
+
def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
76
135
|
"""
|
|
77
136
|
Apply the transform to the image.
|
|
78
137
|
|
|
79
138
|
Parameters
|
|
80
139
|
----------
|
|
81
|
-
patch :
|
|
140
|
+
patch : NDArray
|
|
82
141
|
Image patch, 2D or 3D, shape C(Z)YX.
|
|
142
|
+
mean : NDArray
|
|
143
|
+
Mean values.
|
|
144
|
+
std : NDArray
|
|
145
|
+
Standard deviations.
|
|
83
146
|
|
|
84
147
|
Returns
|
|
85
148
|
-------
|
|
86
|
-
|
|
87
|
-
|
|
149
|
+
NDArray
|
|
150
|
+
Normalized image patch.
|
|
88
151
|
"""
|
|
89
|
-
return ((patch -
|
|
152
|
+
return ((patch - mean) / (std + self.eps)).astype(np.float32)
|
|
90
153
|
|
|
91
154
|
|
|
92
155
|
class Denormalize:
|
|
93
156
|
"""
|
|
94
|
-
Denormalize an image
|
|
157
|
+
Denormalize an image.
|
|
95
158
|
|
|
96
159
|
Denormalization is performed expecting a zero mean and unit variance input. This
|
|
97
160
|
transform expects C(Z)YX dimensions.
|
|
98
161
|
|
|
99
|
-
|
|
162
|
+
Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
100
163
|
division by zero during the normalization step, which is taken into account during
|
|
101
164
|
denormalization.
|
|
102
165
|
|
|
103
166
|
Parameters
|
|
104
167
|
----------
|
|
105
|
-
|
|
106
|
-
Mean value.
|
|
107
|
-
|
|
108
|
-
Standard deviation value.
|
|
168
|
+
image_means : list or tuple of float
|
|
169
|
+
Mean value per channel.
|
|
170
|
+
image_stds : list or tuple of float
|
|
171
|
+
Standard deviation value per channel.
|
|
109
172
|
|
|
110
|
-
Attributes
|
|
111
|
-
----------
|
|
112
|
-
mean : float
|
|
113
|
-
Mean value.
|
|
114
|
-
std : float
|
|
115
|
-
Standard deviation value.
|
|
116
173
|
"""
|
|
117
174
|
|
|
118
175
|
def __init__(
|
|
119
176
|
self,
|
|
120
|
-
|
|
121
|
-
|
|
177
|
+
image_means: list[float],
|
|
178
|
+
image_stds: list[float],
|
|
122
179
|
):
|
|
123
180
|
"""Constructor.
|
|
124
181
|
|
|
125
182
|
Parameters
|
|
126
183
|
----------
|
|
127
|
-
|
|
128
|
-
Mean.
|
|
129
|
-
|
|
130
|
-
Standard deviation.
|
|
184
|
+
image_means : list of float
|
|
185
|
+
Mean value per channel.
|
|
186
|
+
image_stds : list of float
|
|
187
|
+
Standard deviation value per channel.
|
|
131
188
|
"""
|
|
132
|
-
self.
|
|
133
|
-
self.
|
|
189
|
+
self.image_means = image_means
|
|
190
|
+
self.image_stds = image_stds
|
|
191
|
+
|
|
134
192
|
self.eps = 1e-6
|
|
135
193
|
|
|
136
|
-
def __call__(
|
|
137
|
-
|
|
138
|
-
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
139
|
-
"""Apply the transform to the source patch and the target (optional).
|
|
194
|
+
def __call__(self, patch: NDArray) -> NDArray:
|
|
195
|
+
"""Reverse the normalization operation for a batch of patches.
|
|
140
196
|
|
|
141
197
|
Parameters
|
|
142
198
|
----------
|
|
143
|
-
patch :
|
|
144
|
-
Patch, 2D or 3D, shape
|
|
145
|
-
target : Optional[np.ndarray], optional
|
|
146
|
-
Target for the patch, by default None.
|
|
199
|
+
patch : NDArray
|
|
200
|
+
Patch, 2D or 3D, shape BC(Z)YX.
|
|
147
201
|
|
|
148
202
|
Returns
|
|
149
203
|
-------
|
|
150
|
-
|
|
151
|
-
Transformed
|
|
204
|
+
NDArray
|
|
205
|
+
Transformed array.
|
|
152
206
|
"""
|
|
153
|
-
|
|
154
|
-
|
|
207
|
+
if len(self.image_means) != patch.shape[1]:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
210
|
+
f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
|
|
211
|
+
f"match."
|
|
212
|
+
)
|
|
155
213
|
|
|
156
|
-
|
|
214
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
215
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
216
|
+
|
|
217
|
+
denorm_array = self._apply(
|
|
218
|
+
patch,
|
|
219
|
+
np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
|
|
220
|
+
np.swapaxes(stds, 0, 1),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return denorm_array.astype(np.float32)
|
|
157
224
|
|
|
158
|
-
def _apply(self,
|
|
225
|
+
def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
159
226
|
"""
|
|
160
227
|
Apply the transform to the image.
|
|
161
228
|
|
|
162
229
|
Parameters
|
|
163
230
|
----------
|
|
164
|
-
|
|
231
|
+
array : NDArray
|
|
165
232
|
Image patch, 2D or 3D, shape C(Z)YX.
|
|
233
|
+
mean : NDArray
|
|
234
|
+
Mean values.
|
|
235
|
+
std : NDArray
|
|
236
|
+
Standard deviations.
|
|
166
237
|
|
|
167
238
|
Returns
|
|
168
239
|
-------
|
|
169
|
-
|
|
170
|
-
Denormalized image
|
|
240
|
+
NDArray
|
|
241
|
+
Denormalized image array.
|
|
171
242
|
"""
|
|
172
|
-
return
|
|
243
|
+
return array * (std + self.eps) + mean
|