careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixel manipulation methods.
|
|
3
|
+
|
|
4
|
+
Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
5
|
+
masked pixels.
|
|
6
|
+
"""
|
|
7
|
+
from typing import Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _apply_struct_mask(
|
|
15
|
+
patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
|
|
16
|
+
) -> np.ndarray:
|
|
17
|
+
"""Applies structN2V masks to patch.
|
|
18
|
+
|
|
19
|
+
Each point in `coords` corresponds to the center of a mask, masks are paremeterized
|
|
20
|
+
by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
|
|
21
|
+
a random value.
|
|
22
|
+
|
|
23
|
+
Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
patch : np.ndarray
|
|
28
|
+
Patch to be manipulated, 2D or 3D.
|
|
29
|
+
coords : np.ndarray
|
|
30
|
+
Coordinates of the ROI(subpatch) centers.
|
|
31
|
+
struct_params : StructMaskParameters
|
|
32
|
+
Parameters for the structN2V mask (axis and span).
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
np.ndarray
|
|
37
|
+
Patch with the structN2V mask applied.
|
|
38
|
+
"""
|
|
39
|
+
# relative axis
|
|
40
|
+
moving_axis = -1 - struct_params.axis
|
|
41
|
+
|
|
42
|
+
# Create a mask array
|
|
43
|
+
mask = np.expand_dims(
|
|
44
|
+
np.ones(struct_params.span), axis=list(range(len(patch.shape) - 1))
|
|
45
|
+
) # (1, 1, span) or (1, span)
|
|
46
|
+
|
|
47
|
+
# Move the moving axis to the correct position
|
|
48
|
+
# i.e. the axis along which the coordinates should change
|
|
49
|
+
mask = np.moveaxis(mask, -1, moving_axis)
|
|
50
|
+
center = np.array(mask.shape) // 2
|
|
51
|
+
|
|
52
|
+
# Mark the center
|
|
53
|
+
mask[tuple(center.T)] = 0
|
|
54
|
+
|
|
55
|
+
# displacements from center
|
|
56
|
+
dx = np.indices(mask.shape)[:, mask == 1] - center[:, None]
|
|
57
|
+
|
|
58
|
+
# combine all coords (ndim, npts,) with all displacements (ncoords,ndim,)
|
|
59
|
+
mix = dx.T[..., None] + coords.T[None]
|
|
60
|
+
mix = mix.transpose([1, 0, 2]).reshape([mask.ndim, -1]).T
|
|
61
|
+
|
|
62
|
+
# delete entries that are out of bounds
|
|
63
|
+
mix = np.delete(mix, mix[:, moving_axis] < 0, axis=0)
|
|
64
|
+
|
|
65
|
+
max_bound = patch.shape[moving_axis] - 1
|
|
66
|
+
mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
|
|
67
|
+
|
|
68
|
+
# replace neighbouring pixels with random values from flat dist
|
|
69
|
+
patch[tuple(mix.T)] = np.random.uniform(patch.min(), patch.max(), size=mix.shape[0])
|
|
70
|
+
|
|
71
|
+
return patch
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
75
|
+
"""
|
|
76
|
+
Randomly sample a jitter to be applied to the masking grid.
|
|
77
|
+
|
|
78
|
+
This is done to account for cases where the step size is not an integer.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
step : float
|
|
83
|
+
Step size of the grid, output of np.linspace.
|
|
84
|
+
rng : np.random.Generator
|
|
85
|
+
Random number generator.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
np.ndarray
|
|
90
|
+
Array of random jitter to be added to the grid.
|
|
91
|
+
"""
|
|
92
|
+
# Define the random jitter to be added to the grid
|
|
93
|
+
odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
|
|
94
|
+
|
|
95
|
+
# Round the step size to the nearest integer depending on the jitter
|
|
96
|
+
return np.floor(step) if odd_jitter == 0 else np.ceil(step)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _get_stratified_coords(
|
|
100
|
+
mask_pixel_perc: float, shape: Union[Tuple[int, int], Tuple[int, int, int]]
|
|
101
|
+
) -> np.ndarray:
|
|
102
|
+
"""
|
|
103
|
+
Generate coordinates of the pixels to mask.
|
|
104
|
+
|
|
105
|
+
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
106
|
+
the distance between masked pixels is approximately the same.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
mask_pixel_perc : float
|
|
111
|
+
Actual (quasi) percentage of masked pixels across the whole image. Used in
|
|
112
|
+
calculating the distance between masked pixels across each axis.
|
|
113
|
+
shape : Tuple[int, ...]
|
|
114
|
+
Shape of the input patch.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
np.ndarray
|
|
119
|
+
Array of coordinates of the masked pixels.
|
|
120
|
+
"""
|
|
121
|
+
if len(shape) < 2 or len(shape) > 3:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Calculating coordinates is only possible for 2D and 3D patches"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
rng = np.random.default_rng()
|
|
127
|
+
|
|
128
|
+
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
129
|
+
np.int32
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Define a grid of coordinates for each axis in the input patch and the step size
|
|
133
|
+
pixel_coords = []
|
|
134
|
+
steps = []
|
|
135
|
+
for axis_size in shape:
|
|
136
|
+
# make sure axis size is evenly divisible by box size
|
|
137
|
+
num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
|
|
138
|
+
axis_pixel_coords, step = np.linspace(
|
|
139
|
+
0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
|
|
140
|
+
)
|
|
141
|
+
# explain
|
|
142
|
+
pixel_coords.append(axis_pixel_coords.T)
|
|
143
|
+
steps.append(step)
|
|
144
|
+
|
|
145
|
+
# Create a meshgrid of coordinates for each axis in the input patch
|
|
146
|
+
coordinate_grid_list = np.meshgrid(*pixel_coords)
|
|
147
|
+
coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
|
|
148
|
+
|
|
149
|
+
grid_random_increment = rng.integers(
|
|
150
|
+
_odd_jitter_func(float(max(steps)), rng)
|
|
151
|
+
* np.ones_like(coordinate_grid).astype(np.int32)
|
|
152
|
+
- 1,
|
|
153
|
+
size=coordinate_grid.shape,
|
|
154
|
+
endpoint=True,
|
|
155
|
+
)
|
|
156
|
+
coordinate_grid += grid_random_increment
|
|
157
|
+
coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
|
|
158
|
+
return coordinate_grid
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _create_subpatch_center_mask(
|
|
162
|
+
subpatch: np.ndarray, center_coords: np.ndarray
|
|
163
|
+
) -> np.ndarray:
|
|
164
|
+
"""Create a mask with the center of the subpatch masked.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
subpatch : np.ndarray
|
|
169
|
+
Subpatch to be manipulated.
|
|
170
|
+
center_coords : np.ndarray
|
|
171
|
+
Coordinates of the original center before possible crop.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
np.ndarray
|
|
176
|
+
Mask with the center of the subpatch masked.
|
|
177
|
+
"""
|
|
178
|
+
mask = np.ones(subpatch.shape)
|
|
179
|
+
mask[tuple(center_coords)] = 0
|
|
180
|
+
return np.ma.make_mask(mask) # type: ignore
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _create_subpatch_struct_mask(
|
|
184
|
+
subpatch: np.ndarray, center_coords: np.ndarray, struct_params: StructMaskParameters
|
|
185
|
+
) -> np.ndarray:
|
|
186
|
+
"""Create a structN2V mask for the subpatch.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
subpatch : np.ndarray
|
|
191
|
+
Subpatch to be manipulated.
|
|
192
|
+
center_coords : np.ndarray
|
|
193
|
+
Coordinates of the original center before possible crop.
|
|
194
|
+
struct_params : StructMaskParameters
|
|
195
|
+
Parameters for the structN2V mask (axis and span).
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
np.ndarray
|
|
200
|
+
StructN2V mask for the subpatch.
|
|
201
|
+
"""
|
|
202
|
+
# Create a mask with the center of the subpatch masked
|
|
203
|
+
mask_placeholder = np.ones(subpatch.shape)
|
|
204
|
+
|
|
205
|
+
# reshape to move the struct axis to the first position
|
|
206
|
+
mask_reshaped = np.moveaxis(mask_placeholder, struct_params.axis, 0)
|
|
207
|
+
|
|
208
|
+
# create the mask index for the struct axis
|
|
209
|
+
mask_index = slice(
|
|
210
|
+
max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2),
|
|
211
|
+
min(
|
|
212
|
+
1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2,
|
|
213
|
+
subpatch.shape[struct_params.axis],
|
|
214
|
+
),
|
|
215
|
+
)
|
|
216
|
+
mask_reshaped[struct_params.axis][mask_index] = 0
|
|
217
|
+
|
|
218
|
+
# reshape back to the original shape
|
|
219
|
+
mask = np.moveaxis(mask_reshaped, 0, struct_params.axis)
|
|
220
|
+
|
|
221
|
+
return np.ma.make_mask(mask) # type: ignore
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def uniform_manipulate(
|
|
225
|
+
patch: np.ndarray,
|
|
226
|
+
mask_pixel_percentage: float,
|
|
227
|
+
subpatch_size: int = 11,
|
|
228
|
+
remove_center: bool = True,
|
|
229
|
+
struct_params: Optional[StructMaskParameters] = None,
|
|
230
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
231
|
+
"""
|
|
232
|
+
Manipulate pixels by replacing them with a neighbor values.
|
|
233
|
+
|
|
234
|
+
Manipulated pixels are selected unformly selected in a subpatch, away from a grid
|
|
235
|
+
with an approximate uniform probability to be selected across the whole patch.
|
|
236
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the
|
|
237
|
+
data, replacing the pixels in the mask with random values (excluding the pixel
|
|
238
|
+
already manipulated).
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
patch : np.ndarray
|
|
243
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
244
|
+
mask_pixel_percentage : float
|
|
245
|
+
Approximate percentage of pixels to be masked.
|
|
246
|
+
subpatch_size : int
|
|
247
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
248
|
+
remove_center : bool
|
|
249
|
+
Whether to remove the center pixel from the subpatch, by default False. See
|
|
250
|
+
uniform with/without central pixel in the documentation. #TODO add link
|
|
251
|
+
struct_params: Optional[StructMaskParameters]
|
|
252
|
+
Parameters for the structN2V mask (axis and span).
|
|
253
|
+
|
|
254
|
+
Returns
|
|
255
|
+
-------
|
|
256
|
+
Tuple[np.ndarray]
|
|
257
|
+
Tuple containing the manipulated patch and the corresponding mask.
|
|
258
|
+
"""
|
|
259
|
+
# Get the coordinates of the pixels to be replaced
|
|
260
|
+
transformed_patch = patch.copy()
|
|
261
|
+
|
|
262
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
263
|
+
rng = np.random.default_rng()
|
|
264
|
+
|
|
265
|
+
# Generate coordinate grid for subpatch
|
|
266
|
+
roi_span_full = np.arange(
|
|
267
|
+
-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)
|
|
268
|
+
).astype(np.int32)
|
|
269
|
+
|
|
270
|
+
# Remove the center pixel from the grid if needed
|
|
271
|
+
roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
|
|
272
|
+
|
|
273
|
+
# Randomly select coordinates from the grid
|
|
274
|
+
random_increment = rng.choice(roi_span, size=subpatch_centers.shape)
|
|
275
|
+
|
|
276
|
+
# Clip the coordinates to the patch size
|
|
277
|
+
replacement_coords = np.clip(
|
|
278
|
+
subpatch_centers + random_increment,
|
|
279
|
+
0,
|
|
280
|
+
[patch.shape[i] - 1 for i in range(len(patch.shape))],
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Get the replacement pixels from all subpatchs
|
|
284
|
+
replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
|
|
285
|
+
|
|
286
|
+
# Replace the original pixels with the replacement pixels
|
|
287
|
+
transformed_patch[tuple(subpatch_centers.T.tolist())] = replacement_pixels
|
|
288
|
+
mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
|
|
289
|
+
|
|
290
|
+
if struct_params is not None:
|
|
291
|
+
transformed_patch = _apply_struct_mask(
|
|
292
|
+
transformed_patch, subpatch_centers, struct_params
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
return (
|
|
296
|
+
transformed_patch,
|
|
297
|
+
mask,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def median_manipulate(
|
|
302
|
+
patch: np.ndarray,
|
|
303
|
+
mask_pixel_percentage: float,
|
|
304
|
+
subpatch_size: int = 11,
|
|
305
|
+
struct_params: Optional[StructMaskParameters] = None,
|
|
306
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
307
|
+
"""
|
|
308
|
+
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
309
|
+
|
|
310
|
+
N2V2 version, manipulated pixels are selected randomly away from a grid with an
|
|
311
|
+
approximate uniform probability to be selected across the whole patch.
|
|
312
|
+
|
|
313
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the data,
|
|
314
|
+
replacing the pixels in the mask with random values (excluding the pixel already
|
|
315
|
+
manipulated).
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
patch : np.ndarray
|
|
320
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
321
|
+
mask_pixel_percentage : floar
|
|
322
|
+
Approximate percentage of pixels to be masked.
|
|
323
|
+
subpatch_size : int
|
|
324
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
325
|
+
struct_params: Optional[StructMaskParameters]
|
|
326
|
+
Parameters for the structN2V mask (axis and span).
|
|
327
|
+
|
|
328
|
+
Returns
|
|
329
|
+
-------
|
|
330
|
+
Tuple[np.ndarray]
|
|
331
|
+
Tuple containing the manipulated patch, the original patch and the mask.
|
|
332
|
+
"""
|
|
333
|
+
transformed_patch = patch.copy()
|
|
334
|
+
|
|
335
|
+
# Get the coordinates of the pixels to be replaced
|
|
336
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
337
|
+
|
|
338
|
+
# Generate coordinate grid for subpatch
|
|
339
|
+
roi_span = np.array(
|
|
340
|
+
[-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)]
|
|
341
|
+
).astype(np.int32)
|
|
342
|
+
|
|
343
|
+
subpatch_crops_span_full = subpatch_centers[np.newaxis, ...].T + roi_span
|
|
344
|
+
|
|
345
|
+
# Dimensions n dims, n centers, (min, max)
|
|
346
|
+
subpatch_crops_span_clipped = np.clip(
|
|
347
|
+
subpatch_crops_span_full,
|
|
348
|
+
a_min=np.zeros_like(patch.shape)[:, np.newaxis, np.newaxis],
|
|
349
|
+
a_max=np.array(patch.shape)[:, np.newaxis, np.newaxis],
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
for idx in range(subpatch_crops_span_clipped.shape[1]):
|
|
353
|
+
subpatch_coords = subpatch_crops_span_clipped[:, idx, ...]
|
|
354
|
+
idxs = [
|
|
355
|
+
slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1)
|
|
356
|
+
for x in subpatch_coords
|
|
357
|
+
]
|
|
358
|
+
subpatch = patch[tuple(idxs)]
|
|
359
|
+
subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0]
|
|
360
|
+
|
|
361
|
+
if struct_params is None:
|
|
362
|
+
subpatch_mask = _create_subpatch_center_mask(
|
|
363
|
+
subpatch, subpatch_center_adjusted
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
subpatch_mask = _create_subpatch_struct_mask(
|
|
367
|
+
subpatch, subpatch_center_adjusted, struct_params
|
|
368
|
+
)
|
|
369
|
+
transformed_patch[tuple(subpatch_centers[idx])] = np.median(
|
|
370
|
+
subpatch[subpatch_mask]
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
|
|
374
|
+
|
|
375
|
+
if struct_params is not None:
|
|
376
|
+
transformed_patch = _apply_struct_mask(
|
|
377
|
+
transformed_patch, subpatch_centers, struct_params
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
return (
|
|
381
|
+
transformed_patch,
|
|
382
|
+
mask,
|
|
383
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class StructMaskParameters:
|
|
7
|
+
"""Parameters of structN2V masks.
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
axis : Literal[0, 1]
|
|
12
|
+
Axis along which to apply the mask, horizontal (0) or vertical (1).
|
|
13
|
+
span : int
|
|
14
|
+
Span of the mask.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
axis: Literal[0, 1]
|
|
18
|
+
span: int
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Test-time augmentations."""
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from torch import Tensor, flip, mean, rot90, stack
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# TODO add tests
|
|
9
|
+
class ImageRestorationTTA:
|
|
10
|
+
"""
|
|
11
|
+
Test-time augmentation for image restoration tasks.
|
|
12
|
+
|
|
13
|
+
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
14
|
+
as well as the original image flipped.
|
|
15
|
+
|
|
16
|
+
Tensors should be of shape SC(Z)YX
|
|
17
|
+
|
|
18
|
+
This transformation is used in the LightningModule in order to perform test-time
|
|
19
|
+
agumentation.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
"""Constructor."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
def forward(self, x: Tensor) -> List[Tensor]:
|
|
27
|
+
"""
|
|
28
|
+
Apply test-time augmentation to the input tensor.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
x : Tensor
|
|
33
|
+
Input tensor, shape SC(Z)YX.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
List[Tensor]
|
|
38
|
+
List of augmented tensors.
|
|
39
|
+
"""
|
|
40
|
+
augmented = [
|
|
41
|
+
x,
|
|
42
|
+
rot90(x, 1, dims=(-2, -1)),
|
|
43
|
+
rot90(x, 2, dims=(-2, -1)),
|
|
44
|
+
rot90(x, 3, dims=(-2, -1)),
|
|
45
|
+
]
|
|
46
|
+
augmented_flip = augmented.copy()
|
|
47
|
+
for x_ in augmented:
|
|
48
|
+
augmented_flip.append(flip(x_, dims=(-3, -1)))
|
|
49
|
+
return augmented_flip
|
|
50
|
+
|
|
51
|
+
def backward(self, x: List[Tensor]) -> np.ndarray:
|
|
52
|
+
"""Undo the test-time augmentation.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
x : Any
|
|
57
|
+
List of augmented tensors.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
Any
|
|
62
|
+
Original tensor.
|
|
63
|
+
"""
|
|
64
|
+
reverse = [
|
|
65
|
+
x[0],
|
|
66
|
+
rot90(x[1], -1, dims=(-2, -1)),
|
|
67
|
+
rot90(x[2], -2, dims=(-2, -1)),
|
|
68
|
+
rot90(x[3], -3, dims=(-2, -1)),
|
|
69
|
+
flip(x[4], dims=(-3, -1)),
|
|
70
|
+
rot90(flip(x[5], dims=(-3, -1)), -1, dims=(-2, -1)),
|
|
71
|
+
rot90(flip(x[6], dims=(-3, -1)), -2, dims=(-2, -1)),
|
|
72
|
+
rot90(flip(x[7], dims=(-3, -1)), -3, dims=(-2, -1)),
|
|
73
|
+
]
|
|
74
|
+
return mean(stack(reverse), dim=0)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Any, Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from albumentations import DualTransform
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class XYRandomRotate90(DualTransform):
|
|
8
|
+
"""Applies random 90 degree rotations to the YX axis.
|
|
9
|
+
|
|
10
|
+
This transform expects (Z)YXC dimensions.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
p : int, optional
|
|
15
|
+
Probability to apply the transform, by default 0.5
|
|
16
|
+
is_3D : bool, optional
|
|
17
|
+
Whether the patches are 3D, by default False
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, p: float = 0.5, is_3D: bool = False):
|
|
21
|
+
"""Constructor.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
p : float, optional
|
|
26
|
+
Probability to apply the transform, by default 0.5
|
|
27
|
+
is_3D : bool, optional
|
|
28
|
+
Whether the patches are 3D, by default False
|
|
29
|
+
"""
|
|
30
|
+
super().__init__(p=p)
|
|
31
|
+
|
|
32
|
+
self.is_3D = is_3D
|
|
33
|
+
|
|
34
|
+
# rotation axes
|
|
35
|
+
if is_3D:
|
|
36
|
+
self.axes = (1, 2)
|
|
37
|
+
else:
|
|
38
|
+
self.axes = (0, 1)
|
|
39
|
+
|
|
40
|
+
def get_params(self, **kwargs: Any) -> Dict[str, int]:
|
|
41
|
+
"""Get the transform parameters.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
Dict[str, int]
|
|
46
|
+
Transform parameters.
|
|
47
|
+
"""
|
|
48
|
+
return {"n_rotations": np.random.randint(1, 4)}
|
|
49
|
+
|
|
50
|
+
def apply(self, patch: np.ndarray, n_rotations: int, **kwargs: Any) -> np.ndarray:
|
|
51
|
+
"""Apply the transform to the image.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
patch : np.ndarray
|
|
56
|
+
Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
57
|
+
flip_axis : int
|
|
58
|
+
Axis along which to flip the patch.
|
|
59
|
+
"""
|
|
60
|
+
if len(patch.shape) == 3 and self.is_3D:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"Incompatible patch shape and dimensionality. ZYXC patch shape "
|
|
63
|
+
"expected, but got YXC shape."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return np.ascontiguousarray(np.rot90(patch, k=n_rotations, axes=self.axes))
|
|
67
|
+
|
|
68
|
+
def apply_to_mask(
|
|
69
|
+
self, mask: np.ndarray, n_rotations: int, **kwargs: Any
|
|
70
|
+
) -> np.ndarray:
|
|
71
|
+
"""Apply the transform to the mask.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
mask : np.ndarray
|
|
76
|
+
Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
|
|
77
|
+
"""
|
|
78
|
+
if len(mask.shape) != 4 and self.is_3D:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"Incompatible mask shape and dimensionality. ZYXC patch shape "
|
|
81
|
+
"expected, but got YXC shape."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return np.ascontiguousarray(np.rot90(mask, k=n_rotations, axes=self.axes))
|
|
85
|
+
|
|
86
|
+
def get_transform_init_args_names(self) -> Tuple[str, str]:
|
|
87
|
+
"""
|
|
88
|
+
Get the transform arguments.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
Tuple[str]
|
|
93
|
+
Transform arguments.
|
|
94
|
+
"""
|
|
95
|
+
return ("p", "is_3D")
|
careamics/utils/__init__.py
CHANGED
|
@@ -2,19 +2,17 @@
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
-
"denormalize",
|
|
6
|
-
"normalize",
|
|
7
|
-
"get_device",
|
|
8
|
-
"check_axes_validity",
|
|
9
|
-
"add_axes",
|
|
10
|
-
"check_tiling_validity",
|
|
11
5
|
"cwd",
|
|
12
|
-
"
|
|
6
|
+
"get_ram_size",
|
|
7
|
+
"check_path_exists",
|
|
8
|
+
"BaseEnum",
|
|
9
|
+
"get_logger",
|
|
10
|
+
"get_careamics_home",
|
|
13
11
|
]
|
|
14
12
|
|
|
15
13
|
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
from .
|
|
14
|
+
from .base_enum import BaseEnum
|
|
15
|
+
from .context import cwd, get_careamics_home
|
|
16
|
+
from .logging import get_logger
|
|
17
|
+
from .path_utils import check_path_exists
|
|
18
|
+
from .ram import get_ram_size
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from enum import Enum, EnumMeta
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class _ContainerEnum(EnumMeta):
|
|
6
|
+
def __contains__(cls, item: Any) -> bool:
|
|
7
|
+
try:
|
|
8
|
+
cls(item)
|
|
9
|
+
except ValueError:
|
|
10
|
+
return False
|
|
11
|
+
return True
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def has_value(cls, value: Any) -> bool:
|
|
15
|
+
return value in cls._value2member_map_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseEnum(Enum, metaclass=_ContainerEnum):
|
|
19
|
+
"""Base Enum class, allowing checking if a value is in the enum.
|
|
20
|
+
|
|
21
|
+
Example
|
|
22
|
+
-------
|
|
23
|
+
>>> from careamics.utils.base_enum import BaseEnum
|
|
24
|
+
>>> # Define a new enum
|
|
25
|
+
>>> class BaseEnumExtension(BaseEnum):
|
|
26
|
+
... VALUE = "value"
|
|
27
|
+
>>> # Check if value is in the enum
|
|
28
|
+
>>> "value" in BaseEnumExtension
|
|
29
|
+
True
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
pass
|
careamics/utils/context.py
CHANGED
|
@@ -9,6 +9,24 @@ from pathlib import Path
|
|
|
9
9
|
from typing import Iterator, Union
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
def get_careamics_home() -> Path:
|
|
13
|
+
"""Return the CAREamics home directory.
|
|
14
|
+
|
|
15
|
+
CAREamics home directory is a hidden folder in home.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
Path
|
|
20
|
+
CAREamics home directory path.
|
|
21
|
+
"""
|
|
22
|
+
home = Path.home() / ".careamics"
|
|
23
|
+
|
|
24
|
+
if not home.exists():
|
|
25
|
+
home.mkdir(parents=True, exist_ok=True)
|
|
26
|
+
|
|
27
|
+
return home
|
|
28
|
+
|
|
29
|
+
|
|
12
30
|
@contextmanager
|
|
13
31
|
def cwd(path: Union[str, Path]) -> Iterator[None]:
|
|
14
32
|
"""
|
|
@@ -29,8 +47,10 @@ def cwd(path: Union[str, Path]) -> Iterator[None]:
|
|
|
29
47
|
|
|
30
48
|
Examples
|
|
31
49
|
--------
|
|
32
|
-
|
|
33
|
-
|
|
50
|
+
The context is whcnaged within the block and then restored to the original one.
|
|
51
|
+
|
|
52
|
+
>>> with cwd(my_path):
|
|
53
|
+
... pass # do something
|
|
34
54
|
"""
|
|
35
55
|
path = Path(path)
|
|
36
56
|
|