careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +80 -44
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -2
- careamics/config/configuration_factory.py +4 -16
- careamics/config/data_model.py +10 -14
- careamics/config/inference_model.py +0 -65
- careamics/config/optimizer_models.py +4 -4
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/conftest.py +12 -0
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +71 -32
- careamics/dataset/iterable_dataset.py +155 -68
- careamics/dataset/patching/patching.py +56 -15
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/tiled_patching.py +3 -1
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +45 -19
- careamics/lightning_module.py +8 -2
- careamics/lightning_prediction_datamodule.py +3 -13
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/bmz_io.py +3 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction/stitch_prediction.py +2 -6
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +49 -13
- careamics/transforms/normalize.py +55 -3
- careamics/transforms/pixel_manipulation.py +5 -5
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -67
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py
CHANGED
|
@@ -34,7 +34,9 @@ class UnetEncoder(nn.Module):
|
|
|
34
34
|
Dropout probability, by default 0.0.
|
|
35
35
|
pool_kernel : int, optional
|
|
36
36
|
Kernel size for the max pooling layers, by default 2.
|
|
37
|
-
|
|
37
|
+
n2v2 : bool, optional
|
|
38
|
+
Whether to use N2V2 architecture, by default False.
|
|
39
|
+
groups : int, optional
|
|
38
40
|
Number of blocked connections from input channels to output
|
|
39
41
|
channels, by default 1.
|
|
40
42
|
"""
|
|
@@ -70,7 +72,9 @@ class UnetEncoder(nn.Module):
|
|
|
70
72
|
Dropout probability, by default 0.0.
|
|
71
73
|
pool_kernel : int, optional
|
|
72
74
|
Kernel size for the max pooling layers, by default 2.
|
|
73
|
-
|
|
75
|
+
n2v2 : bool, optional
|
|
76
|
+
Whether to use N2V2 architecture, by default False.
|
|
77
|
+
groups : int, optional
|
|
74
78
|
Number of blocked connections from input channels to output
|
|
75
79
|
channels, by default 1.
|
|
76
80
|
"""
|
|
@@ -140,7 +144,9 @@ class UnetDecoder(nn.Module):
|
|
|
140
144
|
Whether to use batch normalization, by default True.
|
|
141
145
|
dropout : float, optional
|
|
142
146
|
Dropout probability, by default 0.0.
|
|
143
|
-
|
|
147
|
+
n2v2 : bool, optional
|
|
148
|
+
Whether to use N2V2 architecture, by default False.
|
|
149
|
+
groups : int, optional
|
|
144
150
|
Number of blocked connections from input channels to output
|
|
145
151
|
channels, by default 1.
|
|
146
152
|
"""
|
|
@@ -170,7 +176,9 @@ class UnetDecoder(nn.Module):
|
|
|
170
176
|
Whether to use batch normalization, by default True.
|
|
171
177
|
dropout : float, optional
|
|
172
178
|
Dropout probability, by default 0.0.
|
|
173
|
-
|
|
179
|
+
n2v2 : bool, optional
|
|
180
|
+
Whether to use N2V2 architecture, by default False.
|
|
181
|
+
groups : int, optional
|
|
174
182
|
Number of blocked connections from input channels to output
|
|
175
183
|
channels, by default 1.
|
|
176
184
|
"""
|
|
@@ -250,22 +258,25 @@ class UnetDecoder(nn.Module):
|
|
|
250
258
|
|
|
251
259
|
@staticmethod
|
|
252
260
|
def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
|
|
253
|
-
"""
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
A.
|
|
261
|
+
"""Interleave two tensors.
|
|
262
|
+
|
|
263
|
+
Splits the tensors `A` and `B` into equally sized groups along the channel
|
|
264
|
+
axis (axis=1); then concatenates the groups in alternating order along the
|
|
265
|
+
channel axis, starting with the first group from tensor A.
|
|
258
266
|
|
|
259
267
|
Parameters
|
|
260
268
|
----------
|
|
261
|
-
A: torch.Tensor
|
|
262
|
-
|
|
263
|
-
|
|
269
|
+
A : torch.Tensor
|
|
270
|
+
First tensor.
|
|
271
|
+
B : torch.Tensor
|
|
272
|
+
Second tensor.
|
|
273
|
+
groups : int
|
|
264
274
|
The number of groups.
|
|
265
275
|
|
|
266
276
|
Returns
|
|
267
277
|
-------
|
|
268
278
|
torch.Tensor
|
|
279
|
+
Interleaved tensor.
|
|
269
280
|
|
|
270
281
|
Raises
|
|
271
282
|
------
|
|
@@ -322,8 +333,14 @@ class UNet(nn.Module):
|
|
|
322
333
|
Dropout probability, by default 0.0.
|
|
323
334
|
pool_kernel : int, optional
|
|
324
335
|
Kernel size of the pooling layers, by default 2.
|
|
325
|
-
|
|
336
|
+
final_activation : Optional[Callable], optional
|
|
326
337
|
Activation function to use for the last layer, by default None.
|
|
338
|
+
n2v2 : bool, optional
|
|
339
|
+
Whether to use N2V2 architecture, by default False.
|
|
340
|
+
independent_channels : bool
|
|
341
|
+
Whether to train the channels independently, by default True.
|
|
342
|
+
**kwargs : Any
|
|
343
|
+
Additional keyword arguments, unused.
|
|
327
344
|
"""
|
|
328
345
|
|
|
329
346
|
def __init__(
|
|
@@ -362,11 +379,15 @@ class UNet(nn.Module):
|
|
|
362
379
|
Dropout probability, by default 0.0.
|
|
363
380
|
pool_kernel : int, optional
|
|
364
381
|
Kernel size of the pooling layers, by default 2.
|
|
365
|
-
|
|
382
|
+
final_activation : Optional[Callable], optional
|
|
366
383
|
Activation function to use for the last layer, by default None.
|
|
384
|
+
n2v2 : bool, optional
|
|
385
|
+
Whether to use N2V2 architecture, by default False.
|
|
367
386
|
independent_channels : bool
|
|
368
387
|
Whether to train parallel independent networks for each channel, by
|
|
369
388
|
default True.
|
|
389
|
+
**kwargs : Any
|
|
390
|
+
Additional keyword arguments, unused.
|
|
370
391
|
"""
|
|
371
392
|
super().__init__()
|
|
372
393
|
|
|
@@ -1,8 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Prediction convenience functions.
|
|
3
|
-
|
|
4
|
-
These functions are used during prediction.
|
|
5
|
-
"""
|
|
1
|
+
"""Prediction utility functions."""
|
|
6
2
|
|
|
7
3
|
from typing import List
|
|
8
4
|
|
|
@@ -21,7 +17,7 @@ def stitch_prediction(
|
|
|
21
17
|
----------
|
|
22
18
|
tiles : List[torch.Tensor]
|
|
23
19
|
Cropped tiles and their respective stitching coordinates.
|
|
24
|
-
|
|
20
|
+
stitching_data : List
|
|
25
21
|
List of information and coordinates obtained from
|
|
26
22
|
`dataset.tiled_patching.extract_tiles`.
|
|
27
23
|
|
careamics/transforms/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"get_all_transforms",
|
|
5
5
|
"N2VManipulate",
|
|
6
|
-
"
|
|
6
|
+
"XYFlip",
|
|
7
7
|
"XYRandomRotate90",
|
|
8
8
|
"ImageRestorationTTA",
|
|
9
9
|
"Denormalize",
|
|
@@ -14,7 +14,7 @@ __all__ = [
|
|
|
14
14
|
|
|
15
15
|
from .compose import Compose, get_all_transforms
|
|
16
16
|
from .n2v_manipulate import N2VManipulate
|
|
17
|
-
from .nd_flip import NDFlip
|
|
18
17
|
from .normalize import Denormalize, Normalize
|
|
19
18
|
from .tta import ImageRestorationTTA
|
|
19
|
+
from .xy_flip import XYFlip
|
|
20
20
|
from .xy_random_rotate90 import XYRandomRotate90
|
careamics/transforms/compose.py
CHANGED
|
@@ -1,26 +1,26 @@
|
|
|
1
1
|
"""A class chaining transforms together."""
|
|
2
2
|
|
|
3
|
-
from typing import Callable, List, Optional, Tuple
|
|
3
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
from careamics.config.data_model import TRANSFORMS_UNION
|
|
8
8
|
|
|
9
9
|
from .n2v_manipulate import N2VManipulate
|
|
10
|
-
from .nd_flip import NDFlip
|
|
11
10
|
from .normalize import Normalize
|
|
12
11
|
from .transform import Transform
|
|
12
|
+
from .xy_flip import XYFlip
|
|
13
13
|
from .xy_random_rotate90 import XYRandomRotate90
|
|
14
14
|
|
|
15
15
|
ALL_TRANSFORMS = {
|
|
16
16
|
"Normalize": Normalize,
|
|
17
17
|
"N2VManipulate": N2VManipulate,
|
|
18
|
-
"
|
|
18
|
+
"XYFlip": XYFlip,
|
|
19
19
|
"XYRandomRotate90": XYRandomRotate90,
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def get_all_transforms() ->
|
|
23
|
+
def get_all_transforms() -> Dict[str, type]:
|
|
24
24
|
"""Return all the transforms accepted by CAREamics.
|
|
25
25
|
|
|
26
26
|
Returns
|
|
@@ -33,7 +33,19 @@ def get_all_transforms() -> dict:
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class Compose:
|
|
36
|
-
"""A class chaining transforms together.
|
|
36
|
+
"""A class chaining transforms together.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
transform_list : List[TRANSFORMS_UNION]
|
|
41
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
42
|
+
transform and its parameters.
|
|
43
|
+
|
|
44
|
+
Attributes
|
|
45
|
+
----------
|
|
46
|
+
_callable_transforms : Callable
|
|
47
|
+
A callable that applies the transforms to the input data.
|
|
48
|
+
"""
|
|
37
49
|
|
|
38
50
|
def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None:
|
|
39
51
|
"""Instantiate a Compose object.
|
|
@@ -68,7 +80,21 @@ class Compose:
|
|
|
68
80
|
|
|
69
81
|
def _chain(
|
|
70
82
|
patch: np.ndarray, target: Optional[np.ndarray]
|
|
71
|
-
) -> Tuple[np.ndarray,
|
|
83
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
84
|
+
"""Chain transforms on the input data.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
patch : np.ndarray
|
|
89
|
+
Input data.
|
|
90
|
+
target : Optional[np.ndarray]
|
|
91
|
+
Target data, by default None.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
96
|
+
The output of the transformations.
|
|
97
|
+
"""
|
|
72
98
|
params = (patch, target)
|
|
73
99
|
|
|
74
100
|
for t in transforms:
|
|
@@ -88,7 +114,7 @@ class Compose:
|
|
|
88
114
|
patch : np.ndarray
|
|
89
115
|
The input data.
|
|
90
116
|
target : Optional[np.ndarray], optional
|
|
91
|
-
Target data, by default None
|
|
117
|
+
Target data, by default None.
|
|
92
118
|
|
|
93
119
|
Returns
|
|
94
120
|
-------
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""N2V manipulation transform."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Literal, Optional, Tuple
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -17,10 +19,35 @@ class N2VManipulate(Transform):
|
|
|
17
19
|
|
|
18
20
|
Parameters
|
|
19
21
|
----------
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
roi_size : int, optional
|
|
23
|
+
Size of the replacement area, by default 11.
|
|
24
|
+
masked_pixel_percentage : float, optional
|
|
25
|
+
Percentage of pixels to mask, by default 0.2.
|
|
26
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
27
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
28
|
+
remove_center : bool, optional
|
|
29
|
+
Whether to remove central pixel from patch, by default True.
|
|
30
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
31
|
+
StructN2V mask axis, by default "none".
|
|
32
|
+
struct_mask_span : int, optional
|
|
33
|
+
StructN2V mask span, by default 5.
|
|
34
|
+
seed : Optional[int], optional
|
|
35
|
+
Random seed, by default None.
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
masked_pixel_percentage : float
|
|
40
|
+
Percentage of pixels to mask.
|
|
22
41
|
roi_size : int
|
|
23
|
-
Size of the
|
|
42
|
+
Size of the replacement area.
|
|
43
|
+
strategy : Literal[ "uniform", "median" ]
|
|
44
|
+
Replaccement strategy, uniform or median.
|
|
45
|
+
remove_center : bool
|
|
46
|
+
Whether to remove central pixel from patch.
|
|
47
|
+
struct_mask : Optional[StructMaskParameters]
|
|
48
|
+
StructN2V mask parameters.
|
|
49
|
+
rng : Generator
|
|
50
|
+
Random number generator.
|
|
24
51
|
"""
|
|
25
52
|
|
|
26
53
|
def __init__(
|
|
@@ -40,24 +67,24 @@ class N2VManipulate(Transform):
|
|
|
40
67
|
Parameters
|
|
41
68
|
----------
|
|
42
69
|
roi_size : int, optional
|
|
43
|
-
Size of the replacement area, by default 11
|
|
70
|
+
Size of the replacement area, by default 11.
|
|
44
71
|
masked_pixel_percentage : float, optional
|
|
45
|
-
Percentage of pixels to mask, by default 0.2
|
|
72
|
+
Percentage of pixels to mask, by default 0.2.
|
|
46
73
|
strategy : Literal[ "uniform", "median" ], optional
|
|
47
|
-
Replaccement strategy, uniform or median, by default uniform
|
|
74
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
48
75
|
remove_center : bool, optional
|
|
49
|
-
Whether to remove central pixel from patch, by default True
|
|
76
|
+
Whether to remove central pixel from patch, by default True.
|
|
50
77
|
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
51
|
-
StructN2V mask axis, by default "none"
|
|
78
|
+
StructN2V mask axis, by default "none".
|
|
52
79
|
struct_mask_span : int, optional
|
|
53
|
-
StructN2V mask span, by default 5
|
|
80
|
+
StructN2V mask span, by default 5.
|
|
54
81
|
seed : Optional[int], optional
|
|
55
|
-
Random seed, by default None
|
|
82
|
+
Random seed, by default None.
|
|
56
83
|
"""
|
|
57
84
|
self.masked_pixel_percentage = masked_pixel_percentage
|
|
58
85
|
self.roi_size = roi_size
|
|
59
86
|
self.strategy = strategy
|
|
60
|
-
self.remove_center = remove_center
|
|
87
|
+
self.remove_center = remove_center # TODO is this ever used?
|
|
61
88
|
|
|
62
89
|
if struct_mask_axis == SupportedStructAxis.NONE:
|
|
63
90
|
self.struct_mask: Optional[StructMaskParameters] = None
|
|
@@ -77,8 +104,17 @@ class N2VManipulate(Transform):
|
|
|
77
104
|
|
|
78
105
|
Parameters
|
|
79
106
|
----------
|
|
80
|
-
|
|
81
|
-
Image
|
|
107
|
+
patch : np.ndarray
|
|
108
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
109
|
+
*args : Any
|
|
110
|
+
Additional arguments, unused.
|
|
111
|
+
**kwargs : Any
|
|
112
|
+
Additional keyword arguments, unused.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
117
|
+
Masked patch, original patch, and mask.
|
|
82
118
|
"""
|
|
83
119
|
masked = np.zeros_like(patch)
|
|
84
120
|
mask = np.zeros_like(patch)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Normalization and denormalization transforms for image patches."""
|
|
2
|
+
|
|
1
3
|
from typing import Optional, Tuple
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -15,6 +17,13 @@ class Normalize(Transform):
|
|
|
15
17
|
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
16
18
|
division by zero and that it returns a float32 image.
|
|
17
19
|
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
mean : float
|
|
23
|
+
Mean value.
|
|
24
|
+
std : float
|
|
25
|
+
Standard deviation value.
|
|
26
|
+
|
|
18
27
|
Attributes
|
|
19
28
|
----------
|
|
20
29
|
mean : float
|
|
@@ -28,6 +37,15 @@ class Normalize(Transform):
|
|
|
28
37
|
mean: float,
|
|
29
38
|
std: float,
|
|
30
39
|
):
|
|
40
|
+
"""Constructor.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
mean : float
|
|
45
|
+
Mean value.
|
|
46
|
+
std : float
|
|
47
|
+
Standard deviation value.
|
|
48
|
+
"""
|
|
31
49
|
self.mean = mean
|
|
32
50
|
self.std = std
|
|
33
51
|
self.eps = 1e-6
|
|
@@ -42,7 +60,7 @@ class Normalize(Transform):
|
|
|
42
60
|
patch : np.ndarray
|
|
43
61
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
44
62
|
target : Optional[np.ndarray], optional
|
|
45
|
-
Target for the patch, by default None
|
|
63
|
+
Target for the patch, by default None.
|
|
46
64
|
|
|
47
65
|
Returns
|
|
48
66
|
-------
|
|
@@ -55,6 +73,19 @@ class Normalize(Transform):
|
|
|
55
73
|
return norm_patch, norm_target
|
|
56
74
|
|
|
57
75
|
def _apply(self, patch: np.ndarray) -> np.ndarray:
|
|
76
|
+
"""
|
|
77
|
+
Apply the transform to the image.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
patch : np.ndarray
|
|
82
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
np.ndarray
|
|
87
|
+
Normalizedimage patch.
|
|
88
|
+
"""
|
|
58
89
|
return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32)
|
|
59
90
|
|
|
60
91
|
|
|
@@ -69,6 +100,13 @@ class Denormalize:
|
|
|
69
100
|
division by zero during the normalization step, which is taken into account during
|
|
70
101
|
denormalization.
|
|
71
102
|
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
mean : float
|
|
106
|
+
Mean value.
|
|
107
|
+
std : float
|
|
108
|
+
Standard deviation value.
|
|
109
|
+
|
|
72
110
|
Attributes
|
|
73
111
|
----------
|
|
74
112
|
mean : float
|
|
@@ -82,6 +120,15 @@ class Denormalize:
|
|
|
82
120
|
mean: float,
|
|
83
121
|
std: float,
|
|
84
122
|
):
|
|
123
|
+
"""Constructor.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
mean : float
|
|
128
|
+
Mean.
|
|
129
|
+
std : float
|
|
130
|
+
Standard deviation.
|
|
131
|
+
"""
|
|
85
132
|
self.mean = mean
|
|
86
133
|
self.std = std
|
|
87
134
|
self.eps = 1e-6
|
|
@@ -96,7 +143,7 @@ class Denormalize:
|
|
|
96
143
|
patch : np.ndarray
|
|
97
144
|
Patch, 2D or 3D, shape C(Z)YX.
|
|
98
145
|
target : Optional[np.ndarray], optional
|
|
99
|
-
Target for the patch, by default None
|
|
146
|
+
Target for the patch, by default None.
|
|
100
147
|
|
|
101
148
|
Returns
|
|
102
149
|
-------
|
|
@@ -115,6 +162,11 @@ class Denormalize:
|
|
|
115
162
|
Parameters
|
|
116
163
|
----------
|
|
117
164
|
patch : np.ndarray
|
|
118
|
-
Image
|
|
165
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
np.ndarray
|
|
170
|
+
Denormalized image patch.
|
|
119
171
|
"""
|
|
120
172
|
return patch * (self.std + self.eps) + self.mean
|
|
@@ -5,7 +5,7 @@ Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
|
5
5
|
masked pixels.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Optional, Tuple
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -15,7 +15,7 @@ from .struct_mask_parameters import StructMaskParameters
|
|
|
15
15
|
def _apply_struct_mask(
|
|
16
16
|
patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters
|
|
17
17
|
) -> np.ndarray:
|
|
18
|
-
"""
|
|
18
|
+
"""Apply structN2V masks to patch.
|
|
19
19
|
|
|
20
20
|
Each point in `coords` corresponds to the center of a mask, masks are paremeterized
|
|
21
21
|
by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
|
|
@@ -98,7 +98,7 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
|
98
98
|
|
|
99
99
|
|
|
100
100
|
def _get_stratified_coords(
|
|
101
|
-
mask_pixel_perc: float, shape:
|
|
101
|
+
mask_pixel_perc: float, shape: Tuple[int, ...]
|
|
102
102
|
) -> np.ndarray:
|
|
103
103
|
"""
|
|
104
104
|
Generate coordinates of the pixels to mask.
|
|
@@ -248,7 +248,7 @@ def uniform_manipulate(
|
|
|
248
248
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
249
249
|
remove_center : bool
|
|
250
250
|
Whether to remove the center pixel from the subpatch, by default False.
|
|
251
|
-
struct_params: Optional[StructMaskParameters]
|
|
251
|
+
struct_params : Optional[StructMaskParameters]
|
|
252
252
|
Parameters for the structN2V mask (axis and span).
|
|
253
253
|
|
|
254
254
|
Returns
|
|
@@ -322,7 +322,7 @@ def median_manipulate(
|
|
|
322
322
|
Approximate percentage of pixels to be masked.
|
|
323
323
|
subpatch_size : int
|
|
324
324
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
325
|
-
struct_params: Optional[StructMaskParameters]
|
|
325
|
+
struct_params : Optional[StructMaskParameters]
|
|
326
326
|
Parameters for the structN2V mask (axis and span).
|
|
327
327
|
|
|
328
328
|
Returns
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Class representing the parameters of structN2V masks."""
|
|
2
|
+
|
|
1
3
|
from dataclasses import dataclass
|
|
2
4
|
from typing import Literal
|
|
3
5
|
|
|
@@ -6,7 +8,7 @@ from typing import Literal
|
|
|
6
8
|
class StructMaskParameters:
|
|
7
9
|
"""Parameters of structN2V masks.
|
|
8
10
|
|
|
9
|
-
|
|
11
|
+
Attributes
|
|
10
12
|
----------
|
|
11
13
|
axis : Literal[0, 1]
|
|
12
14
|
Axis along which to apply the mask, horizontal (0) or vertical (1).
|
|
@@ -1,33 +1,24 @@
|
|
|
1
1
|
"""A general parent class for transforms."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
3
|
+
from typing import Any
|
|
6
4
|
|
|
7
5
|
|
|
8
6
|
class Transform:
|
|
9
7
|
"""A general parent class for transforms."""
|
|
10
8
|
|
|
11
|
-
def __call__(
|
|
12
|
-
|
|
13
|
-
) -> Tuple[np.ndarray, ...]:
|
|
14
|
-
"""Apply the transform to the input data.
|
|
9
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
10
|
+
"""Apply the transform.
|
|
15
11
|
|
|
16
12
|
Parameters
|
|
17
13
|
----------
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
14
|
+
*args : Any
|
|
15
|
+
Arguments.
|
|
16
|
+
**kwargs : Any
|
|
17
|
+
Keyword arguments.
|
|
22
18
|
|
|
23
19
|
Returns
|
|
24
20
|
-------
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Raises
|
|
29
|
-
------
|
|
30
|
-
NotImplementedError
|
|
31
|
-
This method should be implemented in the child class.
|
|
21
|
+
Any
|
|
22
|
+
Transformed data.
|
|
32
23
|
"""
|
|
33
|
-
|
|
24
|
+
pass
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""XY flip transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from careamics.transforms.transform import Transform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYFlip(Transform):
|
|
11
|
+
"""Flip image along X and Y axis, one at a time.
|
|
12
|
+
|
|
13
|
+
This transform randomly flips one of the last two axes.
|
|
14
|
+
|
|
15
|
+
This transform expects C(Z)YX dimensions.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
axis_indices : List[int]
|
|
20
|
+
Indices of the axes that can be flipped.
|
|
21
|
+
rng : np.random.Generator
|
|
22
|
+
Random number generator.
|
|
23
|
+
p : float
|
|
24
|
+
Probability of applying the transform.
|
|
25
|
+
seed : Optional[int]
|
|
26
|
+
Random seed.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
flip_x : bool, optional
|
|
31
|
+
Whether to flip along the X axis, by default True.
|
|
32
|
+
flip_y : bool, optional
|
|
33
|
+
Whether to flip along the Y axis, by default True.
|
|
34
|
+
p : float, optional
|
|
35
|
+
Probability of applying the transform, by default 0.5.
|
|
36
|
+
seed : Optional[int], optional
|
|
37
|
+
Random seed, by default None.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
flip_x: bool = True,
|
|
43
|
+
flip_y: bool = True,
|
|
44
|
+
p: float = 0.5,
|
|
45
|
+
seed: Optional[int] = None,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Constructor.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
flip_x : bool, optional
|
|
52
|
+
Whether to flip along the X axis, by default True.
|
|
53
|
+
flip_y : bool, optional
|
|
54
|
+
Whether to flip along the Y axis, by default True.
|
|
55
|
+
p : float
|
|
56
|
+
Probability of applying the transform, by default 0.5.
|
|
57
|
+
seed : Optional[int], optional
|
|
58
|
+
Random seed, by default None.
|
|
59
|
+
"""
|
|
60
|
+
if p < 0 or p > 1:
|
|
61
|
+
raise ValueError("Probability must be in [0, 1].")
|
|
62
|
+
|
|
63
|
+
if not flip_x and not flip_y:
|
|
64
|
+
raise ValueError("At least one axis must be flippable.")
|
|
65
|
+
|
|
66
|
+
# probability to apply the transform
|
|
67
|
+
self.p = p
|
|
68
|
+
|
|
69
|
+
# "flippable" axes
|
|
70
|
+
self.axis_indices = []
|
|
71
|
+
|
|
72
|
+
if flip_y:
|
|
73
|
+
self.axis_indices.append(-2)
|
|
74
|
+
if flip_x:
|
|
75
|
+
self.axis_indices.append(-1)
|
|
76
|
+
|
|
77
|
+
# numpy random generator
|
|
78
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
79
|
+
|
|
80
|
+
def __call__(
|
|
81
|
+
self, patch: np.ndarray, target: Optional[np.ndarray] = None
|
|
82
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
83
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
patch : np.ndarray
|
|
88
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
89
|
+
target : Optional[np.ndarray], optional
|
|
90
|
+
Target for the patch, by default None.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
95
|
+
Transformed patch and target.
|
|
96
|
+
"""
|
|
97
|
+
if self.rng.random() > self.p:
|
|
98
|
+
return patch, target
|
|
99
|
+
|
|
100
|
+
# choose an axis to flip
|
|
101
|
+
axis = self.rng.choice(self.axis_indices)
|
|
102
|
+
|
|
103
|
+
patch_transformed = self._apply(patch, axis)
|
|
104
|
+
target_transformed = self._apply(target, axis) if target is not None else None
|
|
105
|
+
|
|
106
|
+
return patch_transformed, target_transformed
|
|
107
|
+
|
|
108
|
+
def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray:
|
|
109
|
+
"""Apply the transform to the image.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
patch : np.ndarray
|
|
114
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
115
|
+
axis : int
|
|
116
|
+
Axis to flip.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
np.ndarray
|
|
121
|
+
Flipped image patch.
|
|
122
|
+
"""
|
|
123
|
+
return np.ascontiguousarray(np.flip(patch, axis=axis))
|