careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py
CHANGED
|
@@ -4,7 +4,7 @@ UNet model.
|
|
|
4
4
|
A UNet encoder, decoder and complete model.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Union
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn as nn
|
|
@@ -104,7 +104,7 @@ class UnetEncoder(nn.Module):
|
|
|
104
104
|
encoder_blocks.append(self.pooling)
|
|
105
105
|
self.encoder_blocks = nn.ModuleList(encoder_blocks)
|
|
106
106
|
|
|
107
|
-
def forward(self, x: torch.Tensor) ->
|
|
107
|
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
108
108
|
"""
|
|
109
109
|
Forward pass.
|
|
110
110
|
|
|
@@ -115,7 +115,7 @@ class UnetEncoder(nn.Module):
|
|
|
115
115
|
|
|
116
116
|
Returns
|
|
117
117
|
-------
|
|
118
|
-
|
|
118
|
+
list[torch.Tensor]
|
|
119
119
|
Output of each encoder block (skip connections) and final output of the
|
|
120
120
|
encoder.
|
|
121
121
|
"""
|
|
@@ -202,7 +202,7 @@ class UnetDecoder(nn.Module):
|
|
|
202
202
|
groups=self.groups,
|
|
203
203
|
)
|
|
204
204
|
|
|
205
|
-
decoder_blocks:
|
|
205
|
+
decoder_blocks: list[nn.Module] = []
|
|
206
206
|
for n in range(depth):
|
|
207
207
|
decoder_blocks.append(upsampling)
|
|
208
208
|
in_channels = (num_channels_init * 2 ** (depth - n)) * groups
|
|
@@ -230,7 +230,7 @@ class UnetDecoder(nn.Module):
|
|
|
230
230
|
|
|
231
231
|
Parameters
|
|
232
232
|
----------
|
|
233
|
-
*features :
|
|
233
|
+
*features : list[torch.Tensor]
|
|
234
234
|
List containing the output of each encoder block(skip connections) and final
|
|
235
235
|
output of the encoder.
|
|
236
236
|
|
|
@@ -240,7 +240,7 @@ class UnetDecoder(nn.Module):
|
|
|
240
240
|
Output of the decoder.
|
|
241
241
|
"""
|
|
242
242
|
x: torch.Tensor = features[0]
|
|
243
|
-
skip_connections:
|
|
243
|
+
skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
244
244
|
|
|
245
245
|
x = self.bottleneck(x)
|
|
246
246
|
|
|
@@ -289,10 +289,10 @@ class UnetDecoder(nn.Module):
|
|
|
289
289
|
m = A.shape[1] // groups
|
|
290
290
|
n = B.shape[1] // groups
|
|
291
291
|
|
|
292
|
-
A_groups:
|
|
292
|
+
A_groups: list[torch.Tensor] = [
|
|
293
293
|
A[:, i * m : (i + 1) * m] for i in range(groups)
|
|
294
294
|
]
|
|
295
|
-
B_groups:
|
|
295
|
+
B_groups: list[torch.Tensor] = [
|
|
296
296
|
B[:, i * n : (i + 1) * n] for i in range(groups)
|
|
297
297
|
]
|
|
298
298
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
2
|
|
|
3
|
-
from typing import Any,
|
|
3
|
+
from typing import Any, Literal, Union, overload
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from numpy.typing import NDArray
|
|
@@ -9,7 +9,7 @@ from ..config.tile_information import TileInformation
|
|
|
9
9
|
from .stitch_prediction import stitch_prediction
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def convert_outputs(predictions:
|
|
12
|
+
def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
|
|
13
13
|
"""
|
|
14
14
|
Convert the Lightning trainer outputs to the desired form.
|
|
15
15
|
|
|
@@ -25,7 +25,7 @@ def convert_outputs(predictions: List[Any], tiled: bool) -> list[NDArray]:
|
|
|
25
25
|
Returns
|
|
26
26
|
-------
|
|
27
27
|
list of numpy.ndarray or numpy.ndarray
|
|
28
|
-
|
|
28
|
+
list of arrays with the axes SC(Z)YX. If there is only 1 output it will not
|
|
29
29
|
be in a list.
|
|
30
30
|
"""
|
|
31
31
|
if len(predictions) == 0:
|
|
@@ -44,27 +44,27 @@ def convert_outputs(predictions: List[Any], tiled: bool) -> list[NDArray]:
|
|
|
44
44
|
# for mypy
|
|
45
45
|
@overload
|
|
46
46
|
def combine_batches( # numpydoc ignore=GL08
|
|
47
|
-
predictions:
|
|
48
|
-
) ->
|
|
47
|
+
predictions: list[Any], tiled: Literal[True]
|
|
48
|
+
) -> tuple[list[NDArray], list[TileInformation]]: ...
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
# for mypy
|
|
52
52
|
@overload
|
|
53
53
|
def combine_batches( # numpydoc ignore=GL08
|
|
54
|
-
predictions:
|
|
55
|
-
) ->
|
|
54
|
+
predictions: list[Any], tiled: Literal[False]
|
|
55
|
+
) -> list[NDArray]: ...
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
# for mypy
|
|
59
59
|
@overload
|
|
60
60
|
def combine_batches( # numpydoc ignore=GL08
|
|
61
|
-
predictions:
|
|
62
|
-
) -> Union[
|
|
61
|
+
predictions: list[Any], tiled: Union[bool, Literal[True], Literal[False]]
|
|
62
|
+
) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]: ...
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
def combine_batches(
|
|
66
|
-
predictions:
|
|
67
|
-
) -> Union[
|
|
66
|
+
predictions: list[Any], tiled: bool
|
|
67
|
+
) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]:
|
|
68
68
|
"""
|
|
69
69
|
If predictions are in batches, they will be combined.
|
|
70
70
|
|
|
@@ -87,8 +87,8 @@ def combine_batches(
|
|
|
87
87
|
|
|
88
88
|
|
|
89
89
|
def _combine_tiled_batches(
|
|
90
|
-
predictions:
|
|
91
|
-
) ->
|
|
90
|
+
predictions: list[tuple[NDArray, list[TileInformation]]]
|
|
91
|
+
) -> tuple[list[NDArray], list[TileInformation]]:
|
|
92
92
|
"""
|
|
93
93
|
Combine batches from tiled output.
|
|
94
94
|
|
|
@@ -109,13 +109,13 @@ def _combine_tiled_batches(
|
|
|
109
109
|
tile_infos = [
|
|
110
110
|
tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
|
|
111
111
|
]
|
|
112
|
-
prediction_tiles:
|
|
112
|
+
prediction_tiles: list[NDArray] = _combine_array_batches(
|
|
113
113
|
[preds for preds, _ in predictions]
|
|
114
114
|
)
|
|
115
115
|
return prediction_tiles, tile_infos
|
|
116
116
|
|
|
117
117
|
|
|
118
|
-
def _combine_array_batches(predictions:
|
|
118
|
+
def _combine_array_batches(predictions: list[NDArray]) -> list[NDArray]:
|
|
119
119
|
"""
|
|
120
120
|
Combine batches of arrays.
|
|
121
121
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Prediction utility functions."""
|
|
2
2
|
|
|
3
3
|
import builtins
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from numpy.typing import NDArray
|
|
@@ -11,9 +11,9 @@ from careamics.config.tile_information import TileInformation
|
|
|
11
11
|
|
|
12
12
|
# TODO: why not allow input and output of torch.tensor ?
|
|
13
13
|
def stitch_prediction(
|
|
14
|
-
tiles:
|
|
15
|
-
tile_infos:
|
|
16
|
-
) ->
|
|
14
|
+
tiles: list[np.ndarray],
|
|
15
|
+
tile_infos: list[TileInformation],
|
|
16
|
+
) -> list[np.ndarray]:
|
|
17
17
|
"""
|
|
18
18
|
Stitch tiles back together to form a full image(s).
|
|
19
19
|
|
|
@@ -54,8 +54,8 @@ def stitch_prediction(
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def stitch_prediction_single(
|
|
57
|
-
tiles:
|
|
58
|
-
tile_infos:
|
|
57
|
+
tiles: list[NDArray],
|
|
58
|
+
tile_infos: list[TileInformation],
|
|
59
59
|
) -> NDArray:
|
|
60
60
|
"""
|
|
61
61
|
Stitch tiles back together to form a full image.
|
careamics/transforms/__init__.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
"""Transforms that are used to augment the data."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"
|
|
4
|
+
"Compose",
|
|
5
|
+
"Denormalize",
|
|
6
|
+
"ImageRestorationTTA",
|
|
5
7
|
"N2VManipulate",
|
|
8
|
+
"Normalize",
|
|
6
9
|
"XYFlip",
|
|
7
10
|
"XYRandomRotate90",
|
|
8
|
-
"
|
|
9
|
-
"Denormalize",
|
|
10
|
-
"Normalize",
|
|
11
|
-
"Compose",
|
|
11
|
+
"get_all_transforms",
|
|
12
12
|
]
|
|
13
13
|
|
|
14
14
|
|
careamics/transforms/compose.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""A class chaining transforms together."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Optional, Union, cast
|
|
4
4
|
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
|
|
7
|
-
from careamics.config.transformations import
|
|
7
|
+
from careamics.config.transformations import NORM_AND_SPATIAL_UNION
|
|
8
8
|
|
|
9
9
|
from .n2v_manipulate import N2VManipulate
|
|
10
10
|
from .normalize import Normalize
|
|
@@ -20,7 +20,7 @@ ALL_TRANSFORMS = {
|
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
def get_all_transforms() ->
|
|
23
|
+
def get_all_transforms() -> dict[str, type]:
|
|
24
24
|
"""Return all the transforms accepted by CAREamics.
|
|
25
25
|
|
|
26
26
|
Returns
|
|
@@ -37,7 +37,7 @@ class Compose:
|
|
|
37
37
|
|
|
38
38
|
Parameters
|
|
39
39
|
----------
|
|
40
|
-
transform_list :
|
|
40
|
+
transform_list : list[TransformModel]
|
|
41
41
|
A list of dictionaries where each dictionary contains the name of a
|
|
42
42
|
transform and its parameters.
|
|
43
43
|
|
|
@@ -47,12 +47,12 @@ class Compose:
|
|
|
47
47
|
A callable that applies the transforms to the input data.
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
|
-
def __init__(self, transform_list:
|
|
50
|
+
def __init__(self, transform_list: list[NORM_AND_SPATIAL_UNION]) -> None:
|
|
51
51
|
"""Instantiate a Compose object.
|
|
52
52
|
|
|
53
53
|
Parameters
|
|
54
54
|
----------
|
|
55
|
-
transform_list :
|
|
55
|
+
transform_list : list[NORM_AND_SPATIAL_UNION]
|
|
56
56
|
A list of dictionaries where each dictionary contains the name of a
|
|
57
57
|
transform and its parameters.
|
|
58
58
|
"""
|
|
@@ -67,7 +67,7 @@ class Compose:
|
|
|
67
67
|
|
|
68
68
|
def _chain_transforms(
|
|
69
69
|
self, patch: NDArray, target: Optional[NDArray]
|
|
70
|
-
) ->
|
|
70
|
+
) -> tuple[Optional[NDArray], ...]:
|
|
71
71
|
"""Chain transforms on the input data.
|
|
72
72
|
|
|
73
73
|
Parameters
|
|
@@ -79,7 +79,7 @@ class Compose:
|
|
|
79
79
|
|
|
80
80
|
Returns
|
|
81
81
|
-------
|
|
82
|
-
|
|
82
|
+
tuple[np.ndarray, Optional[np.ndarray]]
|
|
83
83
|
The output of the transformations.
|
|
84
84
|
"""
|
|
85
85
|
params: Union[
|
|
@@ -103,7 +103,7 @@ class Compose:
|
|
|
103
103
|
patch: NDArray,
|
|
104
104
|
target: Optional[NDArray],
|
|
105
105
|
**additional_arrays: NDArray,
|
|
106
|
-
) ->
|
|
106
|
+
) -> tuple[NDArray, Optional[NDArray], dict[str, NDArray]]:
|
|
107
107
|
"""Chain transforms on the input data, with additional arrays.
|
|
108
108
|
|
|
109
109
|
Parameters
|
|
@@ -118,7 +118,7 @@ class Compose:
|
|
|
118
118
|
|
|
119
119
|
Returns
|
|
120
120
|
-------
|
|
121
|
-
|
|
121
|
+
tuple[np.ndarray, Optional[np.ndarray]]
|
|
122
122
|
The output of the transformations.
|
|
123
123
|
"""
|
|
124
124
|
params = {"patch": patch, "target": target, **additional_arrays}
|
|
@@ -131,7 +131,7 @@ class Compose:
|
|
|
131
131
|
|
|
132
132
|
def __call__(
|
|
133
133
|
self, patch: NDArray, target: Optional[NDArray] = None
|
|
134
|
-
) ->
|
|
134
|
+
) -> tuple[NDArray, ...]:
|
|
135
135
|
"""Apply the transforms to the input data.
|
|
136
136
|
|
|
137
137
|
Parameters
|
|
@@ -143,11 +143,11 @@ class Compose:
|
|
|
143
143
|
|
|
144
144
|
Returns
|
|
145
145
|
-------
|
|
146
|
-
|
|
146
|
+
tuple[np.ndarray, ...]
|
|
147
147
|
The output of the transformations.
|
|
148
148
|
"""
|
|
149
149
|
# TODO: solve casting Compose.__call__ ouput
|
|
150
|
-
return cast(
|
|
150
|
+
return cast(tuple[NDArray, ...], self._chain_transforms(patch, target))
|
|
151
151
|
|
|
152
152
|
def transform_with_additional_arrays(
|
|
153
153
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""N2V manipulation transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Literal, Optional
|
|
3
|
+
from typing import Any, Literal, Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from numpy.typing import NDArray
|
|
@@ -100,7 +100,7 @@ class N2VManipulate(Transform):
|
|
|
100
100
|
|
|
101
101
|
def __call__(
|
|
102
102
|
self, patch: NDArray, *args: Any, **kwargs: Any
|
|
103
|
-
) ->
|
|
103
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
104
104
|
"""Apply the transform to the image.
|
|
105
105
|
|
|
106
106
|
Parameters
|
|
@@ -114,7 +114,7 @@ class N2VManipulate(Transform):
|
|
|
114
114
|
|
|
115
115
|
Returns
|
|
116
116
|
-------
|
|
117
|
-
|
|
117
|
+
tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
118
118
|
Masked patch, original patch, and mask.
|
|
119
119
|
"""
|
|
120
120
|
masked = np.zeros_like(patch)
|
|
@@ -5,7 +5,7 @@ Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
|
5
5
|
masked pixels.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Optional
|
|
8
|
+
from typing import Optional
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -107,7 +107,7 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
|
107
107
|
|
|
108
108
|
def _get_stratified_coords(
|
|
109
109
|
mask_pixel_perc: float,
|
|
110
|
-
shape:
|
|
110
|
+
shape: tuple[int, ...],
|
|
111
111
|
rng: Optional[np.random.Generator] = None,
|
|
112
112
|
) -> np.ndarray:
|
|
113
113
|
"""
|
|
@@ -121,7 +121,7 @@ def _get_stratified_coords(
|
|
|
121
121
|
mask_pixel_perc : float
|
|
122
122
|
Actual (quasi) percentage of masked pixels across the whole image. Used in
|
|
123
123
|
calculating the distance between masked pixels across each axis.
|
|
124
|
-
shape :
|
|
124
|
+
shape : tuple[int, ...]
|
|
125
125
|
Shape of the input patch.
|
|
126
126
|
rng : np.random.Generator or None
|
|
127
127
|
Random number generator.
|
|
@@ -242,7 +242,7 @@ def uniform_manipulate(
|
|
|
242
242
|
remove_center: bool = True,
|
|
243
243
|
struct_params: Optional[StructMaskParameters] = None,
|
|
244
244
|
rng: Optional[np.random.Generator] = None,
|
|
245
|
-
) ->
|
|
245
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
246
246
|
"""
|
|
247
247
|
Manipulate pixels by replacing them with a neighbor values.
|
|
248
248
|
|
|
@@ -269,8 +269,8 @@ def uniform_manipulate(
|
|
|
269
269
|
|
|
270
270
|
Returns
|
|
271
271
|
-------
|
|
272
|
-
|
|
273
|
-
|
|
272
|
+
tuple[np.ndarray]
|
|
273
|
+
tuple containing the manipulated patch and the corresponding mask.
|
|
274
274
|
"""
|
|
275
275
|
if rng is None:
|
|
276
276
|
rng = np.random.default_rng()
|
|
@@ -322,7 +322,7 @@ def median_manipulate(
|
|
|
322
322
|
subpatch_size: int = 11,
|
|
323
323
|
struct_params: Optional[StructMaskParameters] = None,
|
|
324
324
|
rng: Optional[np.random.Generator] = None,
|
|
325
|
-
) ->
|
|
325
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
326
326
|
"""
|
|
327
327
|
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
328
328
|
|
|
@@ -348,8 +348,8 @@ def median_manipulate(
|
|
|
348
348
|
|
|
349
349
|
Returns
|
|
350
350
|
-------
|
|
351
|
-
|
|
352
|
-
|
|
351
|
+
tuple[np.ndarray]
|
|
352
|
+
tuple containing the manipulated patch, the original patch and the mask.
|
|
353
353
|
"""
|
|
354
354
|
if rng is None:
|
|
355
355
|
rng = np.random.default_rng()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Patch transform applying XY random 90 degrees rotations."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from numpy.typing import NDArray
|
|
@@ -69,7 +69,7 @@ class XYRandomRotate90(Transform):
|
|
|
69
69
|
|
|
70
70
|
Returns
|
|
71
71
|
-------
|
|
72
|
-
|
|
72
|
+
tuple[np.ndarray, Optional[np.ndarray]]
|
|
73
73
|
Transformed patch and target.
|
|
74
74
|
"""
|
|
75
75
|
if self.rng.random() > self.p:
|
|
@@ -90,7 +90,7 @@ class XYRandomRotate90(Transform):
|
|
|
90
90
|
|
|
91
91
|
return patch_transformed, target_transformed, additional_transformed
|
|
92
92
|
|
|
93
|
-
def _apply(self, patch: NDArray, n_rot: int, axes:
|
|
93
|
+
def _apply(self, patch: NDArray, n_rot: int, axes: tuple[int, int]) -> NDArray:
|
|
94
94
|
"""Apply the transform to the image.
|
|
95
95
|
|
|
96
96
|
Parameters
|
|
@@ -99,7 +99,7 @@ class XYRandomRotate90(Transform):
|
|
|
99
99
|
Image or image patch, 2D or 3D, shape C(Z)YX.
|
|
100
100
|
n_rot : int
|
|
101
101
|
Number of 90 degree rotations.
|
|
102
|
-
axes :
|
|
102
|
+
axes : tuple[int, int]
|
|
103
103
|
Axes along which to rotate the patch.
|
|
104
104
|
|
|
105
105
|
Returns
|
careamics/utils/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Utils module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"cwd",
|
|
5
|
-
"get_ram_size",
|
|
6
|
-
"check_path_exists",
|
|
7
4
|
"BaseEnum",
|
|
8
|
-
"get_logger",
|
|
9
|
-
"get_careamics_home",
|
|
10
5
|
"autocorrelation",
|
|
6
|
+
"check_path_exists",
|
|
7
|
+
"cwd",
|
|
8
|
+
"get_careamics_home",
|
|
9
|
+
"get_logger",
|
|
10
|
+
"get_ram_size",
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
|
careamics/utils/context.py
CHANGED
|
@@ -5,9 +5,10 @@ A convenience function to change the working directory in order to save data.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import os
|
|
8
|
+
from collections.abc import Iterator
|
|
8
9
|
from contextlib import contextmanager
|
|
9
10
|
from pathlib import Path
|
|
10
|
-
from typing import
|
|
11
|
+
from typing import Union
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def get_careamics_home() -> Path:
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""PyTorch lightning utilities."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def read_csv_logger(experiment_name: str, log_folder: Union[str, Path]) -> dict:
|
|
8
|
+
"""Return the loss curves from the csv logs.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
experiment_name : str
|
|
13
|
+
Name of the experiment.
|
|
14
|
+
log_folder : Path or str
|
|
15
|
+
Path to the folder containing the csv logs.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
dict
|
|
20
|
+
Dictionary containing the loss curves, with keys "train_epoch", "val_epoch",
|
|
21
|
+
"train_loss" and "val_loss".
|
|
22
|
+
"""
|
|
23
|
+
path = Path(log_folder) / experiment_name
|
|
24
|
+
|
|
25
|
+
# find the most recent of version_* folders
|
|
26
|
+
versions = [int(v.name.split("_")[-1]) for v in path.iterdir() if v.is_dir()]
|
|
27
|
+
version = max(versions)
|
|
28
|
+
|
|
29
|
+
path_log = path / f"version_{version}" / "metrics.csv"
|
|
30
|
+
|
|
31
|
+
epochs = []
|
|
32
|
+
train_losses_tmp = []
|
|
33
|
+
val_losses_tmp = []
|
|
34
|
+
with open(path_log) as f:
|
|
35
|
+
lines = f.readlines()
|
|
36
|
+
|
|
37
|
+
for single_line in lines[1:]:
|
|
38
|
+
epoch, _, train_loss, _, val_loss = single_line.strip().split(",")
|
|
39
|
+
|
|
40
|
+
epochs.append(epoch)
|
|
41
|
+
train_losses_tmp.append(train_loss)
|
|
42
|
+
val_losses_tmp.append(val_loss)
|
|
43
|
+
|
|
44
|
+
# train and val are not logged on the same row and can have different lengths
|
|
45
|
+
train_epoch = [
|
|
46
|
+
int(epochs[i]) for i in range(len(epochs)) if train_losses_tmp[i] != ""
|
|
47
|
+
]
|
|
48
|
+
val_epoch = [int(epochs[i]) for i in range(len(epochs)) if val_losses_tmp[i] != ""]
|
|
49
|
+
train_losses = [float(loss) for loss in train_losses_tmp if loss != ""]
|
|
50
|
+
val_losses = [float(loss) for loss in val_losses_tmp if loss != ""]
|
|
51
|
+
|
|
52
|
+
return {
|
|
53
|
+
"train_epoch": train_epoch,
|
|
54
|
+
"val_epoch": val_epoch,
|
|
55
|
+
"train_loss": train_losses,
|
|
56
|
+
"val_loss": val_losses,
|
|
57
|
+
}
|
careamics/utils/logging.py
CHANGED
|
@@ -7,8 +7,9 @@ The methods are responsible for the in-console logger.
|
|
|
7
7
|
import logging
|
|
8
8
|
import sys
|
|
9
9
|
import time
|
|
10
|
+
from collections.abc import Generator
|
|
10
11
|
from pathlib import Path
|
|
11
|
-
from typing import Any,
|
|
12
|
+
from typing import Any, Optional, Union
|
|
12
13
|
|
|
13
14
|
LOGGERS: dict = {}
|
|
14
15
|
|
|
@@ -84,7 +85,7 @@ class ProgressBar:
|
|
|
84
85
|
Zero-indexed current epoch, by default None.
|
|
85
86
|
num_epochs : Optional[int], optional
|
|
86
87
|
Total number of epochs, by default None.
|
|
87
|
-
stateful_metrics : Optional[
|
|
88
|
+
stateful_metrics : Optional[list], optional
|
|
88
89
|
Iterable of string names of metrics that should *not* be averaged over time.
|
|
89
90
|
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
90
91
|
the progress bar before display, by default None.
|
|
@@ -99,7 +100,7 @@ class ProgressBar:
|
|
|
99
100
|
max_value: Optional[int] = None,
|
|
100
101
|
epoch: Optional[int] = None,
|
|
101
102
|
num_epochs: Optional[int] = None,
|
|
102
|
-
stateful_metrics: Optional[
|
|
103
|
+
stateful_metrics: Optional[list] = None,
|
|
103
104
|
always_stateful: bool = False,
|
|
104
105
|
mode: str = "train",
|
|
105
106
|
) -> None:
|
|
@@ -114,7 +115,7 @@ class ProgressBar:
|
|
|
114
115
|
Zero-indexed current epoch, by default None.
|
|
115
116
|
num_epochs : Optional[int], optional
|
|
116
117
|
Total number of epochs, by default None.
|
|
117
|
-
stateful_metrics : Optional[
|
|
118
|
+
stateful_metrics : Optional[list], optional
|
|
118
119
|
Iterable of string names of metrics that should *not* be averaged over time.
|
|
119
120
|
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
120
121
|
the progress bar before display, by default None.
|
|
@@ -145,8 +146,8 @@ class ProgressBar:
|
|
|
145
146
|
self._seen_so_far = 0
|
|
146
147
|
# We use a dict + list to avoid garbage collection
|
|
147
148
|
# issues found in OrderedDict
|
|
148
|
-
self._values:
|
|
149
|
-
self._values_order:
|
|
149
|
+
self._values: dict[Any, Any] = {}
|
|
150
|
+
self._values_order: list[Any] = []
|
|
150
151
|
self._start = time.time()
|
|
151
152
|
self._last_update = 0.0
|
|
152
153
|
self.spin = self.spinning_cursor() if self.max_value is None else None
|
|
@@ -158,7 +159,7 @@ class ProgressBar:
|
|
|
158
159
|
self.message = "Denoising"
|
|
159
160
|
|
|
160
161
|
def update(
|
|
161
|
-
self, current_step: int, batch_size: int = 1, values: Optional[
|
|
162
|
+
self, current_step: int, batch_size: int = 1, values: Optional[list] = None
|
|
162
163
|
) -> None:
|
|
163
164
|
"""
|
|
164
165
|
Update the progress bar.
|
|
@@ -169,7 +170,7 @@ class ProgressBar:
|
|
|
169
170
|
Index of the current step.
|
|
170
171
|
batch_size : int, optional
|
|
171
172
|
Batch size, by default 1.
|
|
172
|
-
values : Optional[
|
|
173
|
+
values : Optional[list], optional
|
|
173
174
|
Updated metrics values, by default None.
|
|
174
175
|
"""
|
|
175
176
|
values = values or []
|
|
@@ -263,7 +264,7 @@ class ProgressBar:
|
|
|
263
264
|
|
|
264
265
|
self._last_update = now
|
|
265
266
|
|
|
266
|
-
def add(self, n: int, values: Optional[
|
|
267
|
+
def add(self, n: int, values: Optional[list] = None) -> None:
|
|
267
268
|
"""
|
|
268
269
|
Update the progress bar by n steps.
|
|
269
270
|
|
|
@@ -271,7 +272,7 @@ class ProgressBar:
|
|
|
271
272
|
----------
|
|
272
273
|
n : int
|
|
273
274
|
Number of steps to increase the progress bar with.
|
|
274
|
-
values : Optional[
|
|
275
|
+
values : Optional[list], optional
|
|
275
276
|
Updated metrics values, by default None.
|
|
276
277
|
"""
|
|
277
278
|
self.update(self._seen_so_far + n, 1, values=values)
|