careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixel manipulation methods.
|
|
3
|
+
|
|
4
|
+
Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
5
|
+
masked pixels.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _apply_struct_mask(
|
|
16
|
+
patch: np.ndarray,
|
|
17
|
+
coords: np.ndarray,
|
|
18
|
+
struct_params: StructMaskParameters,
|
|
19
|
+
rng: Optional[np.random.Generator] = None,
|
|
20
|
+
) -> np.ndarray:
|
|
21
|
+
"""Apply structN2V masks to patch.
|
|
22
|
+
|
|
23
|
+
Each point in `coords` corresponds to the center of a mask, masks are paremeterized
|
|
24
|
+
by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
|
|
25
|
+
a random value.
|
|
26
|
+
|
|
27
|
+
Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
patch : np.ndarray
|
|
32
|
+
Patch to be manipulated, 2D or 3D.
|
|
33
|
+
coords : np.ndarray
|
|
34
|
+
Coordinates of the ROI(subpatch) centers.
|
|
35
|
+
struct_params : StructMaskParameters
|
|
36
|
+
Parameters for the structN2V mask (axis and span).
|
|
37
|
+
rng : np.random.Generator or None
|
|
38
|
+
Random number generator.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
np.ndarray
|
|
43
|
+
Patch with the structN2V mask applied.
|
|
44
|
+
"""
|
|
45
|
+
if rng is None:
|
|
46
|
+
rng = np.random.default_rng()
|
|
47
|
+
|
|
48
|
+
# relative axis
|
|
49
|
+
moving_axis = -1 - struct_params.axis
|
|
50
|
+
|
|
51
|
+
# Create a mask array
|
|
52
|
+
mask = np.expand_dims(
|
|
53
|
+
np.ones(struct_params.span), axis=list(range(len(patch.shape) - 1))
|
|
54
|
+
) # (1, 1, span) or (1, span)
|
|
55
|
+
|
|
56
|
+
# Move the moving axis to the correct position
|
|
57
|
+
# i.e. the axis along which the coordinates should change
|
|
58
|
+
mask = np.moveaxis(mask, -1, moving_axis)
|
|
59
|
+
center = np.array(mask.shape) // 2
|
|
60
|
+
|
|
61
|
+
# Mark the center
|
|
62
|
+
mask[tuple(center.T)] = 0
|
|
63
|
+
|
|
64
|
+
# displacements from center
|
|
65
|
+
dx = np.indices(mask.shape)[:, mask == 1] - center[:, None]
|
|
66
|
+
|
|
67
|
+
# combine all coords (ndim, npts,) with all displacements (ncoords,ndim,)
|
|
68
|
+
mix = dx.T[..., None] + coords.T[None]
|
|
69
|
+
mix = mix.transpose([1, 0, 2]).reshape([mask.ndim, -1]).T
|
|
70
|
+
|
|
71
|
+
# delete entries that are out of bounds
|
|
72
|
+
mix = np.delete(mix, mix[:, moving_axis] < 0, axis=0)
|
|
73
|
+
|
|
74
|
+
max_bound = patch.shape[moving_axis] - 1
|
|
75
|
+
mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
|
|
76
|
+
|
|
77
|
+
# replace neighbouring pixels with random values from flat dist
|
|
78
|
+
patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
|
|
79
|
+
|
|
80
|
+
return patch
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
84
|
+
"""
|
|
85
|
+
Randomly sample a jitter to be applied to the masking grid.
|
|
86
|
+
|
|
87
|
+
This is done to account for cases where the step size is not an integer.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
step : float
|
|
92
|
+
Step size of the grid, output of np.linspace.
|
|
93
|
+
rng : np.random.Generator
|
|
94
|
+
Random number generator.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
np.ndarray
|
|
99
|
+
Array of random jitter to be added to the grid.
|
|
100
|
+
"""
|
|
101
|
+
# Define the random jitter to be added to the grid
|
|
102
|
+
odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
|
|
103
|
+
|
|
104
|
+
# Round the step size to the nearest integer depending on the jitter
|
|
105
|
+
return np.floor(step) if odd_jitter == 0 else np.ceil(step)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _get_stratified_coords(
|
|
109
|
+
mask_pixel_perc: float,
|
|
110
|
+
shape: Tuple[int, ...],
|
|
111
|
+
rng: Optional[np.random.Generator] = None,
|
|
112
|
+
) -> np.ndarray:
|
|
113
|
+
"""
|
|
114
|
+
Generate coordinates of the pixels to mask.
|
|
115
|
+
|
|
116
|
+
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
117
|
+
the distance between masked pixels is approximately the same.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
mask_pixel_perc : float
|
|
122
|
+
Actual (quasi) percentage of masked pixels across the whole image. Used in
|
|
123
|
+
calculating the distance between masked pixels across each axis.
|
|
124
|
+
shape : Tuple[int, ...]
|
|
125
|
+
Shape of the input patch.
|
|
126
|
+
rng : np.random.Generator or None
|
|
127
|
+
Random number generator.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
np.ndarray
|
|
132
|
+
Array of coordinates of the masked pixels.
|
|
133
|
+
"""
|
|
134
|
+
if len(shape) < 2 or len(shape) > 3:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
"Calculating coordinates is only possible for 2D and 3D patches"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if rng is None:
|
|
140
|
+
rng = np.random.default_rng()
|
|
141
|
+
|
|
142
|
+
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
143
|
+
np.int32
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Define a grid of coordinates for each axis in the input patch and the step size
|
|
147
|
+
pixel_coords = []
|
|
148
|
+
steps = []
|
|
149
|
+
for axis_size in shape:
|
|
150
|
+
# make sure axis size is evenly divisible by box size
|
|
151
|
+
num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
|
|
152
|
+
axis_pixel_coords, step = np.linspace(
|
|
153
|
+
0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
|
|
154
|
+
)
|
|
155
|
+
# explain
|
|
156
|
+
pixel_coords.append(axis_pixel_coords.T)
|
|
157
|
+
steps.append(step)
|
|
158
|
+
|
|
159
|
+
# Create a meshgrid of coordinates for each axis in the input patch
|
|
160
|
+
coordinate_grid_list = np.meshgrid(*pixel_coords)
|
|
161
|
+
coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
|
|
162
|
+
|
|
163
|
+
grid_random_increment = rng.integers(
|
|
164
|
+
_odd_jitter_func(float(max(steps)), rng) # type: ignore
|
|
165
|
+
* np.ones_like(coordinate_grid).astype(np.int32)
|
|
166
|
+
- 1,
|
|
167
|
+
size=coordinate_grid.shape,
|
|
168
|
+
endpoint=True,
|
|
169
|
+
)
|
|
170
|
+
coordinate_grid += grid_random_increment
|
|
171
|
+
coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
|
|
172
|
+
return coordinate_grid
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _create_subpatch_center_mask(
|
|
176
|
+
subpatch: np.ndarray, center_coords: np.ndarray
|
|
177
|
+
) -> np.ndarray:
|
|
178
|
+
"""Create a mask with the center of the subpatch masked.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
subpatch : np.ndarray
|
|
183
|
+
Subpatch to be manipulated.
|
|
184
|
+
center_coords : np.ndarray
|
|
185
|
+
Coordinates of the original center before possible crop.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
np.ndarray
|
|
190
|
+
Mask with the center of the subpatch masked.
|
|
191
|
+
"""
|
|
192
|
+
mask = np.ones(subpatch.shape)
|
|
193
|
+
mask[tuple(center_coords)] = 0
|
|
194
|
+
return np.ma.make_mask(mask) # type: ignore
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _create_subpatch_struct_mask(
|
|
198
|
+
subpatch: np.ndarray, center_coords: np.ndarray, struct_params: StructMaskParameters
|
|
199
|
+
) -> np.ndarray:
|
|
200
|
+
"""Create a structN2V mask for the subpatch.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
subpatch : np.ndarray
|
|
205
|
+
Subpatch to be manipulated.
|
|
206
|
+
center_coords : np.ndarray
|
|
207
|
+
Coordinates of the original center before possible crop.
|
|
208
|
+
struct_params : StructMaskParameters
|
|
209
|
+
Parameters for the structN2V mask (axis and span).
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
np.ndarray
|
|
214
|
+
StructN2V mask for the subpatch.
|
|
215
|
+
"""
|
|
216
|
+
# Create a mask with the center of the subpatch masked
|
|
217
|
+
mask_placeholder = np.ones(subpatch.shape)
|
|
218
|
+
|
|
219
|
+
# reshape to move the struct axis to the first position
|
|
220
|
+
mask_reshaped = np.moveaxis(mask_placeholder, struct_params.axis, 0)
|
|
221
|
+
|
|
222
|
+
# create the mask index for the struct axis
|
|
223
|
+
mask_index = slice(
|
|
224
|
+
max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2),
|
|
225
|
+
min(
|
|
226
|
+
1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2,
|
|
227
|
+
subpatch.shape[struct_params.axis],
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
mask_reshaped[struct_params.axis][mask_index] = 0
|
|
231
|
+
|
|
232
|
+
# reshape back to the original shape
|
|
233
|
+
mask = np.moveaxis(mask_reshaped, 0, struct_params.axis)
|
|
234
|
+
|
|
235
|
+
return np.ma.make_mask(mask) # type: ignore
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def uniform_manipulate(
|
|
239
|
+
patch: np.ndarray,
|
|
240
|
+
mask_pixel_percentage: float,
|
|
241
|
+
subpatch_size: int = 11,
|
|
242
|
+
remove_center: bool = True,
|
|
243
|
+
struct_params: Optional[StructMaskParameters] = None,
|
|
244
|
+
rng: Optional[np.random.Generator] = None,
|
|
245
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
246
|
+
"""
|
|
247
|
+
Manipulate pixels by replacing them with a neighbor values.
|
|
248
|
+
|
|
249
|
+
Manipulated pixels are selected unformly selected in a subpatch, away from a grid
|
|
250
|
+
with an approximate uniform probability to be selected across the whole patch.
|
|
251
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the
|
|
252
|
+
data, replacing the pixels in the mask with random values (excluding the pixel
|
|
253
|
+
already manipulated).
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
patch : np.ndarray
|
|
258
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
259
|
+
mask_pixel_percentage : float
|
|
260
|
+
Approximate percentage of pixels to be masked.
|
|
261
|
+
subpatch_size : int
|
|
262
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
263
|
+
remove_center : bool
|
|
264
|
+
Whether to remove the center pixel from the subpatch, by default False.
|
|
265
|
+
struct_params : StructMaskParameters or None
|
|
266
|
+
Parameters for the structN2V mask (axis and span).
|
|
267
|
+
rng : np.random.Generator or None
|
|
268
|
+
Random number generator.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
Tuple[np.ndarray]
|
|
273
|
+
Tuple containing the manipulated patch and the corresponding mask.
|
|
274
|
+
"""
|
|
275
|
+
if rng is None:
|
|
276
|
+
rng = np.random.default_rng()
|
|
277
|
+
|
|
278
|
+
# Get the coordinates of the pixels to be replaced
|
|
279
|
+
transformed_patch = patch.copy()
|
|
280
|
+
|
|
281
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
282
|
+
|
|
283
|
+
# Generate coordinate grid for subpatch
|
|
284
|
+
roi_span_full = np.arange(
|
|
285
|
+
-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)
|
|
286
|
+
).astype(np.int32)
|
|
287
|
+
|
|
288
|
+
# Remove the center pixel from the grid if needed
|
|
289
|
+
roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
|
|
290
|
+
|
|
291
|
+
# Randomly select coordinates from the grid
|
|
292
|
+
random_increment = rng.choice(roi_span, size=subpatch_centers.shape)
|
|
293
|
+
|
|
294
|
+
# Clip the coordinates to the patch size
|
|
295
|
+
replacement_coords = np.clip(
|
|
296
|
+
subpatch_centers + random_increment,
|
|
297
|
+
0,
|
|
298
|
+
[patch.shape[i] - 1 for i in range(len(patch.shape))],
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Get the replacement pixels from all subpatchs
|
|
302
|
+
replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
|
|
303
|
+
|
|
304
|
+
# Replace the original pixels with the replacement pixels
|
|
305
|
+
transformed_patch[tuple(subpatch_centers.T.tolist())] = replacement_pixels
|
|
306
|
+
mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
|
|
307
|
+
|
|
308
|
+
if struct_params is not None:
|
|
309
|
+
transformed_patch = _apply_struct_mask(
|
|
310
|
+
transformed_patch, subpatch_centers, struct_params
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
return (
|
|
314
|
+
transformed_patch,
|
|
315
|
+
mask,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def median_manipulate(
|
|
320
|
+
patch: np.ndarray,
|
|
321
|
+
mask_pixel_percentage: float,
|
|
322
|
+
subpatch_size: int = 11,
|
|
323
|
+
struct_params: Optional[StructMaskParameters] = None,
|
|
324
|
+
rng: Optional[np.random.Generator] = None,
|
|
325
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
326
|
+
"""
|
|
327
|
+
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
328
|
+
|
|
329
|
+
N2V2 version, manipulated pixels are selected randomly away from a grid with an
|
|
330
|
+
approximate uniform probability to be selected across the whole patch.
|
|
331
|
+
|
|
332
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the data,
|
|
333
|
+
replacing the pixels in the mask with random values (excluding the pixel already
|
|
334
|
+
manipulated).
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
patch : np.ndarray
|
|
339
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
340
|
+
mask_pixel_percentage : floar
|
|
341
|
+
Approximate percentage of pixels to be masked.
|
|
342
|
+
subpatch_size : int
|
|
343
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
344
|
+
struct_params : StructMaskParameters or None, optional
|
|
345
|
+
Parameters for the structN2V mask (axis and span).
|
|
346
|
+
rng : np.random.Generator or None, optional
|
|
347
|
+
Random number generato, by default None.
|
|
348
|
+
|
|
349
|
+
Returns
|
|
350
|
+
-------
|
|
351
|
+
Tuple[np.ndarray]
|
|
352
|
+
Tuple containing the manipulated patch, the original patch and the mask.
|
|
353
|
+
"""
|
|
354
|
+
if rng is None:
|
|
355
|
+
rng = np.random.default_rng()
|
|
356
|
+
|
|
357
|
+
transformed_patch = patch.copy()
|
|
358
|
+
|
|
359
|
+
# Get the coordinates of the pixels to be replaced
|
|
360
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
361
|
+
|
|
362
|
+
# Generate coordinate grid for subpatch
|
|
363
|
+
roi_span = np.array(
|
|
364
|
+
[-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)]
|
|
365
|
+
).astype(np.int32)
|
|
366
|
+
|
|
367
|
+
subpatch_crops_span_full = subpatch_centers[np.newaxis, ...].T + roi_span
|
|
368
|
+
|
|
369
|
+
# Dimensions n dims, n centers, (min, max)
|
|
370
|
+
subpatch_crops_span_clipped = np.clip(
|
|
371
|
+
subpatch_crops_span_full,
|
|
372
|
+
a_min=np.zeros_like(patch.shape)[:, np.newaxis, np.newaxis],
|
|
373
|
+
a_max=np.array(patch.shape)[:, np.newaxis, np.newaxis],
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
for idx in range(subpatch_crops_span_clipped.shape[1]):
|
|
377
|
+
subpatch_coords = subpatch_crops_span_clipped[:, idx, ...]
|
|
378
|
+
idxs = [
|
|
379
|
+
slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1)
|
|
380
|
+
for x in subpatch_coords
|
|
381
|
+
]
|
|
382
|
+
subpatch = patch[tuple(idxs)]
|
|
383
|
+
subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0]
|
|
384
|
+
|
|
385
|
+
if struct_params is None:
|
|
386
|
+
subpatch_mask = _create_subpatch_center_mask(
|
|
387
|
+
subpatch, subpatch_center_adjusted
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
subpatch_mask = _create_subpatch_struct_mask(
|
|
391
|
+
subpatch, subpatch_center_adjusted, struct_params
|
|
392
|
+
)
|
|
393
|
+
transformed_patch[tuple(subpatch_centers[idx])] = np.median(
|
|
394
|
+
subpatch[subpatch_mask]
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
|
|
398
|
+
|
|
399
|
+
if struct_params is not None:
|
|
400
|
+
transformed_patch = _apply_struct_mask(
|
|
401
|
+
transformed_patch, subpatch_centers, struct_params
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
return (
|
|
405
|
+
transformed_patch,
|
|
406
|
+
mask,
|
|
407
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Class representing the parameters of structN2V masks."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class StructMaskParameters:
|
|
9
|
+
"""Parameters of structN2V masks.
|
|
10
|
+
|
|
11
|
+
Attributes
|
|
12
|
+
----------
|
|
13
|
+
axis : Literal[0, 1]
|
|
14
|
+
Axis along which to apply the mask, horizontal (0) or vertical (1).
|
|
15
|
+
span : int
|
|
16
|
+
Span of the mask.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
axis: Literal[0, 1]
|
|
20
|
+
span: int
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""A general parent class for transforms."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Transform:
|
|
7
|
+
"""A general parent class for transforms."""
|
|
8
|
+
|
|
9
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
10
|
+
"""Apply the transform.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
*args : Any
|
|
15
|
+
Arguments.
|
|
16
|
+
**kwargs : Any
|
|
17
|
+
Keyword arguments.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
Any
|
|
22
|
+
Transformed data.
|
|
23
|
+
"""
|
|
24
|
+
pass
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Test-time augmentations."""
|
|
2
|
+
|
|
3
|
+
from torch import Tensor, flip, mean, rot90, stack
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ImageRestorationTTA:
|
|
7
|
+
"""
|
|
8
|
+
Test-time augmentation for image restoration tasks.
|
|
9
|
+
|
|
10
|
+
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
11
|
+
as well as the original image flipped.
|
|
12
|
+
|
|
13
|
+
Tensors should be of shape SC(Z)YX.
|
|
14
|
+
|
|
15
|
+
This transformation is used in the LightningModule in order to perform test-time
|
|
16
|
+
augmentation.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def forward(self, input_tensor: Tensor) -> list[Tensor]:
|
|
20
|
+
"""
|
|
21
|
+
Apply test-time augmentation to the input tensor.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
input_tensor : Tensor
|
|
26
|
+
Input tensor, shape SC(Z)YX.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
list of torch.Tensor
|
|
31
|
+
List of augmented tensors.
|
|
32
|
+
"""
|
|
33
|
+
# axes: only applies to YX axes
|
|
34
|
+
axes = (-2, -1)
|
|
35
|
+
|
|
36
|
+
augmented = [
|
|
37
|
+
# original
|
|
38
|
+
input_tensor,
|
|
39
|
+
# rotations
|
|
40
|
+
rot90(input_tensor, 1, dims=axes),
|
|
41
|
+
rot90(input_tensor, 2, dims=axes),
|
|
42
|
+
rot90(input_tensor, 3, dims=axes),
|
|
43
|
+
# original flipped
|
|
44
|
+
flip(input_tensor, dims=(axes[0],)),
|
|
45
|
+
flip(input_tensor, dims=(axes[1],)),
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# rotated once, flipped
|
|
49
|
+
augmented.extend(
|
|
50
|
+
[
|
|
51
|
+
flip(augmented[1], dims=(axes[0],)),
|
|
52
|
+
flip(augmented[1], dims=(axes[1],)),
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return augmented
|
|
57
|
+
|
|
58
|
+
def backward(self, x: list[Tensor]) -> Tensor:
|
|
59
|
+
"""Undo the test-time augmentation.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
x : Any
|
|
64
|
+
List of augmented tensors of shape SC(Z)YX.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
Any
|
|
69
|
+
Original tensor.
|
|
70
|
+
"""
|
|
71
|
+
axes = (-2, -1)
|
|
72
|
+
|
|
73
|
+
reverse = [
|
|
74
|
+
# original
|
|
75
|
+
x[0],
|
|
76
|
+
# rotated
|
|
77
|
+
rot90(x[1], -1, dims=axes),
|
|
78
|
+
rot90(x[2], -2, dims=axes),
|
|
79
|
+
rot90(x[3], -3, dims=axes),
|
|
80
|
+
# original flipped
|
|
81
|
+
flip(x[4], dims=(axes[0],)),
|
|
82
|
+
flip(x[5], dims=(axes[1],)),
|
|
83
|
+
# rotated once, flipped
|
|
84
|
+
rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
|
|
85
|
+
rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
return mean(stack(reverse), dim=0)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""XY flip transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.transforms.transform import Transform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYFlip(Transform):
|
|
11
|
+
"""Flip image along X and Y axis, one at a time.
|
|
12
|
+
|
|
13
|
+
This transform randomly flips one of the last two axes.
|
|
14
|
+
|
|
15
|
+
This transform expects C(Z)YX dimensions.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
axis_indices : List[int]
|
|
20
|
+
Indices of the axes that can be flipped.
|
|
21
|
+
rng : np.random.Generator
|
|
22
|
+
Random number generator.
|
|
23
|
+
p : float
|
|
24
|
+
Probability of applying the transform.
|
|
25
|
+
seed : Optional[int]
|
|
26
|
+
Random seed.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
flip_x : bool, optional
|
|
31
|
+
Whether to flip along the X axis, by default True.
|
|
32
|
+
flip_y : bool, optional
|
|
33
|
+
Whether to flip along the Y axis, by default True.
|
|
34
|
+
p : float, optional
|
|
35
|
+
Probability of applying the transform, by default 0.5.
|
|
36
|
+
seed : Optional[int], optional
|
|
37
|
+
Random seed, by default None.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
flip_x: bool = True,
|
|
43
|
+
flip_y: bool = True,
|
|
44
|
+
p: float = 0.5,
|
|
45
|
+
seed: Optional[int] = None,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Constructor.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
flip_x : bool, optional
|
|
52
|
+
Whether to flip along the X axis, by default True.
|
|
53
|
+
flip_y : bool, optional
|
|
54
|
+
Whether to flip along the Y axis, by default True.
|
|
55
|
+
p : float
|
|
56
|
+
Probability of applying the transform, by default 0.5.
|
|
57
|
+
seed : Optional[int], optional
|
|
58
|
+
Random seed, by default None.
|
|
59
|
+
"""
|
|
60
|
+
if p < 0 or p > 1:
|
|
61
|
+
raise ValueError("Probability must be in [0, 1].")
|
|
62
|
+
|
|
63
|
+
if not flip_x and not flip_y:
|
|
64
|
+
raise ValueError("At least one axis must be flippable.")
|
|
65
|
+
|
|
66
|
+
# probability to apply the transform
|
|
67
|
+
self.p = p
|
|
68
|
+
|
|
69
|
+
# "flippable" axes
|
|
70
|
+
self.axis_indices = []
|
|
71
|
+
|
|
72
|
+
if flip_y:
|
|
73
|
+
self.axis_indices.append(-2)
|
|
74
|
+
if flip_x:
|
|
75
|
+
self.axis_indices.append(-1)
|
|
76
|
+
|
|
77
|
+
# numpy random generator
|
|
78
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
79
|
+
|
|
80
|
+
def __call__(
|
|
81
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
82
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
83
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
patch : np.ndarray
|
|
88
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
89
|
+
target : Optional[np.ndarray], optional
|
|
90
|
+
Target for the patch, by default None.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
95
|
+
Transformed patch and target.
|
|
96
|
+
"""
|
|
97
|
+
if self.rng.random() > self.p:
|
|
98
|
+
return patch, target
|
|
99
|
+
|
|
100
|
+
# choose an axis to flip
|
|
101
|
+
axis = self.rng.choice(self.axis_indices)
|
|
102
|
+
|
|
103
|
+
patch_transformed = self._apply(patch, axis)
|
|
104
|
+
target_transformed = self._apply(target, axis) if target is not None else None
|
|
105
|
+
|
|
106
|
+
return patch_transformed, target_transformed
|
|
107
|
+
|
|
108
|
+
def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
|
|
109
|
+
"""Apply the transform to the image.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
patch : np.ndarray
|
|
114
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
115
|
+
axis : int
|
|
116
|
+
Axis to flip.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
np.ndarray
|
|
121
|
+
Flipped image patch.
|
|
122
|
+
"""
|
|
123
|
+
return np.ascontiguousarray(np.flip(patch, axis=axis))
|