careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,34 @@
|
|
|
1
|
-
|
|
1
|
+
"""Normalization and denormalization transforms for image patches."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
4
7
|
|
|
5
8
|
from careamics.transforms.transform import Transform
|
|
6
9
|
|
|
7
10
|
|
|
11
|
+
def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
|
|
12
|
+
"""Reshape stats to match the number of dimensions of the input image.
|
|
13
|
+
|
|
14
|
+
This allows to broadcast the stats (mean or std) to the image dimensions, and
|
|
15
|
+
thus directly perform a vectorial calculation.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
stats : list of float
|
|
20
|
+
List of stats, mean or standard deviation.
|
|
21
|
+
ndim : int
|
|
22
|
+
Number of dimensions of the image, including the C channel.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
NDArray
|
|
27
|
+
Reshaped stats.
|
|
28
|
+
"""
|
|
29
|
+
return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
|
|
30
|
+
|
|
31
|
+
|
|
8
32
|
class Normalize(Transform):
|
|
9
33
|
"""
|
|
10
34
|
Normalize an image or image patch.
|
|
@@ -15,106 +39,205 @@ class Normalize(Transform):
|
|
|
15
39
|
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
16
40
|
division by zero and that it returns a float32 image.
|
|
17
41
|
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
image_means : list of float
|
|
45
|
+
Mean value per channel.
|
|
46
|
+
image_stds : list of float
|
|
47
|
+
Standard deviation value per channel.
|
|
48
|
+
target_means : list of float, optional
|
|
49
|
+
Target mean value per channel, by default None.
|
|
50
|
+
target_stds : list of float, optional
|
|
51
|
+
Target standard deviation value per channel, by default None.
|
|
52
|
+
|
|
18
53
|
Attributes
|
|
19
54
|
----------
|
|
20
|
-
|
|
21
|
-
Mean value.
|
|
22
|
-
|
|
23
|
-
Standard deviation value.
|
|
55
|
+
image_means : list of float
|
|
56
|
+
Mean value per channel.
|
|
57
|
+
image_stds : list of float
|
|
58
|
+
Standard deviation value per channel.
|
|
59
|
+
target_means :list of float, optional
|
|
60
|
+
Target mean value per channel, by default None.
|
|
61
|
+
target_stds : list of float, optional
|
|
62
|
+
Target standard deviation value per channel, by default None.
|
|
24
63
|
"""
|
|
25
64
|
|
|
26
65
|
def __init__(
|
|
27
66
|
self,
|
|
28
|
-
|
|
29
|
-
|
|
67
|
+
image_means: list[float],
|
|
68
|
+
image_stds: list[float],
|
|
69
|
+
target_means: Optional[list[float]] = None,
|
|
70
|
+
target_stds: Optional[list[float]] = None,
|
|
30
71
|
):
|
|
31
|
-
|
|
32
|
-
|
|
72
|
+
"""Constructor.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
image_means : list of float
|
|
77
|
+
Mean value per channel.
|
|
78
|
+
image_stds : list of float
|
|
79
|
+
Standard deviation value per channel.
|
|
80
|
+
target_means : list of float, optional
|
|
81
|
+
Target mean value per channel, by default None.
|
|
82
|
+
target_stds : list of float, optional
|
|
83
|
+
Target standard deviation value per channel, by default None.
|
|
84
|
+
"""
|
|
85
|
+
self.image_means = image_means
|
|
86
|
+
self.image_stds = image_stds
|
|
87
|
+
self.target_means = target_means
|
|
88
|
+
self.target_stds = target_stds
|
|
89
|
+
|
|
33
90
|
self.eps = 1e-6
|
|
34
91
|
|
|
35
92
|
def __call__(
|
|
36
|
-
self, patch: np.ndarray, target: Optional[
|
|
37
|
-
) ->
|
|
93
|
+
self, patch: np.ndarray, target: Optional[NDArray] = None
|
|
94
|
+
) -> tuple[NDArray, Optional[NDArray]]:
|
|
38
95
|
"""Apply the transform to the source patch and the target (optional).
|
|
39
96
|
|
|
40
97
|
Parameters
|
|
41
98
|
----------
|
|
42
|
-
patch :
|
|
99
|
+
patch : NDArray
|
|
43
100
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
44
|
-
target :
|
|
45
|
-
Target for the patch, by default None
|
|
101
|
+
target : NDArray, optional
|
|
102
|
+
Target for the patch, by default None.
|
|
46
103
|
|
|
47
104
|
Returns
|
|
48
105
|
-------
|
|
49
|
-
|
|
50
|
-
Transformed patch and target
|
|
106
|
+
tuple of NDArray
|
|
107
|
+
Transformed patch and target, the target can be returned as `None`.
|
|
51
108
|
"""
|
|
52
|
-
|
|
53
|
-
|
|
109
|
+
if len(self.image_means) != patch.shape[0]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
112
|
+
f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# reshape mean and std and apply the normalization to the patch
|
|
116
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
117
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
118
|
+
norm_patch = self._apply(patch, means, stds)
|
|
119
|
+
|
|
120
|
+
# same for the target patch
|
|
121
|
+
if (
|
|
122
|
+
target is not None
|
|
123
|
+
and self.target_means is not None
|
|
124
|
+
and self.target_stds is not None
|
|
125
|
+
):
|
|
126
|
+
target_means = _reshape_stats(self.target_means, target.ndim)
|
|
127
|
+
target_stds = _reshape_stats(self.target_stds, target.ndim)
|
|
128
|
+
norm_target = self._apply(target, target_means, target_stds)
|
|
129
|
+
else:
|
|
130
|
+
norm_target = None
|
|
54
131
|
|
|
55
132
|
return norm_patch, norm_target
|
|
56
133
|
|
|
57
|
-
def _apply(self, patch:
|
|
58
|
-
|
|
134
|
+
def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
135
|
+
"""
|
|
136
|
+
Apply the transform to the image.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
patch : NDArray
|
|
141
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
142
|
+
mean : NDArray
|
|
143
|
+
Mean values.
|
|
144
|
+
std : NDArray
|
|
145
|
+
Standard deviations.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
NDArray
|
|
150
|
+
Normalized image patch.
|
|
151
|
+
"""
|
|
152
|
+
return ((patch - mean) / (std + self.eps)).astype(np.float32)
|
|
59
153
|
|
|
60
154
|
|
|
61
155
|
class Denormalize:
|
|
62
156
|
"""
|
|
63
|
-
Denormalize an image
|
|
157
|
+
Denormalize an image.
|
|
64
158
|
|
|
65
159
|
Denormalization is performed expecting a zero mean and unit variance input. This
|
|
66
160
|
transform expects C(Z)YX dimensions.
|
|
67
161
|
|
|
68
|
-
|
|
162
|
+
Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
69
163
|
division by zero during the normalization step, which is taken into account during
|
|
70
164
|
denormalization.
|
|
71
165
|
|
|
72
|
-
|
|
166
|
+
Parameters
|
|
73
167
|
----------
|
|
74
|
-
|
|
75
|
-
Mean value.
|
|
76
|
-
|
|
77
|
-
Standard deviation value.
|
|
168
|
+
image_means : list or tuple of float
|
|
169
|
+
Mean value per channel.
|
|
170
|
+
image_stds : list or tuple of float
|
|
171
|
+
Standard deviation value per channel.
|
|
172
|
+
|
|
78
173
|
"""
|
|
79
174
|
|
|
80
175
|
def __init__(
|
|
81
176
|
self,
|
|
82
|
-
|
|
83
|
-
|
|
177
|
+
image_means: list[float],
|
|
178
|
+
image_stds: list[float],
|
|
84
179
|
):
|
|
85
|
-
|
|
86
|
-
|
|
180
|
+
"""Constructor.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
image_means : list of float
|
|
185
|
+
Mean value per channel.
|
|
186
|
+
image_stds : list of float
|
|
187
|
+
Standard deviation value per channel.
|
|
188
|
+
"""
|
|
189
|
+
self.image_means = image_means
|
|
190
|
+
self.image_stds = image_stds
|
|
191
|
+
|
|
87
192
|
self.eps = 1e-6
|
|
88
193
|
|
|
89
|
-
def __call__(
|
|
90
|
-
|
|
91
|
-
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
92
|
-
"""Apply the transform to the source patch and the target (optional).
|
|
194
|
+
def __call__(self, patch: NDArray) -> NDArray:
|
|
195
|
+
"""Reverse the normalization operation for a batch of patches.
|
|
93
196
|
|
|
94
197
|
Parameters
|
|
95
198
|
----------
|
|
96
|
-
patch :
|
|
97
|
-
Patch, 2D or 3D, shape
|
|
98
|
-
target : Optional[np.ndarray], optional
|
|
99
|
-
Target for the patch, by default None
|
|
199
|
+
patch : NDArray
|
|
200
|
+
Patch, 2D or 3D, shape BC(Z)YX.
|
|
100
201
|
|
|
101
202
|
Returns
|
|
102
203
|
-------
|
|
103
|
-
|
|
104
|
-
Transformed
|
|
204
|
+
NDArray
|
|
205
|
+
Transformed array.
|
|
105
206
|
"""
|
|
106
|
-
|
|
107
|
-
|
|
207
|
+
if len(self.image_means) != patch.shape[1]:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
210
|
+
f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
|
|
211
|
+
f"match."
|
|
212
|
+
)
|
|
108
213
|
|
|
109
|
-
|
|
214
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
215
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
110
216
|
|
|
111
|
-
|
|
217
|
+
denorm_array = self._apply(
|
|
218
|
+
patch,
|
|
219
|
+
np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
|
|
220
|
+
np.swapaxes(stds, 0, 1),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return denorm_array.astype(np.float32)
|
|
224
|
+
|
|
225
|
+
def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
112
226
|
"""
|
|
113
227
|
Apply the transform to the image.
|
|
114
228
|
|
|
115
229
|
Parameters
|
|
116
230
|
----------
|
|
117
|
-
|
|
118
|
-
Image
|
|
231
|
+
array : NDArray
|
|
232
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
233
|
+
mean : NDArray
|
|
234
|
+
Mean values.
|
|
235
|
+
std : NDArray
|
|
236
|
+
Standard deviations.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
NDArray
|
|
241
|
+
Denormalized image array.
|
|
119
242
|
"""
|
|
120
|
-
return
|
|
243
|
+
return array * (std + self.eps) + mean
|
|
@@ -5,7 +5,7 @@ Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
|
5
5
|
masked pixels.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Optional, Tuple
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -13,9 +13,12 @@ from .struct_mask_parameters import StructMaskParameters
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def _apply_struct_mask(
|
|
16
|
-
patch: np.ndarray,
|
|
16
|
+
patch: np.ndarray,
|
|
17
|
+
coords: np.ndarray,
|
|
18
|
+
struct_params: StructMaskParameters,
|
|
19
|
+
rng: Optional[np.random.Generator] = None,
|
|
17
20
|
) -> np.ndarray:
|
|
18
|
-
"""
|
|
21
|
+
"""Apply structN2V masks to patch.
|
|
19
22
|
|
|
20
23
|
Each point in `coords` corresponds to the center of a mask, masks are paremeterized
|
|
21
24
|
by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
|
|
@@ -31,12 +34,17 @@ def _apply_struct_mask(
|
|
|
31
34
|
Coordinates of the ROI(subpatch) centers.
|
|
32
35
|
struct_params : StructMaskParameters
|
|
33
36
|
Parameters for the structN2V mask (axis and span).
|
|
37
|
+
rng : np.random.Generator or None
|
|
38
|
+
Random number generator.
|
|
34
39
|
|
|
35
40
|
Returns
|
|
36
41
|
-------
|
|
37
42
|
np.ndarray
|
|
38
43
|
Patch with the structN2V mask applied.
|
|
39
44
|
"""
|
|
45
|
+
if rng is None:
|
|
46
|
+
rng = np.random.default_rng()
|
|
47
|
+
|
|
40
48
|
# relative axis
|
|
41
49
|
moving_axis = -1 - struct_params.axis
|
|
42
50
|
|
|
@@ -67,7 +75,7 @@ def _apply_struct_mask(
|
|
|
67
75
|
mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
|
|
68
76
|
|
|
69
77
|
# replace neighbouring pixels with random values from flat dist
|
|
70
|
-
patch[tuple(mix.T)] =
|
|
78
|
+
patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
|
|
71
79
|
|
|
72
80
|
return patch
|
|
73
81
|
|
|
@@ -98,7 +106,9 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
|
98
106
|
|
|
99
107
|
|
|
100
108
|
def _get_stratified_coords(
|
|
101
|
-
mask_pixel_perc: float,
|
|
109
|
+
mask_pixel_perc: float,
|
|
110
|
+
shape: Tuple[int, ...],
|
|
111
|
+
rng: Optional[np.random.Generator] = None,
|
|
102
112
|
) -> np.ndarray:
|
|
103
113
|
"""
|
|
104
114
|
Generate coordinates of the pixels to mask.
|
|
@@ -113,6 +123,8 @@ def _get_stratified_coords(
|
|
|
113
123
|
calculating the distance between masked pixels across each axis.
|
|
114
124
|
shape : Tuple[int, ...]
|
|
115
125
|
Shape of the input patch.
|
|
126
|
+
rng : np.random.Generator or None
|
|
127
|
+
Random number generator.
|
|
116
128
|
|
|
117
129
|
Returns
|
|
118
130
|
-------
|
|
@@ -124,7 +136,8 @@ def _get_stratified_coords(
|
|
|
124
136
|
"Calculating coordinates is only possible for 2D and 3D patches"
|
|
125
137
|
)
|
|
126
138
|
|
|
127
|
-
rng
|
|
139
|
+
if rng is None:
|
|
140
|
+
rng = np.random.default_rng()
|
|
128
141
|
|
|
129
142
|
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
130
143
|
np.int32
|
|
@@ -228,6 +241,7 @@ def uniform_manipulate(
|
|
|
228
241
|
subpatch_size: int = 11,
|
|
229
242
|
remove_center: bool = True,
|
|
230
243
|
struct_params: Optional[StructMaskParameters] = None,
|
|
244
|
+
rng: Optional[np.random.Generator] = None,
|
|
231
245
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
232
246
|
"""
|
|
233
247
|
Manipulate pixels by replacing them with a neighbor values.
|
|
@@ -248,19 +262,23 @@ def uniform_manipulate(
|
|
|
248
262
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
249
263
|
remove_center : bool
|
|
250
264
|
Whether to remove the center pixel from the subpatch, by default False.
|
|
251
|
-
struct_params:
|
|
265
|
+
struct_params : StructMaskParameters or None
|
|
252
266
|
Parameters for the structN2V mask (axis and span).
|
|
267
|
+
rng : np.random.Generator or None
|
|
268
|
+
Random number generator.
|
|
253
269
|
|
|
254
270
|
Returns
|
|
255
271
|
-------
|
|
256
272
|
Tuple[np.ndarray]
|
|
257
273
|
Tuple containing the manipulated patch and the corresponding mask.
|
|
258
274
|
"""
|
|
275
|
+
if rng is None:
|
|
276
|
+
rng = np.random.default_rng()
|
|
277
|
+
|
|
259
278
|
# Get the coordinates of the pixels to be replaced
|
|
260
279
|
transformed_patch = patch.copy()
|
|
261
280
|
|
|
262
|
-
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
263
|
-
rng = np.random.default_rng()
|
|
281
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
264
282
|
|
|
265
283
|
# Generate coordinate grid for subpatch
|
|
266
284
|
roi_span_full = np.arange(
|
|
@@ -303,6 +321,7 @@ def median_manipulate(
|
|
|
303
321
|
mask_pixel_percentage: float,
|
|
304
322
|
subpatch_size: int = 11,
|
|
305
323
|
struct_params: Optional[StructMaskParameters] = None,
|
|
324
|
+
rng: Optional[np.random.Generator] = None,
|
|
306
325
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
307
326
|
"""
|
|
308
327
|
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
@@ -322,18 +341,23 @@ def median_manipulate(
|
|
|
322
341
|
Approximate percentage of pixels to be masked.
|
|
323
342
|
subpatch_size : int
|
|
324
343
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
325
|
-
struct_params:
|
|
344
|
+
struct_params : StructMaskParameters or None, optional
|
|
326
345
|
Parameters for the structN2V mask (axis and span).
|
|
346
|
+
rng : np.random.Generator or None, optional
|
|
347
|
+
Random number generato, by default None.
|
|
327
348
|
|
|
328
349
|
Returns
|
|
329
350
|
-------
|
|
330
351
|
Tuple[np.ndarray]
|
|
331
352
|
Tuple containing the manipulated patch, the original patch and the mask.
|
|
332
353
|
"""
|
|
354
|
+
if rng is None:
|
|
355
|
+
rng = np.random.default_rng()
|
|
356
|
+
|
|
333
357
|
transformed_patch = patch.copy()
|
|
334
358
|
|
|
335
359
|
# Get the coordinates of the pixels to be replaced
|
|
336
|
-
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
360
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
337
361
|
|
|
338
362
|
# Generate coordinate grid for subpatch
|
|
339
363
|
roi_span = np.array(
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Class representing the parameters of structN2V masks."""
|
|
2
|
+
|
|
1
3
|
from dataclasses import dataclass
|
|
2
4
|
from typing import Literal
|
|
3
5
|
|
|
@@ -6,7 +8,7 @@ from typing import Literal
|
|
|
6
8
|
class StructMaskParameters:
|
|
7
9
|
"""Parameters of structN2V masks.
|
|
8
10
|
|
|
9
|
-
|
|
11
|
+
Attributes
|
|
10
12
|
----------
|
|
11
13
|
axis : Literal[0, 1]
|
|
12
14
|
Axis along which to apply the mask, horizontal (0) or vertical (1).
|
|
@@ -1,33 +1,24 @@
|
|
|
1
1
|
"""A general parent class for transforms."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
3
|
+
from typing import Any
|
|
6
4
|
|
|
7
5
|
|
|
8
6
|
class Transform:
|
|
9
7
|
"""A general parent class for transforms."""
|
|
10
8
|
|
|
11
|
-
def __call__(
|
|
12
|
-
|
|
13
|
-
) -> Tuple[np.ndarray, ...]:
|
|
14
|
-
"""Apply the transform to the input data.
|
|
9
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
10
|
+
"""Apply the transform.
|
|
15
11
|
|
|
16
12
|
Parameters
|
|
17
13
|
----------
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
14
|
+
*args : Any
|
|
15
|
+
Arguments.
|
|
16
|
+
**kwargs : Any
|
|
17
|
+
Keyword arguments.
|
|
22
18
|
|
|
23
19
|
Returns
|
|
24
20
|
-------
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Raises
|
|
29
|
-
------
|
|
30
|
-
NotImplementedError
|
|
31
|
-
This method should be implemented in the child class.
|
|
21
|
+
Any
|
|
22
|
+
Transformed data.
|
|
32
23
|
"""
|
|
33
|
-
|
|
24
|
+
pass
|
careamics/transforms/tta.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
"""Test-time augmentations."""
|
|
2
2
|
|
|
3
|
-
from typing import List
|
|
4
|
-
|
|
5
3
|
from torch import Tensor, flip, mean, rot90, stack
|
|
6
4
|
|
|
7
5
|
|
|
8
|
-
# TODO add tests
|
|
9
6
|
class ImageRestorationTTA:
|
|
10
7
|
"""
|
|
11
8
|
Test-time augmentation for image restoration tasks.
|
|
@@ -13,62 +10,79 @@ class ImageRestorationTTA:
|
|
|
13
10
|
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
14
11
|
as well as the original image flipped.
|
|
15
12
|
|
|
16
|
-
Tensors should be of shape SC(Z)YX
|
|
13
|
+
Tensors should be of shape SC(Z)YX.
|
|
17
14
|
|
|
18
15
|
This transformation is used in the LightningModule in order to perform test-time
|
|
19
|
-
|
|
16
|
+
augmentation.
|
|
20
17
|
"""
|
|
21
18
|
|
|
22
|
-
def
|
|
23
|
-
"""Constructor."""
|
|
24
|
-
pass
|
|
25
|
-
|
|
26
|
-
def forward(self, x: Tensor) -> List[Tensor]:
|
|
19
|
+
def forward(self, input_tensor: Tensor) -> list[Tensor]:
|
|
27
20
|
"""
|
|
28
21
|
Apply test-time augmentation to the input tensor.
|
|
29
22
|
|
|
30
23
|
Parameters
|
|
31
24
|
----------
|
|
32
|
-
|
|
25
|
+
input_tensor : Tensor
|
|
33
26
|
Input tensor, shape SC(Z)YX.
|
|
34
27
|
|
|
35
28
|
Returns
|
|
36
29
|
-------
|
|
37
|
-
|
|
30
|
+
list of torch.Tensor
|
|
38
31
|
List of augmented tensors.
|
|
39
32
|
"""
|
|
33
|
+
# axes: only applies to YX axes
|
|
34
|
+
axes = (-2, -1)
|
|
35
|
+
|
|
40
36
|
augmented = [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
rot90(
|
|
37
|
+
# original
|
|
38
|
+
input_tensor,
|
|
39
|
+
# rotations
|
|
40
|
+
rot90(input_tensor, 1, dims=axes),
|
|
41
|
+
rot90(input_tensor, 2, dims=axes),
|
|
42
|
+
rot90(input_tensor, 3, dims=axes),
|
|
43
|
+
# original flipped
|
|
44
|
+
flip(input_tensor, dims=(axes[0],)),
|
|
45
|
+
flip(input_tensor, dims=(axes[1],)),
|
|
45
46
|
]
|
|
46
|
-
augmented_flip = augmented.copy()
|
|
47
|
-
for x_ in augmented:
|
|
48
|
-
augmented_flip.append(flip(x_, dims=(-3, -1)))
|
|
49
|
-
return augmented_flip
|
|
50
47
|
|
|
51
|
-
|
|
48
|
+
# rotated once, flipped
|
|
49
|
+
augmented.extend(
|
|
50
|
+
[
|
|
51
|
+
flip(augmented[1], dims=(axes[0],)),
|
|
52
|
+
flip(augmented[1], dims=(axes[1],)),
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return augmented
|
|
57
|
+
|
|
58
|
+
def backward(self, x: list[Tensor]) -> Tensor:
|
|
52
59
|
"""Undo the test-time augmentation.
|
|
53
60
|
|
|
54
61
|
Parameters
|
|
55
62
|
----------
|
|
56
63
|
x : Any
|
|
57
|
-
List of augmented tensors.
|
|
64
|
+
List of augmented tensors of shape SC(Z)YX.
|
|
58
65
|
|
|
59
66
|
Returns
|
|
60
67
|
-------
|
|
61
68
|
Any
|
|
62
69
|
Original tensor.
|
|
63
70
|
"""
|
|
71
|
+
axes = (-2, -1)
|
|
72
|
+
|
|
64
73
|
reverse = [
|
|
74
|
+
# original
|
|
65
75
|
x[0],
|
|
66
|
-
|
|
67
|
-
rot90(x[
|
|
68
|
-
rot90(x[
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
76
|
+
# rotated
|
|
77
|
+
rot90(x[1], -1, dims=axes),
|
|
78
|
+
rot90(x[2], -2, dims=axes),
|
|
79
|
+
rot90(x[3], -3, dims=axes),
|
|
80
|
+
# original flipped
|
|
81
|
+
flip(x[4], dims=(axes[0],)),
|
|
82
|
+
flip(x[5], dims=(axes[1],)),
|
|
83
|
+
# rotated once, flipped
|
|
84
|
+
rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
|
|
85
|
+
rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
|
|
73
86
|
]
|
|
87
|
+
|
|
74
88
|
return mean(stack(reverse), dim=0)
|