careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,34 @@
|
|
|
1
1
|
"""Normalization and denormalization transforms for image patches."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
|
|
7
8
|
from careamics.transforms.transform import Transform
|
|
8
9
|
|
|
9
10
|
|
|
11
|
+
def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
|
|
12
|
+
"""Reshape stats to match the number of dimensions of the input image.
|
|
13
|
+
|
|
14
|
+
This allows to broadcast the stats (mean or std) to the image dimensions, and
|
|
15
|
+
thus directly perform a vectorial calculation.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
stats : list of float
|
|
20
|
+
List of stats, mean or standard deviation.
|
|
21
|
+
ndim : int
|
|
22
|
+
Number of dimensions of the image, including the C channel.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
NDArray
|
|
27
|
+
Reshaped stats.
|
|
28
|
+
"""
|
|
29
|
+
return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
|
|
30
|
+
|
|
31
|
+
|
|
10
32
|
class Normalize(Transform):
|
|
11
33
|
"""
|
|
12
34
|
Normalize an image or image patch.
|
|
@@ -19,154 +41,203 @@ class Normalize(Transform):
|
|
|
19
41
|
|
|
20
42
|
Parameters
|
|
21
43
|
----------
|
|
22
|
-
|
|
23
|
-
Mean value.
|
|
24
|
-
|
|
25
|
-
Standard deviation value.
|
|
44
|
+
image_means : list of float
|
|
45
|
+
Mean value per channel.
|
|
46
|
+
image_stds : list of float
|
|
47
|
+
Standard deviation value per channel.
|
|
48
|
+
target_means : list of float, optional
|
|
49
|
+
Target mean value per channel, by default None.
|
|
50
|
+
target_stds : list of float, optional
|
|
51
|
+
Target standard deviation value per channel, by default None.
|
|
26
52
|
|
|
27
53
|
Attributes
|
|
28
54
|
----------
|
|
29
|
-
|
|
30
|
-
Mean value.
|
|
31
|
-
|
|
32
|
-
Standard deviation value.
|
|
55
|
+
image_means : list of float
|
|
56
|
+
Mean value per channel.
|
|
57
|
+
image_stds : list of float
|
|
58
|
+
Standard deviation value per channel.
|
|
59
|
+
target_means :list of float, optional
|
|
60
|
+
Target mean value per channel, by default None.
|
|
61
|
+
target_stds : list of float, optional
|
|
62
|
+
Target standard deviation value per channel, by default None.
|
|
33
63
|
"""
|
|
34
64
|
|
|
35
65
|
def __init__(
|
|
36
66
|
self,
|
|
37
|
-
|
|
38
|
-
|
|
67
|
+
image_means: list[float],
|
|
68
|
+
image_stds: list[float],
|
|
69
|
+
target_means: Optional[list[float]] = None,
|
|
70
|
+
target_stds: Optional[list[float]] = None,
|
|
39
71
|
):
|
|
40
72
|
"""Constructor.
|
|
41
73
|
|
|
42
74
|
Parameters
|
|
43
75
|
----------
|
|
44
|
-
|
|
45
|
-
Mean value.
|
|
46
|
-
|
|
47
|
-
Standard deviation value.
|
|
76
|
+
image_means : list of float
|
|
77
|
+
Mean value per channel.
|
|
78
|
+
image_stds : list of float
|
|
79
|
+
Standard deviation value per channel.
|
|
80
|
+
target_means : list of float, optional
|
|
81
|
+
Target mean value per channel, by default None.
|
|
82
|
+
target_stds : list of float, optional
|
|
83
|
+
Target standard deviation value per channel, by default None.
|
|
48
84
|
"""
|
|
49
|
-
self.
|
|
50
|
-
self.
|
|
85
|
+
self.image_means = image_means
|
|
86
|
+
self.image_stds = image_stds
|
|
87
|
+
self.target_means = target_means
|
|
88
|
+
self.target_stds = target_stds
|
|
89
|
+
|
|
51
90
|
self.eps = 1e-6
|
|
52
91
|
|
|
53
92
|
def __call__(
|
|
54
|
-
self, patch: np.ndarray, target: Optional[
|
|
55
|
-
) ->
|
|
93
|
+
self, patch: np.ndarray, target: Optional[NDArray] = None
|
|
94
|
+
) -> tuple[NDArray, Optional[NDArray]]:
|
|
56
95
|
"""Apply the transform to the source patch and the target (optional).
|
|
57
96
|
|
|
58
97
|
Parameters
|
|
59
98
|
----------
|
|
60
|
-
patch :
|
|
99
|
+
patch : NDArray
|
|
61
100
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
62
|
-
target :
|
|
101
|
+
target : NDArray, optional
|
|
63
102
|
Target for the patch, by default None.
|
|
64
103
|
|
|
65
104
|
Returns
|
|
66
105
|
-------
|
|
67
|
-
|
|
68
|
-
Transformed patch and target
|
|
106
|
+
tuple of NDArray
|
|
107
|
+
Transformed patch and target, the target can be returned as `None`.
|
|
69
108
|
"""
|
|
70
|
-
|
|
71
|
-
|
|
109
|
+
if len(self.image_means) != patch.shape[0]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
112
|
+
f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# reshape mean and std and apply the normalization to the patch
|
|
116
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
117
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
118
|
+
norm_patch = self._apply(patch, means, stds)
|
|
119
|
+
|
|
120
|
+
# same for the target patch
|
|
121
|
+
if (
|
|
122
|
+
target is not None
|
|
123
|
+
and self.target_means is not None
|
|
124
|
+
and self.target_stds is not None
|
|
125
|
+
):
|
|
126
|
+
target_means = _reshape_stats(self.target_means, target.ndim)
|
|
127
|
+
target_stds = _reshape_stats(self.target_stds, target.ndim)
|
|
128
|
+
norm_target = self._apply(target, target_means, target_stds)
|
|
129
|
+
else:
|
|
130
|
+
norm_target = None
|
|
72
131
|
|
|
73
132
|
return norm_patch, norm_target
|
|
74
133
|
|
|
75
|
-
def _apply(self, patch:
|
|
134
|
+
def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
76
135
|
"""
|
|
77
136
|
Apply the transform to the image.
|
|
78
137
|
|
|
79
138
|
Parameters
|
|
80
139
|
----------
|
|
81
|
-
patch :
|
|
140
|
+
patch : NDArray
|
|
82
141
|
Image patch, 2D or 3D, shape C(Z)YX.
|
|
142
|
+
mean : NDArray
|
|
143
|
+
Mean values.
|
|
144
|
+
std : NDArray
|
|
145
|
+
Standard deviations.
|
|
83
146
|
|
|
84
147
|
Returns
|
|
85
148
|
-------
|
|
86
|
-
|
|
87
|
-
|
|
149
|
+
NDArray
|
|
150
|
+
Normalized image patch.
|
|
88
151
|
"""
|
|
89
|
-
return ((patch -
|
|
152
|
+
return ((patch - mean) / (std + self.eps)).astype(np.float32)
|
|
90
153
|
|
|
91
154
|
|
|
92
155
|
class Denormalize:
|
|
93
156
|
"""
|
|
94
|
-
Denormalize an image
|
|
157
|
+
Denormalize an image.
|
|
95
158
|
|
|
96
159
|
Denormalization is performed expecting a zero mean and unit variance input. This
|
|
97
160
|
transform expects C(Z)YX dimensions.
|
|
98
161
|
|
|
99
|
-
|
|
162
|
+
Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
100
163
|
division by zero during the normalization step, which is taken into account during
|
|
101
164
|
denormalization.
|
|
102
165
|
|
|
103
166
|
Parameters
|
|
104
167
|
----------
|
|
105
|
-
|
|
106
|
-
Mean value.
|
|
107
|
-
|
|
108
|
-
Standard deviation value.
|
|
168
|
+
image_means : list or tuple of float
|
|
169
|
+
Mean value per channel.
|
|
170
|
+
image_stds : list or tuple of float
|
|
171
|
+
Standard deviation value per channel.
|
|
109
172
|
|
|
110
|
-
Attributes
|
|
111
|
-
----------
|
|
112
|
-
mean : float
|
|
113
|
-
Mean value.
|
|
114
|
-
std : float
|
|
115
|
-
Standard deviation value.
|
|
116
173
|
"""
|
|
117
174
|
|
|
118
175
|
def __init__(
|
|
119
176
|
self,
|
|
120
|
-
|
|
121
|
-
|
|
177
|
+
image_means: list[float],
|
|
178
|
+
image_stds: list[float],
|
|
122
179
|
):
|
|
123
180
|
"""Constructor.
|
|
124
181
|
|
|
125
182
|
Parameters
|
|
126
183
|
----------
|
|
127
|
-
|
|
128
|
-
Mean.
|
|
129
|
-
|
|
130
|
-
Standard deviation.
|
|
184
|
+
image_means : list of float
|
|
185
|
+
Mean value per channel.
|
|
186
|
+
image_stds : list of float
|
|
187
|
+
Standard deviation value per channel.
|
|
131
188
|
"""
|
|
132
|
-
self.
|
|
133
|
-
self.
|
|
189
|
+
self.image_means = image_means
|
|
190
|
+
self.image_stds = image_stds
|
|
191
|
+
|
|
134
192
|
self.eps = 1e-6
|
|
135
193
|
|
|
136
|
-
def __call__(
|
|
137
|
-
|
|
138
|
-
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
139
|
-
"""Apply the transform to the source patch and the target (optional).
|
|
194
|
+
def __call__(self, patch: NDArray) -> NDArray:
|
|
195
|
+
"""Reverse the normalization operation for a batch of patches.
|
|
140
196
|
|
|
141
197
|
Parameters
|
|
142
198
|
----------
|
|
143
|
-
patch :
|
|
144
|
-
Patch, 2D or 3D, shape
|
|
145
|
-
target : Optional[np.ndarray], optional
|
|
146
|
-
Target for the patch, by default None.
|
|
199
|
+
patch : NDArray
|
|
200
|
+
Patch, 2D or 3D, shape BC(Z)YX.
|
|
147
201
|
|
|
148
202
|
Returns
|
|
149
203
|
-------
|
|
150
|
-
|
|
151
|
-
Transformed
|
|
204
|
+
NDArray
|
|
205
|
+
Transformed array.
|
|
152
206
|
"""
|
|
153
|
-
|
|
154
|
-
|
|
207
|
+
if len(self.image_means) != patch.shape[1]:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
210
|
+
f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
|
|
211
|
+
f"match."
|
|
212
|
+
)
|
|
155
213
|
|
|
156
|
-
|
|
214
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
215
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
216
|
+
|
|
217
|
+
denorm_array = self._apply(
|
|
218
|
+
patch,
|
|
219
|
+
np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
|
|
220
|
+
np.swapaxes(stds, 0, 1),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return denorm_array.astype(np.float32)
|
|
157
224
|
|
|
158
|
-
def _apply(self,
|
|
225
|
+
def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
159
226
|
"""
|
|
160
227
|
Apply the transform to the image.
|
|
161
228
|
|
|
162
229
|
Parameters
|
|
163
230
|
----------
|
|
164
|
-
|
|
231
|
+
array : NDArray
|
|
165
232
|
Image patch, 2D or 3D, shape C(Z)YX.
|
|
233
|
+
mean : NDArray
|
|
234
|
+
Mean values.
|
|
235
|
+
std : NDArray
|
|
236
|
+
Standard deviations.
|
|
166
237
|
|
|
167
238
|
Returns
|
|
168
239
|
-------
|
|
169
|
-
|
|
170
|
-
Denormalized image
|
|
240
|
+
NDArray
|
|
241
|
+
Denormalized image array.
|
|
171
242
|
"""
|
|
172
|
-
return
|
|
243
|
+
return array * (std + self.eps) + mean
|
|
@@ -13,7 +13,10 @@ from .struct_mask_parameters import StructMaskParameters
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def _apply_struct_mask(
|
|
16
|
-
patch: np.ndarray,
|
|
16
|
+
patch: np.ndarray,
|
|
17
|
+
coords: np.ndarray,
|
|
18
|
+
struct_params: StructMaskParameters,
|
|
19
|
+
rng: Optional[np.random.Generator] = None,
|
|
17
20
|
) -> np.ndarray:
|
|
18
21
|
"""Apply structN2V masks to patch.
|
|
19
22
|
|
|
@@ -31,12 +34,17 @@ def _apply_struct_mask(
|
|
|
31
34
|
Coordinates of the ROI(subpatch) centers.
|
|
32
35
|
struct_params : StructMaskParameters
|
|
33
36
|
Parameters for the structN2V mask (axis and span).
|
|
37
|
+
rng : np.random.Generator or None
|
|
38
|
+
Random number generator.
|
|
34
39
|
|
|
35
40
|
Returns
|
|
36
41
|
-------
|
|
37
42
|
np.ndarray
|
|
38
43
|
Patch with the structN2V mask applied.
|
|
39
44
|
"""
|
|
45
|
+
if rng is None:
|
|
46
|
+
rng = np.random.default_rng()
|
|
47
|
+
|
|
40
48
|
# relative axis
|
|
41
49
|
moving_axis = -1 - struct_params.axis
|
|
42
50
|
|
|
@@ -67,7 +75,7 @@ def _apply_struct_mask(
|
|
|
67
75
|
mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
|
|
68
76
|
|
|
69
77
|
# replace neighbouring pixels with random values from flat dist
|
|
70
|
-
patch[tuple(mix.T)] =
|
|
78
|
+
patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
|
|
71
79
|
|
|
72
80
|
return patch
|
|
73
81
|
|
|
@@ -98,7 +106,9 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
|
98
106
|
|
|
99
107
|
|
|
100
108
|
def _get_stratified_coords(
|
|
101
|
-
mask_pixel_perc: float,
|
|
109
|
+
mask_pixel_perc: float,
|
|
110
|
+
shape: Tuple[int, ...],
|
|
111
|
+
rng: Optional[np.random.Generator] = None,
|
|
102
112
|
) -> np.ndarray:
|
|
103
113
|
"""
|
|
104
114
|
Generate coordinates of the pixels to mask.
|
|
@@ -113,6 +123,8 @@ def _get_stratified_coords(
|
|
|
113
123
|
calculating the distance between masked pixels across each axis.
|
|
114
124
|
shape : Tuple[int, ...]
|
|
115
125
|
Shape of the input patch.
|
|
126
|
+
rng : np.random.Generator or None
|
|
127
|
+
Random number generator.
|
|
116
128
|
|
|
117
129
|
Returns
|
|
118
130
|
-------
|
|
@@ -124,7 +136,8 @@ def _get_stratified_coords(
|
|
|
124
136
|
"Calculating coordinates is only possible for 2D and 3D patches"
|
|
125
137
|
)
|
|
126
138
|
|
|
127
|
-
rng
|
|
139
|
+
if rng is None:
|
|
140
|
+
rng = np.random.default_rng()
|
|
128
141
|
|
|
129
142
|
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
130
143
|
np.int32
|
|
@@ -228,6 +241,7 @@ def uniform_manipulate(
|
|
|
228
241
|
subpatch_size: int = 11,
|
|
229
242
|
remove_center: bool = True,
|
|
230
243
|
struct_params: Optional[StructMaskParameters] = None,
|
|
244
|
+
rng: Optional[np.random.Generator] = None,
|
|
231
245
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
232
246
|
"""
|
|
233
247
|
Manipulate pixels by replacing them with a neighbor values.
|
|
@@ -248,19 +262,23 @@ def uniform_manipulate(
|
|
|
248
262
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
249
263
|
remove_center : bool
|
|
250
264
|
Whether to remove the center pixel from the subpatch, by default False.
|
|
251
|
-
struct_params :
|
|
265
|
+
struct_params : StructMaskParameters or None
|
|
252
266
|
Parameters for the structN2V mask (axis and span).
|
|
267
|
+
rng : np.random.Generator or None
|
|
268
|
+
Random number generator.
|
|
253
269
|
|
|
254
270
|
Returns
|
|
255
271
|
-------
|
|
256
272
|
Tuple[np.ndarray]
|
|
257
273
|
Tuple containing the manipulated patch and the corresponding mask.
|
|
258
274
|
"""
|
|
275
|
+
if rng is None:
|
|
276
|
+
rng = np.random.default_rng()
|
|
277
|
+
|
|
259
278
|
# Get the coordinates of the pixels to be replaced
|
|
260
279
|
transformed_patch = patch.copy()
|
|
261
280
|
|
|
262
|
-
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
263
|
-
rng = np.random.default_rng()
|
|
281
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
264
282
|
|
|
265
283
|
# Generate coordinate grid for subpatch
|
|
266
284
|
roi_span_full = np.arange(
|
|
@@ -303,6 +321,7 @@ def median_manipulate(
|
|
|
303
321
|
mask_pixel_percentage: float,
|
|
304
322
|
subpatch_size: int = 11,
|
|
305
323
|
struct_params: Optional[StructMaskParameters] = None,
|
|
324
|
+
rng: Optional[np.random.Generator] = None,
|
|
306
325
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
307
326
|
"""
|
|
308
327
|
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
@@ -322,18 +341,23 @@ def median_manipulate(
|
|
|
322
341
|
Approximate percentage of pixels to be masked.
|
|
323
342
|
subpatch_size : int
|
|
324
343
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
325
|
-
struct_params :
|
|
344
|
+
struct_params : StructMaskParameters or None, optional
|
|
326
345
|
Parameters for the structN2V mask (axis and span).
|
|
346
|
+
rng : np.random.Generator or None, optional
|
|
347
|
+
Random number generato, by default None.
|
|
327
348
|
|
|
328
349
|
Returns
|
|
329
350
|
-------
|
|
330
351
|
Tuple[np.ndarray]
|
|
331
352
|
Tuple containing the manipulated patch, the original patch and the mask.
|
|
332
353
|
"""
|
|
354
|
+
if rng is None:
|
|
355
|
+
rng = np.random.default_rng()
|
|
356
|
+
|
|
333
357
|
transformed_patch = patch.copy()
|
|
334
358
|
|
|
335
359
|
# Get the coordinates of the pixels to be replaced
|
|
336
|
-
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
360
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
337
361
|
|
|
338
362
|
# Generate coordinate grid for subpatch
|
|
339
363
|
roi_span = np.array(
|
careamics/transforms/tta.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
"""Test-time augmentations."""
|
|
2
2
|
|
|
3
|
-
from typing import List
|
|
4
|
-
|
|
5
3
|
from torch import Tensor, flip, mean, rot90, stack
|
|
6
4
|
|
|
7
5
|
|
|
8
|
-
# TODO add tests
|
|
9
6
|
class ImageRestorationTTA:
|
|
10
7
|
"""
|
|
11
8
|
Test-time augmentation for image restoration tasks.
|
|
@@ -13,62 +10,79 @@ class ImageRestorationTTA:
|
|
|
13
10
|
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
14
11
|
as well as the original image flipped.
|
|
15
12
|
|
|
16
|
-
Tensors should be of shape SC(Z)YX
|
|
13
|
+
Tensors should be of shape SC(Z)YX.
|
|
17
14
|
|
|
18
15
|
This transformation is used in the LightningModule in order to perform test-time
|
|
19
|
-
|
|
16
|
+
augmentation.
|
|
20
17
|
"""
|
|
21
18
|
|
|
22
|
-
def
|
|
23
|
-
"""Constructor."""
|
|
24
|
-
pass
|
|
25
|
-
|
|
26
|
-
def forward(self, x: Tensor) -> List[Tensor]:
|
|
19
|
+
def forward(self, input_tensor: Tensor) -> list[Tensor]:
|
|
27
20
|
"""
|
|
28
21
|
Apply test-time augmentation to the input tensor.
|
|
29
22
|
|
|
30
23
|
Parameters
|
|
31
24
|
----------
|
|
32
|
-
|
|
25
|
+
input_tensor : Tensor
|
|
33
26
|
Input tensor, shape SC(Z)YX.
|
|
34
27
|
|
|
35
28
|
Returns
|
|
36
29
|
-------
|
|
37
|
-
|
|
30
|
+
list of torch.Tensor
|
|
38
31
|
List of augmented tensors.
|
|
39
32
|
"""
|
|
33
|
+
# axes: only applies to YX axes
|
|
34
|
+
axes = (-2, -1)
|
|
35
|
+
|
|
40
36
|
augmented = [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
rot90(
|
|
37
|
+
# original
|
|
38
|
+
input_tensor,
|
|
39
|
+
# rotations
|
|
40
|
+
rot90(input_tensor, 1, dims=axes),
|
|
41
|
+
rot90(input_tensor, 2, dims=axes),
|
|
42
|
+
rot90(input_tensor, 3, dims=axes),
|
|
43
|
+
# original flipped
|
|
44
|
+
flip(input_tensor, dims=(axes[0],)),
|
|
45
|
+
flip(input_tensor, dims=(axes[1],)),
|
|
45
46
|
]
|
|
46
|
-
augmented_flip = augmented.copy()
|
|
47
|
-
for x_ in augmented:
|
|
48
|
-
augmented_flip.append(flip(x_, dims=(-3, -1)))
|
|
49
|
-
return augmented_flip
|
|
50
47
|
|
|
51
|
-
|
|
48
|
+
# rotated once, flipped
|
|
49
|
+
augmented.extend(
|
|
50
|
+
[
|
|
51
|
+
flip(augmented[1], dims=(axes[0],)),
|
|
52
|
+
flip(augmented[1], dims=(axes[1],)),
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return augmented
|
|
57
|
+
|
|
58
|
+
def backward(self, x: list[Tensor]) -> Tensor:
|
|
52
59
|
"""Undo the test-time augmentation.
|
|
53
60
|
|
|
54
61
|
Parameters
|
|
55
62
|
----------
|
|
56
63
|
x : Any
|
|
57
|
-
List of augmented tensors.
|
|
64
|
+
List of augmented tensors of shape SC(Z)YX.
|
|
58
65
|
|
|
59
66
|
Returns
|
|
60
67
|
-------
|
|
61
68
|
Any
|
|
62
69
|
Original tensor.
|
|
63
70
|
"""
|
|
71
|
+
axes = (-2, -1)
|
|
72
|
+
|
|
64
73
|
reverse = [
|
|
74
|
+
# original
|
|
65
75
|
x[0],
|
|
66
|
-
|
|
67
|
-
rot90(x[
|
|
68
|
-
rot90(x[
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
76
|
+
# rotated
|
|
77
|
+
rot90(x[1], -1, dims=axes),
|
|
78
|
+
rot90(x[2], -2, dims=axes),
|
|
79
|
+
rot90(x[3], -3, dims=axes),
|
|
80
|
+
# original flipped
|
|
81
|
+
flip(x[4], dims=(axes[0],)),
|
|
82
|
+
flip(x[5], dims=(axes[1],)),
|
|
83
|
+
# rotated once, flipped
|
|
84
|
+
rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
|
|
85
|
+
rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
|
|
73
86
|
]
|
|
87
|
+
|
|
74
88
|
return mean(stack(reverse), dim=0)
|
careamics/utils/__init__.py
CHANGED
|
@@ -7,9 +7,11 @@ __all__ = [
|
|
|
7
7
|
"BaseEnum",
|
|
8
8
|
"get_logger",
|
|
9
9
|
"get_careamics_home",
|
|
10
|
+
"autocorrelation",
|
|
10
11
|
]
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
from .autocorrelation import autocorrelation
|
|
13
15
|
from .base_enum import BaseEnum
|
|
14
16
|
from .context import cwd, get_careamics_home
|
|
15
17
|
from .logging import get_logger
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Autocorrelation function."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def autocorrelation(image: NDArray) -> NDArray:
|
|
8
|
+
"""Compute the autocorrelation of an image.
|
|
9
|
+
|
|
10
|
+
This method is used to explore spatial correlations in images,
|
|
11
|
+
in particular in the noise.
|
|
12
|
+
|
|
13
|
+
The autocorrelation is normalized to the zero-shift value, which is centered in
|
|
14
|
+
the resulting images.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
image : NDArray
|
|
19
|
+
Input image.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
numpy.ndarray
|
|
24
|
+
Autocorrelation of the input image.
|
|
25
|
+
"""
|
|
26
|
+
# normalize image
|
|
27
|
+
image = (image - np.mean(image)) / np.std(image)
|
|
28
|
+
|
|
29
|
+
# compute autocorrelation in fourier space
|
|
30
|
+
image = np.fft.fftn(image)
|
|
31
|
+
image = np.abs(image) ** 2
|
|
32
|
+
image = np.fft.ifftn(image).real
|
|
33
|
+
|
|
34
|
+
# normalize to zero shift value
|
|
35
|
+
image = image / image.flat[0]
|
|
36
|
+
|
|
37
|
+
# shift zero frequency to center
|
|
38
|
+
image = np.fft.fftshift(image)
|
|
39
|
+
|
|
40
|
+
return image
|
careamics/utils/ram.py
CHANGED
|
@@ -5,11 +5,11 @@ import psutil
|
|
|
5
5
|
|
|
6
6
|
def get_ram_size() -> int:
|
|
7
7
|
"""
|
|
8
|
-
Get RAM size in
|
|
8
|
+
Get RAM size in mbytes.
|
|
9
9
|
|
|
10
10
|
Returns
|
|
11
11
|
-------
|
|
12
12
|
int
|
|
13
13
|
RAM size in mbytes.
|
|
14
14
|
"""
|
|
15
|
-
return psutil.virtual_memory().
|
|
15
|
+
return psutil.virtual_memory().available / 1024**2
|
|
@@ -1,31 +1,32 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: careamics
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0rc8
|
|
4
4
|
Summary: Toolbox for running N2V and friends.
|
|
5
5
|
Project-URL: homepage, https://careamics.github.io/
|
|
6
6
|
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
7
|
-
Author-email:
|
|
7
|
+
Author-email: Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
|
|
8
8
|
License: BSD-3-Clause
|
|
9
9
|
License-File: LICENSE
|
|
10
10
|
Classifier: Development Status :: 3 - Alpha
|
|
11
11
|
Classifier: License :: OSI Approved :: BSD License
|
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.9
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
18
17
|
Classifier: Typing :: Typed
|
|
19
|
-
Requires-Python: >=3.
|
|
18
|
+
Requires-Python: >=3.9
|
|
20
19
|
Requires-Dist: bioimageio-core>=0.6.0
|
|
20
|
+
Requires-Dist: numpy<2.0.0
|
|
21
21
|
Requires-Dist: psutil
|
|
22
22
|
Requires-Dist: pydantic>=2.5
|
|
23
23
|
Requires-Dist: pytorch-lightning>=2.2.0
|
|
24
24
|
Requires-Dist: pyyaml
|
|
25
|
-
Requires-Dist: scikit-image
|
|
25
|
+
Requires-Dist: scikit-image<=0.23.2
|
|
26
26
|
Requires-Dist: tifffile
|
|
27
27
|
Requires-Dist: torch>=2.0.0
|
|
28
|
-
Requires-Dist:
|
|
28
|
+
Requires-Dist: torchvision
|
|
29
|
+
Requires-Dist: zarr<3.0.0
|
|
29
30
|
Provides-Extra: dev
|
|
30
31
|
Requires-Dist: pre-commit; extra == 'dev'
|
|
31
32
|
Requires-Dist: pytest; extra == 'dev'
|