careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +92 -55
- careamics/config/__init__.py +0 -1
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +4 -15
- careamics/config/configuration_example.py +4 -4
- careamics/config/configuration_factory.py +113 -55
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +63 -165
- careamics/config/inference_model.py +9 -75
- careamics/config/optimizer_models.py +4 -4
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +13 -7
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +13 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +5 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +84 -76
- careamics/dataset/iterable_dataset.py +166 -134
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +56 -14
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +20 -14
- careamics/dataset/patching/tiled_patching.py +13 -7
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +63 -41
- careamics/lightning_module.py +9 -3
- careamics/lightning_prediction_datamodule.py +15 -20
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +1 -3
- careamics/losses/loss_factory.py +2 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +4 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +122 -25
- careamics/models/model_factory.py +2 -1
- careamics/models/unet.py +114 -19
- careamics/prediction/stitch_prediction.py +2 -5
- careamics/transforms/__init__.py +4 -25
- careamics/transforms/compose.py +124 -0
- careamics/transforms/n2v_manipulate.py +65 -34
- careamics/transforms/normalize.py +91 -28
- careamics/transforms/pixel_manipulation.py +7 -7
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +66 -60
- careamics/utils/__init__.py +0 -1
- careamics/utils/base_enum.py +28 -0
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -24
- careamics/config/transformations/nd_flip_model.py +0 -32
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -93
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""A class chaining transforms together."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.config.data_model import TRANSFORMS_UNION
|
|
8
|
+
|
|
9
|
+
from .n2v_manipulate import N2VManipulate
|
|
10
|
+
from .normalize import Normalize
|
|
11
|
+
from .transform import Transform
|
|
12
|
+
from .xy_flip import XYFlip
|
|
13
|
+
from .xy_random_rotate90 import XYRandomRotate90
|
|
14
|
+
|
|
15
|
+
ALL_TRANSFORMS = {
|
|
16
|
+
"Normalize": Normalize,
|
|
17
|
+
"N2VManipulate": N2VManipulate,
|
|
18
|
+
"XYFlip": XYFlip,
|
|
19
|
+
"XYRandomRotate90": XYRandomRotate90,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_all_transforms() -> Dict[str, type]:
|
|
24
|
+
"""Return all the transforms accepted by CAREamics.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
dict
|
|
29
|
+
A dictionary with all the transforms accepted by CAREamics, where the keys are
|
|
30
|
+
the transform names and the values are the transform classes.
|
|
31
|
+
"""
|
|
32
|
+
return ALL_TRANSFORMS
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Compose:
|
|
36
|
+
"""A class chaining transforms together.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
transform_list : List[TRANSFORMS_UNION]
|
|
41
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
42
|
+
transform and its parameters.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
_callable_transforms : Callable
|
|
47
|
+
A callable that applies the transforms to the input data.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
|
|
51
|
+
"""Instantiate a Compose object.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
transform_list : List[TRANSFORMS_UNION]
|
|
56
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
57
|
+
transform and its parameters.
|
|
58
|
+
"""
|
|
59
|
+
# retrieve all available transforms
|
|
60
|
+
all_transforms = get_all_transforms()
|
|
61
|
+
|
|
62
|
+
# instantiate all transforms
|
|
63
|
+
transforms = [all_transforms[t.name](**t.model_dump()) for t in transform_list]
|
|
64
|
+
|
|
65
|
+
self._callable_transforms = self._chain_transforms(transforms)
|
|
66
|
+
|
|
67
|
+
def _chain_transforms(self, transforms: List[Transform]) -> Callable:
|
|
68
|
+
"""Chain the transforms together.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
transforms : List[Transform]
|
|
73
|
+
A list of transforms to chain together.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
Callable
|
|
78
|
+
A callable that applies the transforms in order to the input data.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def _chain(
|
|
82
|
+
patch: np.ndarray, target: Optional[np.ndarray]
|
|
83
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
84
|
+
"""Chain transforms on the input data.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
patch : np.ndarray
|
|
89
|
+
Input data.
|
|
90
|
+
target : Optional[np.ndarray]
|
|
91
|
+
Target data, by default None.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
96
|
+
The output of the transformations.
|
|
97
|
+
"""
|
|
98
|
+
params = (patch, target)
|
|
99
|
+
|
|
100
|
+
for t in transforms:
|
|
101
|
+
params = t(*params)
|
|
102
|
+
|
|
103
|
+
return params
|
|
104
|
+
|
|
105
|
+
return _chain
|
|
106
|
+
|
|
107
|
+
def __call__(
|
|
108
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
109
|
+
) -> Tuple[np.ndarray, ...]:
|
|
110
|
+
"""Apply the transforms to the input data.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
patch : np.ndarray
|
|
115
|
+
The input data.
|
|
116
|
+
target : Optional[np.ndarray], optional
|
|
117
|
+
Target data, by default None.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
Tuple[np.ndarray, ...]
|
|
122
|
+
The output of the transformations.
|
|
123
|
+
"""
|
|
124
|
+
return self._callable_transforms(patch, target)
|
|
@@ -1,26 +1,53 @@
|
|
|
1
|
+
"""N2V manipulation transform."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Literal, Optional, Tuple
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
4
|
-
from albumentations import ImageOnlyTransform
|
|
5
6
|
|
|
6
7
|
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
8
|
+
from careamics.transforms.transform import Transform
|
|
7
9
|
|
|
8
10
|
from .pixel_manipulation import median_manipulate, uniform_manipulate
|
|
9
11
|
from .struct_mask_parameters import StructMaskParameters
|
|
10
12
|
|
|
11
13
|
|
|
12
|
-
class N2VManipulate(
|
|
14
|
+
class N2VManipulate(Transform):
|
|
13
15
|
"""
|
|
14
16
|
Default augmentation for the N2V model.
|
|
15
17
|
|
|
16
|
-
This transform expects (Z)
|
|
18
|
+
This transform expects C(Z)YX dimensions.
|
|
17
19
|
|
|
18
20
|
Parameters
|
|
19
21
|
----------
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
roi_size : int, optional
|
|
23
|
+
Size of the replacement area, by default 11.
|
|
24
|
+
masked_pixel_percentage : float, optional
|
|
25
|
+
Percentage of pixels to mask, by default 0.2.
|
|
26
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
27
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
28
|
+
remove_center : bool, optional
|
|
29
|
+
Whether to remove central pixel from patch, by default True.
|
|
30
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
31
|
+
StructN2V mask axis, by default "none".
|
|
32
|
+
struct_mask_span : int, optional
|
|
33
|
+
StructN2V mask span, by default 5.
|
|
34
|
+
seed : Optional[int], optional
|
|
35
|
+
Random seed, by default None.
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
masked_pixel_percentage : float
|
|
40
|
+
Percentage of pixels to mask.
|
|
22
41
|
roi_size : int
|
|
23
|
-
Size of the
|
|
42
|
+
Size of the replacement area.
|
|
43
|
+
strategy : Literal[ "uniform", "median" ]
|
|
44
|
+
Replaccement strategy, uniform or median.
|
|
45
|
+
remove_center : bool
|
|
46
|
+
Whether to remove central pixel from patch.
|
|
47
|
+
struct_mask : Optional[StructMaskParameters]
|
|
48
|
+
StructN2V mask parameters.
|
|
49
|
+
rng : Generator
|
|
50
|
+
Random number generator.
|
|
24
51
|
"""
|
|
25
52
|
|
|
26
53
|
def __init__(
|
|
@@ -33,29 +60,31 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
33
60
|
remove_center: bool = True,
|
|
34
61
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
35
62
|
struct_mask_span: int = 5,
|
|
63
|
+
seed: Optional[int] = None, # TODO use in pixel manipulation
|
|
36
64
|
):
|
|
37
65
|
"""Constructor.
|
|
38
66
|
|
|
39
67
|
Parameters
|
|
40
68
|
----------
|
|
41
69
|
roi_size : int, optional
|
|
42
|
-
Size of the replacement area, by default 11
|
|
70
|
+
Size of the replacement area, by default 11.
|
|
43
71
|
masked_pixel_percentage : float, optional
|
|
44
|
-
Percentage of pixels to mask, by default 0.2
|
|
72
|
+
Percentage of pixels to mask, by default 0.2.
|
|
45
73
|
strategy : Literal[ "uniform", "median" ], optional
|
|
46
|
-
Replaccement strategy, uniform or median, by default uniform
|
|
74
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
47
75
|
remove_center : bool, optional
|
|
48
|
-
Whether to remove central pixel from patch, by default True
|
|
76
|
+
Whether to remove central pixel from patch, by default True.
|
|
49
77
|
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
50
|
-
StructN2V mask axis, by default "none"
|
|
78
|
+
StructN2V mask axis, by default "none".
|
|
51
79
|
struct_mask_span : int, optional
|
|
52
|
-
StructN2V mask span, by default 5
|
|
80
|
+
StructN2V mask span, by default 5.
|
|
81
|
+
seed : Optional[int], optional
|
|
82
|
+
Random seed, by default None.
|
|
53
83
|
"""
|
|
54
|
-
super().__init__(p=1)
|
|
55
84
|
self.masked_pixel_percentage = masked_pixel_percentage
|
|
56
85
|
self.roi_size = roi_size
|
|
57
86
|
self.strategy = strategy
|
|
58
|
-
self.remove_center = remove_center
|
|
87
|
+
self.remove_center = remove_center # TODO is this ever used?
|
|
59
88
|
|
|
60
89
|
if struct_mask_axis == SupportedStructAxis.NONE:
|
|
61
90
|
self.struct_mask: Optional[StructMaskParameters] = None
|
|
@@ -65,23 +94,35 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
65
94
|
span=struct_mask_span,
|
|
66
95
|
)
|
|
67
96
|
|
|
68
|
-
|
|
69
|
-
self
|
|
97
|
+
# numpy random generator
|
|
98
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
99
|
+
|
|
100
|
+
def __call__(
|
|
101
|
+
self, patch: np.ndarray, *args: Any, **kwargs: Any
|
|
70
102
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
71
103
|
"""Apply the transform to the image.
|
|
72
104
|
|
|
73
105
|
Parameters
|
|
74
106
|
----------
|
|
75
|
-
|
|
76
|
-
Image
|
|
107
|
+
patch : np.ndarray
|
|
108
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
109
|
+
*args : Any
|
|
110
|
+
Additional arguments, unused.
|
|
111
|
+
**kwargs : Any
|
|
112
|
+
Additional keyword arguments, unused.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
117
|
+
Masked patch, original patch, and mask.
|
|
77
118
|
"""
|
|
78
119
|
masked = np.zeros_like(patch)
|
|
79
120
|
mask = np.zeros_like(patch)
|
|
80
121
|
if self.strategy == SupportedPixelManipulation.UNIFORM:
|
|
81
122
|
# Iterate over the channels to apply manipulation separately
|
|
82
|
-
for c in range(patch.shape[
|
|
83
|
-
masked[
|
|
84
|
-
patch=patch[
|
|
123
|
+
for c in range(patch.shape[0]):
|
|
124
|
+
masked[c, ...], mask[c, ...] = uniform_manipulate(
|
|
125
|
+
patch=patch[c, ...],
|
|
85
126
|
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
86
127
|
subpatch_size=self.roi_size,
|
|
87
128
|
remove_center=self.remove_center,
|
|
@@ -89,9 +130,9 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
89
130
|
)
|
|
90
131
|
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
91
132
|
# Iterate over the channels to apply manipulation separately
|
|
92
|
-
for c in range(patch.shape[
|
|
93
|
-
masked[
|
|
94
|
-
patch=patch[
|
|
133
|
+
for c in range(patch.shape[0]):
|
|
134
|
+
masked[c, ...], mask[c, ...] = median_manipulate(
|
|
135
|
+
patch=patch[c, ...],
|
|
95
136
|
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
96
137
|
subpatch_size=self.roi_size,
|
|
97
138
|
struct_params=self.struct_mask,
|
|
@@ -101,13 +142,3 @@ class N2VManipulate(ImageOnlyTransform):
|
|
|
101
142
|
|
|
102
143
|
# TODO why return patch?
|
|
103
144
|
return masked, patch, mask
|
|
104
|
-
|
|
105
|
-
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
|
106
|
-
"""Get the transform parameters.
|
|
107
|
-
|
|
108
|
-
Returns
|
|
109
|
-
-------
|
|
110
|
-
Tuple[str, ...]
|
|
111
|
-
Transform parameters.
|
|
112
|
-
"""
|
|
113
|
-
return ("roi_size", "masked_pixel_percentage", "strategy", "struct_mask")
|
|
@@ -1,27 +1,35 @@
|
|
|
1
|
-
|
|
1
|
+
"""Normalization and denormalization transforms for image patches."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
4
|
-
|
|
6
|
+
|
|
7
|
+
from careamics.transforms.transform import Transform
|
|
5
8
|
|
|
6
9
|
|
|
7
|
-
class Normalize(
|
|
10
|
+
class Normalize(Transform):
|
|
8
11
|
"""
|
|
9
12
|
Normalize an image or image patch.
|
|
10
13
|
|
|
11
|
-
Normalization is a zero mean and unit variance. This transform expects (Z)
|
|
14
|
+
Normalization is a zero mean and unit variance. This transform expects C(Z)YX
|
|
12
15
|
dimensions.
|
|
13
16
|
|
|
14
17
|
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
15
18
|
division by zero and that it returns a float32 image.
|
|
16
19
|
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
mean : float
|
|
23
|
+
Mean value.
|
|
24
|
+
std : float
|
|
25
|
+
Standard deviation value.
|
|
26
|
+
|
|
17
27
|
Attributes
|
|
18
28
|
----------
|
|
19
29
|
mean : float
|
|
20
30
|
Mean value.
|
|
21
31
|
std : float
|
|
22
32
|
Standard deviation value.
|
|
23
|
-
eps : float
|
|
24
|
-
Epsilon value to avoid division by zero.
|
|
25
33
|
"""
|
|
26
34
|
|
|
27
35
|
def __init__(
|
|
@@ -29,61 +37,82 @@ class Normalize(DualTransform):
|
|
|
29
37
|
mean: float,
|
|
30
38
|
std: float,
|
|
31
39
|
):
|
|
32
|
-
|
|
40
|
+
"""Constructor.
|
|
33
41
|
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
mean : float
|
|
45
|
+
Mean value.
|
|
46
|
+
std : float
|
|
47
|
+
Standard deviation value.
|
|
48
|
+
"""
|
|
34
49
|
self.mean = mean
|
|
35
50
|
self.std = std
|
|
36
51
|
self.eps = 1e-6
|
|
37
52
|
|
|
38
|
-
def
|
|
39
|
-
|
|
40
|
-
|
|
53
|
+
def __call__(
|
|
54
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
55
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
56
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
41
57
|
|
|
42
58
|
Parameters
|
|
43
59
|
----------
|
|
44
60
|
patch : np.ndarray
|
|
45
|
-
|
|
61
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
62
|
+
target : Optional[np.ndarray], optional
|
|
63
|
+
Target for the patch, by default None.
|
|
46
64
|
|
|
47
65
|
Returns
|
|
48
66
|
-------
|
|
49
|
-
np.ndarray
|
|
50
|
-
|
|
67
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
68
|
+
Transformed patch and target.
|
|
51
69
|
"""
|
|
52
|
-
|
|
70
|
+
norm_patch = self._apply(patch)
|
|
71
|
+
norm_target = self._apply(target) if target is not None else None
|
|
53
72
|
|
|
54
|
-
|
|
55
|
-
"""
|
|
56
|
-
Apply the transform to the mask.
|
|
73
|
+
return norm_patch, norm_target
|
|
57
74
|
|
|
58
|
-
|
|
75
|
+
def _apply(self, patch: np.ndarray) -> np.ndarray:
|
|
76
|
+
"""
|
|
77
|
+
Apply the transform to the image.
|
|
59
78
|
|
|
60
79
|
Parameters
|
|
61
80
|
----------
|
|
62
|
-
|
|
63
|
-
|
|
81
|
+
patch : np.ndarray
|
|
82
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
np.ndarray
|
|
87
|
+
Normalizedimage patch.
|
|
64
88
|
"""
|
|
65
|
-
return
|
|
89
|
+
return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
|
|
66
90
|
|
|
67
91
|
|
|
68
|
-
class Denormalize
|
|
92
|
+
class Denormalize:
|
|
69
93
|
"""
|
|
70
94
|
Denormalize an image or image patch.
|
|
71
95
|
|
|
72
96
|
Denormalization is performed expecting a zero mean and unit variance input. This
|
|
73
|
-
transform expects (Z)
|
|
97
|
+
transform expects C(Z)YX dimensions.
|
|
74
98
|
|
|
75
99
|
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
76
100
|
division by zero during the normalization step, which is taken into account during
|
|
77
101
|
denormalization.
|
|
78
102
|
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
mean : float
|
|
106
|
+
Mean value.
|
|
107
|
+
std : float
|
|
108
|
+
Standard deviation value.
|
|
109
|
+
|
|
79
110
|
Attributes
|
|
80
111
|
----------
|
|
81
112
|
mean : float
|
|
82
113
|
Mean value.
|
|
83
114
|
std : float
|
|
84
115
|
Standard deviation value.
|
|
85
|
-
eps : float
|
|
86
|
-
Epsilon value to avoid division by zero.
|
|
87
116
|
"""
|
|
88
117
|
|
|
89
118
|
def __init__(
|
|
@@ -91,19 +120,53 @@ class Denormalize(DualTransform):
|
|
|
91
120
|
mean: float,
|
|
92
121
|
std: float,
|
|
93
122
|
):
|
|
94
|
-
|
|
123
|
+
"""Constructor.
|
|
95
124
|
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
mean : float
|
|
128
|
+
Mean.
|
|
129
|
+
std : float
|
|
130
|
+
Standard deviation.
|
|
131
|
+
"""
|
|
96
132
|
self.mean = mean
|
|
97
133
|
self.std = std
|
|
98
134
|
self.eps = 1e-6
|
|
99
135
|
|
|
100
|
-
def
|
|
136
|
+
def __call__(
|
|
137
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
138
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
139
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
patch : np.ndarray
|
|
144
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
145
|
+
target : Optional[np.ndarray], optional
|
|
146
|
+
Target for the patch, by default None.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
151
|
+
Transformed patch and target.
|
|
152
|
+
"""
|
|
153
|
+
norm_patch = self._apply(patch)
|
|
154
|
+
norm_target = self._apply(target) if target is not None else None
|
|
155
|
+
|
|
156
|
+
return norm_patch, norm_target
|
|
157
|
+
|
|
158
|
+
def _apply(self, patch: np.ndarray) -> np.ndarray:
|
|
101
159
|
"""
|
|
102
160
|
Apply the transform to the image.
|
|
103
161
|
|
|
104
162
|
Parameters
|
|
105
163
|
----------
|
|
106
164
|
patch : np.ndarray
|
|
107
|
-
Image
|
|
165
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
np.ndarray
|
|
170
|
+
Denormalized image patch.
|
|
108
171
|
"""
|
|
109
172
|
return patch * (self.std + self.eps) + self.mean
|
|
@@ -4,7 +4,8 @@ Pixel manipulation methods.
|
|
|
4
4
|
Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
5
5
|
masked pixels.
|
|
6
6
|
"""
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Tuple
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
|
|
@@ -14,7 +15,7 @@ from .struct_mask_parameters import StructMaskParameters
|
|
|
14
15
|
def _apply_struct_mask(
|
|
15
16
|
patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
|
|
16
17
|
) -> np.ndarray:
|
|
17
|
-
"""
|
|
18
|
+
"""Apply structN2V masks to patch.
|
|
18
19
|
|
|
19
20
|
Each point in `coords` corresponds to the center of a mask, masks are paremeterized
|
|
20
21
|
by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
|
|
@@ -97,7 +98,7 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
|
97
98
|
|
|
98
99
|
|
|
99
100
|
def _get_stratified_coords(
|
|
100
|
-
mask_pixel_perc: float, shape:
|
|
101
|
+
mask_pixel_perc: float, shape: Tuple[int, ...]
|
|
101
102
|
) -> np.ndarray:
|
|
102
103
|
"""
|
|
103
104
|
Generate coordinates of the pixels to mask.
|
|
@@ -246,9 +247,8 @@ def uniform_manipulate(
|
|
|
246
247
|
subpatch_size : int
|
|
247
248
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
248
249
|
remove_center : bool
|
|
249
|
-
Whether to remove the center pixel from the subpatch, by default False.
|
|
250
|
-
|
|
251
|
-
struct_params: Optional[StructMaskParameters]
|
|
250
|
+
Whether to remove the center pixel from the subpatch, by default False.
|
|
251
|
+
struct_params : Optional[StructMaskParameters]
|
|
252
252
|
Parameters for the structN2V mask (axis and span).
|
|
253
253
|
|
|
254
254
|
Returns
|
|
@@ -322,7 +322,7 @@ def median_manipulate(
|
|
|
322
322
|
Approximate percentage of pixels to be masked.
|
|
323
323
|
subpatch_size : int
|
|
324
324
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
325
|
-
struct_params: Optional[StructMaskParameters]
|
|
325
|
+
struct_params : Optional[StructMaskParameters]
|
|
326
326
|
Parameters for the structN2V mask (axis and span).
|
|
327
327
|
|
|
328
328
|
Returns
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Class representing the parameters of structN2V masks."""
|
|
2
|
+
|
|
1
3
|
from dataclasses import dataclass
|
|
2
4
|
from typing import Literal
|
|
3
5
|
|
|
@@ -6,7 +8,7 @@ from typing import Literal
|
|
|
6
8
|
class StructMaskParameters:
|
|
7
9
|
"""Parameters of structN2V masks.
|
|
8
10
|
|
|
9
|
-
|
|
11
|
+
Attributes
|
|
10
12
|
----------
|
|
11
13
|
axis : Literal[0, 1]
|
|
12
14
|
Axis along which to apply the mask, horizontal (0) or vertical (1).
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""A general parent class for transforms."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Transform:
|
|
7
|
+
"""A general parent class for transforms."""
|
|
8
|
+
|
|
9
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
10
|
+
"""Apply the transform.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
*args : Any
|
|
15
|
+
Arguments.
|
|
16
|
+
**kwargs : Any
|
|
17
|
+
Keyword arguments.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
Any
|
|
22
|
+
Transformed data.
|
|
23
|
+
"""
|
|
24
|
+
pass
|
careamics/transforms/tta.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Test-time augmentations."""
|
|
2
|
+
|
|
2
3
|
from typing import List
|
|
3
4
|
|
|
4
|
-
import numpy as np
|
|
5
5
|
from torch import Tensor, flip, mean, rot90, stack
|
|
6
6
|
|
|
7
7
|
|
|
@@ -48,7 +48,7 @@ class ImageRestorationTTA:
|
|
|
48
48
|
augmented_flip.append(flip(x_, dims=(-3, -1)))
|
|
49
49
|
return augmented_flip
|
|
50
50
|
|
|
51
|
-
def backward(self, x: List[Tensor]) ->
|
|
51
|
+
def backward(self, x: List[Tensor]) -> Tensor:
|
|
52
52
|
"""Undo the test-time augmentation.
|
|
53
53
|
|
|
54
54
|
Parameters
|