careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""A class chaining transforms together."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Optional, Tuple, cast
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.config.data_model import TRANSFORMS_UNION
|
|
8
|
+
|
|
9
|
+
from .n2v_manipulate import N2VManipulate
|
|
10
|
+
from .normalize import Normalize
|
|
11
|
+
from .xy_flip import XYFlip
|
|
12
|
+
from .xy_random_rotate90 import XYRandomRotate90
|
|
13
|
+
|
|
14
|
+
ALL_TRANSFORMS = {
|
|
15
|
+
"Normalize": Normalize,
|
|
16
|
+
"N2VManipulate": N2VManipulate,
|
|
17
|
+
"XYFlip": XYFlip,
|
|
18
|
+
"XYRandomRotate90": XYRandomRotate90,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_all_transforms() -> Dict[str, type]:
|
|
23
|
+
"""Return all the transforms accepted by CAREamics.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
dict
|
|
28
|
+
A dictionary with all the transforms accepted by CAREamics, where the keys are
|
|
29
|
+
the transform names and the values are the transform classes.
|
|
30
|
+
"""
|
|
31
|
+
return ALL_TRANSFORMS
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Compose:
|
|
35
|
+
"""A class chaining transforms together.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
transform_list : List[TRANSFORMS_UNION]
|
|
40
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
41
|
+
transform and its parameters.
|
|
42
|
+
|
|
43
|
+
Attributes
|
|
44
|
+
----------
|
|
45
|
+
_callable_transforms : Callable
|
|
46
|
+
A callable that applies the transforms to the input data.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
|
|
50
|
+
"""Instantiate a Compose object.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
transform_list : List[TRANSFORMS_UNION]
|
|
55
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
56
|
+
transform and its parameters.
|
|
57
|
+
"""
|
|
58
|
+
# retrieve all available transforms
|
|
59
|
+
all_transforms = get_all_transforms()
|
|
60
|
+
|
|
61
|
+
# instantiate all transforms
|
|
62
|
+
self.transforms = [
|
|
63
|
+
all_transforms[t.name](**t.model_dump()) for t in transform_list
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
def _chain_transforms(
|
|
67
|
+
self, patch: np.ndarray, target: Optional[np.ndarray]
|
|
68
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
69
|
+
"""Chain transforms on the input data.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
patch : np.ndarray
|
|
74
|
+
Input data.
|
|
75
|
+
target : Optional[np.ndarray]
|
|
76
|
+
Target data, by default None.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
81
|
+
The output of the transformations.
|
|
82
|
+
"""
|
|
83
|
+
params = (patch, target)
|
|
84
|
+
|
|
85
|
+
for t in self.transforms:
|
|
86
|
+
params = t(*params)
|
|
87
|
+
|
|
88
|
+
return params
|
|
89
|
+
|
|
90
|
+
def __call__(
|
|
91
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
92
|
+
) -> Tuple[np.ndarray, ...]:
|
|
93
|
+
"""Apply the transforms to the input data.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
patch : np.ndarray
|
|
98
|
+
The input data.
|
|
99
|
+
target : Optional[np.ndarray], optional
|
|
100
|
+
Target data, by default None.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
Tuple[np.ndarray, ...]
|
|
105
|
+
The output of the transformations.
|
|
106
|
+
"""
|
|
107
|
+
return cast(Tuple[np.ndarray, ...], self._chain_transforms(patch, target))
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""N2V manipulation transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
8
|
+
from careamics.transforms.transform import Transform
|
|
9
|
+
|
|
10
|
+
from .pixel_manipulation import median_manipulate, uniform_manipulate
|
|
11
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class N2VManipulate(Transform):
|
|
15
|
+
"""
|
|
16
|
+
Default augmentation for the N2V model.
|
|
17
|
+
|
|
18
|
+
This transform expects C(Z)YX dimensions.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
roi_size : int, optional
|
|
23
|
+
Size of the replacement area, by default 11.
|
|
24
|
+
masked_pixel_percentage : float, optional
|
|
25
|
+
Percentage of pixels to mask, by default 0.2.
|
|
26
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
27
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
28
|
+
remove_center : bool, optional
|
|
29
|
+
Whether to remove central pixel from patch, by default True.
|
|
30
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
31
|
+
StructN2V mask axis, by default "none".
|
|
32
|
+
struct_mask_span : int, optional
|
|
33
|
+
StructN2V mask span, by default 5.
|
|
34
|
+
seed : Optional[int], optional
|
|
35
|
+
Random seed, by default None.
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
masked_pixel_percentage : float
|
|
40
|
+
Percentage of pixels to mask.
|
|
41
|
+
roi_size : int
|
|
42
|
+
Size of the replacement area.
|
|
43
|
+
strategy : Literal[ "uniform", "median" ]
|
|
44
|
+
Replaccement strategy, uniform or median.
|
|
45
|
+
remove_center : bool
|
|
46
|
+
Whether to remove central pixel from patch.
|
|
47
|
+
struct_mask : Optional[StructMaskParameters]
|
|
48
|
+
StructN2V mask parameters.
|
|
49
|
+
rng : Generator
|
|
50
|
+
Random number generator.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
roi_size: int = 11,
|
|
56
|
+
masked_pixel_percentage: float = 0.2,
|
|
57
|
+
strategy: Literal[
|
|
58
|
+
"uniform", "median"
|
|
59
|
+
] = SupportedPixelManipulation.UNIFORM.value,
|
|
60
|
+
remove_center: bool = True,
|
|
61
|
+
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
62
|
+
struct_mask_span: int = 5,
|
|
63
|
+
seed: Optional[int] = None,
|
|
64
|
+
):
|
|
65
|
+
"""Constructor.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
roi_size : int, optional
|
|
70
|
+
Size of the replacement area, by default 11.
|
|
71
|
+
masked_pixel_percentage : float, optional
|
|
72
|
+
Percentage of pixels to mask, by default 0.2.
|
|
73
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
74
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
75
|
+
remove_center : bool, optional
|
|
76
|
+
Whether to remove central pixel from patch, by default True.
|
|
77
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
78
|
+
StructN2V mask axis, by default "none".
|
|
79
|
+
struct_mask_span : int, optional
|
|
80
|
+
StructN2V mask span, by default 5.
|
|
81
|
+
seed : Optional[int], optional
|
|
82
|
+
Random seed, by default None.
|
|
83
|
+
"""
|
|
84
|
+
self.masked_pixel_percentage = masked_pixel_percentage
|
|
85
|
+
self.roi_size = roi_size
|
|
86
|
+
self.strategy = strategy
|
|
87
|
+
self.remove_center = remove_center # TODO is this ever used?
|
|
88
|
+
|
|
89
|
+
if struct_mask_axis == SupportedStructAxis.NONE:
|
|
90
|
+
self.struct_mask: Optional[StructMaskParameters] = None
|
|
91
|
+
else:
|
|
92
|
+
self.struct_mask = StructMaskParameters(
|
|
93
|
+
axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1,
|
|
94
|
+
span=struct_mask_span,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# numpy random generator
|
|
98
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
99
|
+
|
|
100
|
+
def __call__(
|
|
101
|
+
self, patch: np.ndarray, *args: Any, **kwargs: Any
|
|
102
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
103
|
+
"""Apply the transform to the image.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
patch : np.ndarray
|
|
108
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
109
|
+
*args : Any
|
|
110
|
+
Additional arguments, unused.
|
|
111
|
+
**kwargs : Any
|
|
112
|
+
Additional keyword arguments, unused.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
117
|
+
Masked patch, original patch, and mask.
|
|
118
|
+
"""
|
|
119
|
+
masked = np.zeros_like(patch)
|
|
120
|
+
mask = np.zeros_like(patch)
|
|
121
|
+
if self.strategy == SupportedPixelManipulation.UNIFORM:
|
|
122
|
+
# Iterate over the channels to apply manipulation separately
|
|
123
|
+
for c in range(patch.shape[0]):
|
|
124
|
+
masked[c, ...], mask[c, ...] = uniform_manipulate(
|
|
125
|
+
patch=patch[c, ...],
|
|
126
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
127
|
+
subpatch_size=self.roi_size,
|
|
128
|
+
remove_center=self.remove_center,
|
|
129
|
+
struct_params=self.struct_mask,
|
|
130
|
+
rng=self.rng,
|
|
131
|
+
)
|
|
132
|
+
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
133
|
+
# Iterate over the channels to apply manipulation separately
|
|
134
|
+
for c in range(patch.shape[0]):
|
|
135
|
+
masked[c, ...], mask[c, ...] = median_manipulate(
|
|
136
|
+
patch=patch[c, ...],
|
|
137
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
138
|
+
subpatch_size=self.roi_size,
|
|
139
|
+
struct_params=self.struct_mask,
|
|
140
|
+
rng=self.rng,
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
144
|
+
|
|
145
|
+
# TODO why return patch?
|
|
146
|
+
return masked, patch, mask
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""Normalization and denormalization transforms for image patches."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.transforms.transform import Transform
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
|
|
12
|
+
"""Reshape stats to match the number of dimensions of the input image.
|
|
13
|
+
|
|
14
|
+
This allows to broadcast the stats (mean or std) to the image dimensions, and
|
|
15
|
+
thus directly perform a vectorial calculation.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
stats : list of float
|
|
20
|
+
List of stats, mean or standard deviation.
|
|
21
|
+
ndim : int
|
|
22
|
+
Number of dimensions of the image, including the C channel.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
NDArray
|
|
27
|
+
Reshaped stats.
|
|
28
|
+
"""
|
|
29
|
+
return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Normalize(Transform):
|
|
33
|
+
"""
|
|
34
|
+
Normalize an image or image patch.
|
|
35
|
+
|
|
36
|
+
Normalization is a zero mean and unit variance. This transform expects C(Z)YX
|
|
37
|
+
dimensions.
|
|
38
|
+
|
|
39
|
+
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
40
|
+
division by zero and that it returns a float32 image.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
image_means : list of float
|
|
45
|
+
Mean value per channel.
|
|
46
|
+
image_stds : list of float
|
|
47
|
+
Standard deviation value per channel.
|
|
48
|
+
target_means : list of float, optional
|
|
49
|
+
Target mean value per channel, by default None.
|
|
50
|
+
target_stds : list of float, optional
|
|
51
|
+
Target standard deviation value per channel, by default None.
|
|
52
|
+
|
|
53
|
+
Attributes
|
|
54
|
+
----------
|
|
55
|
+
image_means : list of float
|
|
56
|
+
Mean value per channel.
|
|
57
|
+
image_stds : list of float
|
|
58
|
+
Standard deviation value per channel.
|
|
59
|
+
target_means :list of float, optional
|
|
60
|
+
Target mean value per channel, by default None.
|
|
61
|
+
target_stds : list of float, optional
|
|
62
|
+
Target standard deviation value per channel, by default None.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
image_means: list[float],
|
|
68
|
+
image_stds: list[float],
|
|
69
|
+
target_means: Optional[list[float]] = None,
|
|
70
|
+
target_stds: Optional[list[float]] = None,
|
|
71
|
+
):
|
|
72
|
+
"""Constructor.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
image_means : list of float
|
|
77
|
+
Mean value per channel.
|
|
78
|
+
image_stds : list of float
|
|
79
|
+
Standard deviation value per channel.
|
|
80
|
+
target_means : list of float, optional
|
|
81
|
+
Target mean value per channel, by default None.
|
|
82
|
+
target_stds : list of float, optional
|
|
83
|
+
Target standard deviation value per channel, by default None.
|
|
84
|
+
"""
|
|
85
|
+
self.image_means = image_means
|
|
86
|
+
self.image_stds = image_stds
|
|
87
|
+
self.target_means = target_means
|
|
88
|
+
self.target_stds = target_stds
|
|
89
|
+
|
|
90
|
+
self.eps = 1e-6
|
|
91
|
+
|
|
92
|
+
def __call__(
|
|
93
|
+
self, patch: np.ndarray, target: Optional[NDArray] = None
|
|
94
|
+
) -> tuple[NDArray, Optional[NDArray]]:
|
|
95
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
patch : NDArray
|
|
100
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
101
|
+
target : NDArray, optional
|
|
102
|
+
Target for the patch, by default None.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
tuple of NDArray
|
|
107
|
+
Transformed patch and target, the target can be returned as `None`.
|
|
108
|
+
"""
|
|
109
|
+
if len(self.image_means) != patch.shape[0]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
112
|
+
f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# reshape mean and std and apply the normalization to the patch
|
|
116
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
117
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
118
|
+
norm_patch = self._apply(patch, means, stds)
|
|
119
|
+
|
|
120
|
+
# same for the target patch
|
|
121
|
+
if (
|
|
122
|
+
target is not None
|
|
123
|
+
and self.target_means is not None
|
|
124
|
+
and self.target_stds is not None
|
|
125
|
+
):
|
|
126
|
+
target_means = _reshape_stats(self.target_means, target.ndim)
|
|
127
|
+
target_stds = _reshape_stats(self.target_stds, target.ndim)
|
|
128
|
+
norm_target = self._apply(target, target_means, target_stds)
|
|
129
|
+
else:
|
|
130
|
+
norm_target = None
|
|
131
|
+
|
|
132
|
+
return norm_patch, norm_target
|
|
133
|
+
|
|
134
|
+
def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
135
|
+
"""
|
|
136
|
+
Apply the transform to the image.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
patch : NDArray
|
|
141
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
142
|
+
mean : NDArray
|
|
143
|
+
Mean values.
|
|
144
|
+
std : NDArray
|
|
145
|
+
Standard deviations.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
NDArray
|
|
150
|
+
Normalized image patch.
|
|
151
|
+
"""
|
|
152
|
+
return ((patch - mean) / (std + self.eps)).astype(np.float32)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class Denormalize:
|
|
156
|
+
"""
|
|
157
|
+
Denormalize an image.
|
|
158
|
+
|
|
159
|
+
Denormalization is performed expecting a zero mean and unit variance input. This
|
|
160
|
+
transform expects C(Z)YX dimensions.
|
|
161
|
+
|
|
162
|
+
Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
163
|
+
division by zero during the normalization step, which is taken into account during
|
|
164
|
+
denormalization.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
image_means : list or tuple of float
|
|
169
|
+
Mean value per channel.
|
|
170
|
+
image_stds : list or tuple of float
|
|
171
|
+
Standard deviation value per channel.
|
|
172
|
+
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
image_means: list[float],
|
|
178
|
+
image_stds: list[float],
|
|
179
|
+
):
|
|
180
|
+
"""Constructor.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
image_means : list of float
|
|
185
|
+
Mean value per channel.
|
|
186
|
+
image_stds : list of float
|
|
187
|
+
Standard deviation value per channel.
|
|
188
|
+
"""
|
|
189
|
+
self.image_means = image_means
|
|
190
|
+
self.image_stds = image_stds
|
|
191
|
+
|
|
192
|
+
self.eps = 1e-6
|
|
193
|
+
|
|
194
|
+
def __call__(self, patch: NDArray) -> NDArray:
|
|
195
|
+
"""Reverse the normalization operation for a batch of patches.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
patch : NDArray
|
|
200
|
+
Patch, 2D or 3D, shape BC(Z)YX.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
NDArray
|
|
205
|
+
Transformed array.
|
|
206
|
+
"""
|
|
207
|
+
if len(self.image_means) != patch.shape[1]:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
210
|
+
f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
|
|
211
|
+
f"match."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
215
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
216
|
+
|
|
217
|
+
denorm_array = self._apply(
|
|
218
|
+
patch,
|
|
219
|
+
np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
|
|
220
|
+
np.swapaxes(stds, 0, 1),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return denorm_array.astype(np.float32)
|
|
224
|
+
|
|
225
|
+
def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
226
|
+
"""
|
|
227
|
+
Apply the transform to the image.
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
array : NDArray
|
|
232
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
233
|
+
mean : NDArray
|
|
234
|
+
Mean values.
|
|
235
|
+
std : NDArray
|
|
236
|
+
Standard deviations.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
NDArray
|
|
241
|
+
Denormalized image array.
|
|
242
|
+
"""
|
|
243
|
+
return array * (std + self.eps) + mean
|