careamics 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +49 -49
- careamics/cli/conf.py +6 -6
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/algorithms/vae_algorithm_model.py +4 -4
- careamics/config/callback_model.py +8 -8
- careamics/config/configuration_factories.py +49 -49
- careamics/config/data/data_model.py +7 -13
- careamics/config/data/ng_data_model.py +8 -14
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -10
- careamics/config/likelihood_model.py +2 -2
- careamics/config/nm_model.py +5 -7
- careamics/config/training_model.py +4 -4
- careamics/config/transformations/normalize_model.py +3 -3
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patching_strategies/random_patching.py +3 -4
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +12 -14
- careamics/lightning/predict_data_module.py +8 -8
- careamics/lightning/train_data_module.py +11 -11
- careamics/losses/lvae/losses.py +9 -9
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +4 -4
- careamics/models/layers.py +5 -5
- careamics/models/unet.py +16 -10
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +3 -5
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/METADATA +2 -3
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/RECORD +57 -57
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/WHEEL +0 -0
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/licenses/LICENSE +0 -0
careamics/models/unet.py
CHANGED
|
@@ -205,16 +205,23 @@ class UnetDecoder(nn.Module):
|
|
|
205
205
|
decoder_blocks: list[nn.Module] = []
|
|
206
206
|
for n in range(depth):
|
|
207
207
|
decoder_blocks.append(upsampling)
|
|
208
|
-
|
|
209
|
-
|
|
208
|
+
|
|
209
|
+
in_channels = (num_channels_init * 2 ** (depth - n - 1)) * groups
|
|
210
|
+
# final decoder block has the same number in and out features
|
|
211
|
+
out_channels = in_channels // 2 if n != depth - 1 else in_channels
|
|
212
|
+
if not (n2v2 and (n == depth - 1)):
|
|
213
|
+
in_channels = in_channels * 2 # accounting for skip connection concat
|
|
214
|
+
|
|
210
215
|
decoder_blocks.append(
|
|
211
216
|
Conv_Block(
|
|
212
217
|
conv_dim,
|
|
213
|
-
in_channels=
|
|
214
|
-
in_channels + in_channels // 2 if n > 0 else in_channels
|
|
215
|
-
),
|
|
218
|
+
in_channels=in_channels,
|
|
216
219
|
out_channels=out_channels,
|
|
217
|
-
|
|
220
|
+
# TODO: Tensorflow n2v implementation has intermediate channel
|
|
221
|
+
# multiplication for skip_skipone=True but not skip_skipone=False
|
|
222
|
+
# this needs to be benchmarked.
|
|
223
|
+
# final decoder block doesn't multiply the intermediate features
|
|
224
|
+
intermediate_channel_multiplier=2 if n != depth - 1 else 1,
|
|
218
225
|
dropout_perc=dropout,
|
|
219
226
|
activation="ReLU",
|
|
220
227
|
use_batch_norm=use_batch_norm,
|
|
@@ -241,6 +248,7 @@ class UnetDecoder(nn.Module):
|
|
|
241
248
|
"""
|
|
242
249
|
x: torch.Tensor = features[0]
|
|
243
250
|
skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
251
|
+
depth = len(skip_connections)
|
|
244
252
|
|
|
245
253
|
x = self.bottleneck(x)
|
|
246
254
|
|
|
@@ -249,10 +257,8 @@ class UnetDecoder(nn.Module):
|
|
|
249
257
|
if isinstance(module, nn.Upsample):
|
|
250
258
|
# divide index by 2 because of upsampling layers
|
|
251
259
|
skip_connection: torch.Tensor = skip_connections[i // 2]
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
x = self._interleave(x, skip_connection, self.groups)
|
|
255
|
-
else:
|
|
260
|
+
# top level skip connection not added for n2v2
|
|
261
|
+
if (not self.n2v2) or (self.n2v2 and (i // 2 < depth - 1)):
|
|
256
262
|
x = self._interleave(x, skip_connection, self.groups)
|
|
257
263
|
return x
|
|
258
264
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Module containing pytorch implementations for obtaining predictions from an LVAE."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
@@ -18,7 +18,7 @@ def lvae_predict_single_sample(
|
|
|
18
18
|
model: LVAE,
|
|
19
19
|
likelihood_obj: LikelihoodModule,
|
|
20
20
|
input: torch.Tensor,
|
|
21
|
-
) -> tuple[torch.Tensor,
|
|
21
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
22
22
|
"""
|
|
23
23
|
Generate a single sample prediction from an LVAE model, for a given input.
|
|
24
24
|
|
|
@@ -57,7 +57,7 @@ def lvae_predict_tiled_batch(
|
|
|
57
57
|
model: LVAE,
|
|
58
58
|
likelihood_obj: LikelihoodModule,
|
|
59
59
|
input: tuple[Any],
|
|
60
|
-
) -> tuple[tuple[Any],
|
|
60
|
+
) -> tuple[tuple[Any], tuple[Any] | None]:
|
|
61
61
|
# TODO: fix docstring return types, ... too many output options
|
|
62
62
|
"""
|
|
63
63
|
Generate a single sample prediction from an LVAE model, for a given input.
|
|
@@ -98,7 +98,7 @@ def lvae_predict_mmse_tiled_batch(
|
|
|
98
98
|
likelihood_obj: LikelihoodModule,
|
|
99
99
|
input: tuple[Any],
|
|
100
100
|
mmse_count: int,
|
|
101
|
-
) -> tuple[tuple[Any], tuple[Any],
|
|
101
|
+
) -> tuple[tuple[Any], tuple[Any], tuple[Any] | None]:
|
|
102
102
|
# TODO: fix docstring return types, ... hard to make readable
|
|
103
103
|
"""
|
|
104
104
|
Generate the MMSE (minimum mean squared error) prediction, for a given input.
|
|
@@ -137,7 +137,7 @@ def lvae_predict_mmse_tiled_batch(
|
|
|
137
137
|
|
|
138
138
|
input_shape = x.shape
|
|
139
139
|
output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
|
|
140
|
-
log_var:
|
|
140
|
+
log_var: torch.Tensor | None = None
|
|
141
141
|
# pre-declare empty array to fill with individual sample predictions
|
|
142
142
|
sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
|
|
143
143
|
for mmse_idx in range(mmse_count):
|
careamics/transforms/compose.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""A class chaining transforms together."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Union, cast
|
|
4
4
|
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
|
|
@@ -64,8 +64,8 @@ class Compose:
|
|
|
64
64
|
]
|
|
65
65
|
|
|
66
66
|
def _chain_transforms(
|
|
67
|
-
self, patch: NDArray, target:
|
|
68
|
-
) -> tuple[
|
|
67
|
+
self, patch: NDArray, target: NDArray | None
|
|
68
|
+
) -> tuple[NDArray | None, ...]:
|
|
69
69
|
"""Chain transforms on the input data.
|
|
70
70
|
|
|
71
71
|
Parameters
|
|
@@ -80,7 +80,7 @@ class Compose:
|
|
|
80
80
|
tuple[np.ndarray, Optional[np.ndarray]]
|
|
81
81
|
The output of the transformations.
|
|
82
82
|
"""
|
|
83
|
-
params: Union[tuple[NDArray,
|
|
83
|
+
params: Union[tuple[NDArray, NDArray | None],] = (patch, target)
|
|
84
84
|
|
|
85
85
|
for t in self.transforms:
|
|
86
86
|
*params, _ = t(*params) # ignore additional_arrays dict
|
|
@@ -92,9 +92,9 @@ class Compose:
|
|
|
92
92
|
def _chain_transforms_additional_arrays(
|
|
93
93
|
self,
|
|
94
94
|
patch: NDArray,
|
|
95
|
-
target:
|
|
95
|
+
target: NDArray | None,
|
|
96
96
|
**additional_arrays: NDArray,
|
|
97
|
-
) -> tuple[NDArray,
|
|
97
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
98
98
|
"""Chain transforms on the input data, with additional arrays.
|
|
99
99
|
|
|
100
100
|
Parameters
|
|
@@ -121,7 +121,7 @@ class Compose:
|
|
|
121
121
|
return patch, target, additional_arrays
|
|
122
122
|
|
|
123
123
|
def __call__(
|
|
124
|
-
self, patch: NDArray, target:
|
|
124
|
+
self, patch: NDArray, target: NDArray | None = None
|
|
125
125
|
) -> tuple[NDArray, ...]:
|
|
126
126
|
"""Apply the transforms to the input data.
|
|
127
127
|
|
|
@@ -143,9 +143,9 @@ class Compose:
|
|
|
143
143
|
def transform_with_additional_arrays(
|
|
144
144
|
self,
|
|
145
145
|
patch: NDArray,
|
|
146
|
-
target:
|
|
146
|
+
target: NDArray | None = None,
|
|
147
147
|
**additional_arrays: NDArray,
|
|
148
|
-
) -> tuple[NDArray,
|
|
148
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
149
149
|
"""Apply the transforms to the input data, including additional arrays.
|
|
150
150
|
|
|
151
151
|
Parameters
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""N2V manipulation transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Literal
|
|
3
|
+
from typing import Any, Literal
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from numpy.typing import NDArray
|
|
@@ -61,7 +61,7 @@ class N2VManipulate(Transform):
|
|
|
61
61
|
remove_center: bool = True,
|
|
62
62
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
63
63
|
struct_mask_span: int = 5,
|
|
64
|
-
seed:
|
|
64
|
+
seed: int | None = None,
|
|
65
65
|
):
|
|
66
66
|
"""Constructor.
|
|
67
67
|
|
|
@@ -88,7 +88,7 @@ class N2VManipulate(Transform):
|
|
|
88
88
|
self.remove_center = remove_center # TODO is this ever used?
|
|
89
89
|
|
|
90
90
|
if struct_mask_axis == SupportedStructAxis.NONE:
|
|
91
|
-
self.struct_mask:
|
|
91
|
+
self.struct_mask: StructMaskParameters | None = None
|
|
92
92
|
else:
|
|
93
93
|
self.struct_mask = StructMaskParameters(
|
|
94
94
|
axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""N2V manipulation transform for PyTorch."""
|
|
2
2
|
|
|
3
3
|
import platform
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
@@ -49,8 +49,8 @@ class N2VManipulateTorch:
|
|
|
49
49
|
def __init__(
|
|
50
50
|
self,
|
|
51
51
|
n2v_manipulate_config: N2VManipulateModel,
|
|
52
|
-
seed:
|
|
53
|
-
device:
|
|
52
|
+
seed: int | None = None,
|
|
53
|
+
device: str | None = None,
|
|
54
54
|
):
|
|
55
55
|
"""Constructor.
|
|
56
56
|
|
|
@@ -69,7 +69,7 @@ class N2VManipulateTorch:
|
|
|
69
69
|
self.remove_center = n2v_manipulate_config.remove_center
|
|
70
70
|
|
|
71
71
|
if n2v_manipulate_config.struct_mask_axis == SupportedStructAxis.NONE:
|
|
72
|
-
self.struct_mask:
|
|
72
|
+
self.struct_mask: StructMaskParameters | None = None
|
|
73
73
|
else:
|
|
74
74
|
self.struct_mask = StructMaskParameters(
|
|
75
75
|
axis=(
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Normalization and denormalization transforms for image patches."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
import numpy as np
|
|
6
4
|
from numpy.typing import NDArray
|
|
7
5
|
|
|
@@ -66,8 +64,8 @@ class Normalize(Transform):
|
|
|
66
64
|
self,
|
|
67
65
|
image_means: list[float],
|
|
68
66
|
image_stds: list[float],
|
|
69
|
-
target_means:
|
|
70
|
-
target_stds:
|
|
67
|
+
target_means: list[float] | None = None,
|
|
68
|
+
target_stds: list[float] | None = None,
|
|
71
69
|
):
|
|
72
70
|
"""Constructor.
|
|
73
71
|
|
|
@@ -92,9 +90,9 @@ class Normalize(Transform):
|
|
|
92
90
|
def __call__(
|
|
93
91
|
self,
|
|
94
92
|
patch: np.ndarray,
|
|
95
|
-
target:
|
|
93
|
+
target: NDArray | None = None,
|
|
96
94
|
**additional_arrays: NDArray,
|
|
97
|
-
) -> tuple[NDArray,
|
|
95
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
98
96
|
"""Apply the transform to the source patch and the target (optional).
|
|
99
97
|
|
|
100
98
|
Parameters
|
|
@@ -5,8 +5,6 @@ Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
|
5
5
|
masked pixels.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Optional
|
|
9
|
-
|
|
10
8
|
import numpy as np
|
|
11
9
|
|
|
12
10
|
from .struct_mask_parameters import StructMaskParameters
|
|
@@ -16,7 +14,7 @@ def _apply_struct_mask(
|
|
|
16
14
|
patch: np.ndarray,
|
|
17
15
|
coords: np.ndarray,
|
|
18
16
|
struct_params: StructMaskParameters,
|
|
19
|
-
rng:
|
|
17
|
+
rng: np.random.Generator | None = None,
|
|
20
18
|
) -> np.ndarray:
|
|
21
19
|
"""Apply structN2V masks to patch.
|
|
22
20
|
|
|
@@ -108,7 +106,7 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
|
108
106
|
def _get_stratified_coords(
|
|
109
107
|
mask_pixel_perc: float,
|
|
110
108
|
shape: tuple[int, ...],
|
|
111
|
-
rng:
|
|
109
|
+
rng: np.random.Generator | None = None,
|
|
112
110
|
) -> np.ndarray:
|
|
113
111
|
"""
|
|
114
112
|
Generate coordinates of the pixels to mask.
|
|
@@ -241,8 +239,8 @@ def uniform_manipulate(
|
|
|
241
239
|
mask_pixel_percentage: float,
|
|
242
240
|
subpatch_size: int = 11,
|
|
243
241
|
remove_center: bool = True,
|
|
244
|
-
struct_params:
|
|
245
|
-
rng:
|
|
242
|
+
struct_params: StructMaskParameters | None = None,
|
|
243
|
+
rng: np.random.Generator | None = None,
|
|
246
244
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
247
245
|
"""
|
|
248
246
|
Manipulate pixels by replacing them with a neighbor values.
|
|
@@ -321,8 +319,8 @@ def median_manipulate(
|
|
|
321
319
|
patch: np.ndarray,
|
|
322
320
|
mask_pixel_percentage: float,
|
|
323
321
|
subpatch_size: int = 11,
|
|
324
|
-
struct_params:
|
|
325
|
-
rng:
|
|
322
|
+
struct_params: StructMaskParameters | None = None,
|
|
323
|
+
rng: np.random.Generator | None = None,
|
|
326
324
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
327
325
|
"""
|
|
328
326
|
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""N2V manipulation functions for PyTorch."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
|
|
7
5
|
from .struct_mask_parameters import StructMaskParameters
|
|
@@ -11,7 +9,7 @@ def _apply_struct_mask_torch(
|
|
|
11
9
|
patch: torch.Tensor,
|
|
12
10
|
coords: torch.Tensor,
|
|
13
11
|
struct_params: StructMaskParameters,
|
|
14
|
-
rng:
|
|
12
|
+
rng: torch.Generator | None = None,
|
|
15
13
|
) -> torch.Tensor:
|
|
16
14
|
"""Apply structN2V masks to patch.
|
|
17
15
|
|
|
@@ -154,8 +152,8 @@ def uniform_manipulate_torch(
|
|
|
154
152
|
mask_pixel_percentage: float,
|
|
155
153
|
subpatch_size: int = 11,
|
|
156
154
|
remove_center: bool = True,
|
|
157
|
-
struct_params:
|
|
158
|
-
rng:
|
|
155
|
+
struct_params: StructMaskParameters | None = None,
|
|
156
|
+
rng: torch.Generator | None = None,
|
|
159
157
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
160
158
|
"""
|
|
161
159
|
Manipulate pixels by replacing them with a neighbor values.
|
|
@@ -256,8 +254,8 @@ def median_manipulate_torch(
|
|
|
256
254
|
batch: torch.Tensor,
|
|
257
255
|
mask_pixel_percentage: float,
|
|
258
256
|
subpatch_size: int = 11,
|
|
259
|
-
struct_params:
|
|
260
|
-
rng:
|
|
257
|
+
struct_params: StructMaskParameters | None = None,
|
|
258
|
+
rng: torch.Generator | None = None,
|
|
261
259
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
262
260
|
"""
|
|
263
261
|
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
careamics/transforms/xy_flip.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""XY flip transform."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
import numpy as np
|
|
6
4
|
from numpy.typing import NDArray
|
|
7
5
|
|
|
@@ -43,7 +41,7 @@ class XYFlip(Transform):
|
|
|
43
41
|
flip_x: bool = True,
|
|
44
42
|
flip_y: bool = True,
|
|
45
43
|
p: float = 0.5,
|
|
46
|
-
seed:
|
|
44
|
+
seed: int | None = None,
|
|
47
45
|
) -> None:
|
|
48
46
|
"""Constructor.
|
|
49
47
|
|
|
@@ -81,9 +79,9 @@ class XYFlip(Transform):
|
|
|
81
79
|
def __call__(
|
|
82
80
|
self,
|
|
83
81
|
patch: NDArray,
|
|
84
|
-
target:
|
|
82
|
+
target: NDArray | None = None,
|
|
85
83
|
**additional_arrays: NDArray,
|
|
86
|
-
) -> tuple[NDArray,
|
|
84
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
87
85
|
"""Apply the transform to the source patch and the target (optional).
|
|
88
86
|
|
|
89
87
|
Parameters
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Patch transform applying XY random 90 degrees rotations."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
import numpy as np
|
|
6
4
|
from numpy.typing import NDArray
|
|
7
5
|
|
|
@@ -30,7 +28,7 @@ class XYRandomRotate90(Transform):
|
|
|
30
28
|
Random seed, by default None.
|
|
31
29
|
"""
|
|
32
30
|
|
|
33
|
-
def __init__(self, p: float = 0.5, seed:
|
|
31
|
+
def __init__(self, p: float = 0.5, seed: int | None = None):
|
|
34
32
|
"""Constructor.
|
|
35
33
|
|
|
36
34
|
Parameters
|
|
@@ -52,9 +50,9 @@ class XYRandomRotate90(Transform):
|
|
|
52
50
|
def __call__(
|
|
53
51
|
self,
|
|
54
52
|
patch: NDArray,
|
|
55
|
-
target:
|
|
53
|
+
target: NDArray | None = None,
|
|
56
54
|
**additional_arrays: NDArray,
|
|
57
|
-
) -> tuple[NDArray,
|
|
55
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
58
56
|
"""Apply the transform to the source patch and the target (optional).
|
|
59
57
|
|
|
60
58
|
Parameters
|
careamics/utils/logging.py
CHANGED
|
@@ -9,7 +9,7 @@ import sys
|
|
|
9
9
|
import time
|
|
10
10
|
from collections.abc import Generator
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
from typing import Any,
|
|
12
|
+
from typing import Any, Union
|
|
13
13
|
|
|
14
14
|
LOGGERS: dict = {}
|
|
15
15
|
|
|
@@ -17,7 +17,7 @@ LOGGERS: dict = {}
|
|
|
17
17
|
def get_logger(
|
|
18
18
|
name: str,
|
|
19
19
|
log_level: int = logging.INFO,
|
|
20
|
-
log_path:
|
|
20
|
+
log_path: Union[str, Path] | None = None,
|
|
21
21
|
) -> logging.Logger:
|
|
22
22
|
"""
|
|
23
23
|
Create a python logger instance with configured handlers.
|
|
@@ -97,10 +97,10 @@ class ProgressBar:
|
|
|
97
97
|
|
|
98
98
|
def __init__(
|
|
99
99
|
self,
|
|
100
|
-
max_value:
|
|
101
|
-
epoch:
|
|
102
|
-
num_epochs:
|
|
103
|
-
stateful_metrics:
|
|
100
|
+
max_value: int | None = None,
|
|
101
|
+
epoch: int | None = None,
|
|
102
|
+
num_epochs: int | None = None,
|
|
103
|
+
stateful_metrics: list | None = None,
|
|
104
104
|
always_stateful: bool = False,
|
|
105
105
|
mode: str = "train",
|
|
106
106
|
) -> None:
|
|
@@ -159,7 +159,7 @@ class ProgressBar:
|
|
|
159
159
|
self.message = "Denoising"
|
|
160
160
|
|
|
161
161
|
def update(
|
|
162
|
-
self, current_step: int, batch_size: int = 1, values:
|
|
162
|
+
self, current_step: int, batch_size: int = 1, values: list | None = None
|
|
163
163
|
) -> None:
|
|
164
164
|
"""
|
|
165
165
|
Update the progress bar.
|
|
@@ -264,7 +264,7 @@ class ProgressBar:
|
|
|
264
264
|
|
|
265
265
|
self._last_update = now
|
|
266
266
|
|
|
267
|
-
def add(self, n: int, values:
|
|
267
|
+
def add(self, n: int, values: list | None = None) -> None:
|
|
268
268
|
"""
|
|
269
269
|
Update the progress bar by n steps.
|
|
270
270
|
|
careamics/utils/metrics.py
CHANGED
|
@@ -5,7 +5,7 @@ This module contains various metrics and a metrics tracking class.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from collections.abc import Callable
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Union
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
@@ -210,7 +210,7 @@ class RunningPSNR:
|
|
|
210
210
|
self.mse_sum += torch.nansum(elementwise_mse)
|
|
211
211
|
self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
|
|
212
212
|
|
|
213
|
-
def get(self) ->
|
|
213
|
+
def get(self) -> torch.Tensor | None:
|
|
214
214
|
"""Get the actual PSNR value given the running statistics.
|
|
215
215
|
|
|
216
216
|
Returns
|
careamics/utils/plotting.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Plotting utilities."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
import matplotlib.pyplot as plt
|
|
6
4
|
import numpy as np
|
|
7
5
|
import torch
|
|
@@ -14,7 +12,7 @@ def plot_noise_model_probability_distribution(
|
|
|
14
12
|
noise_model: GaussianMixtureNoiseModel,
|
|
15
13
|
signalBinIndex: int,
|
|
16
14
|
histogram: NDArray,
|
|
17
|
-
channel:
|
|
15
|
+
channel: str | None = None,
|
|
18
16
|
number_of_bins: int = 100,
|
|
19
17
|
) -> None:
|
|
20
18
|
"""Plot probability distribution P(x|s) for a certain ground truth signal.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: careamics
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.15
|
|
4
4
|
Summary: Toolbox for running N2V and friends.
|
|
5
5
|
Project-URL: homepage, https://careamics.github.io/
|
|
6
6
|
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
@@ -15,7 +15,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
16
16
|
Classifier: Typing :: Typed
|
|
17
17
|
Requires-Python: >=3.10
|
|
18
|
-
Requires-Dist: bioimageio-core==0.
|
|
18
|
+
Requires-Dist: bioimageio-core==0.9.0
|
|
19
19
|
Requires-Dist: matplotlib<=3.10.3
|
|
20
20
|
Requires-Dist: numpy<2.0.0
|
|
21
21
|
Requires-Dist: pillow<=11.2.1
|
|
@@ -28,7 +28,6 @@ Requires-Dist: tifffile<=2025.5.10
|
|
|
28
28
|
Requires-Dist: torch<=2.7.1,>=2.0
|
|
29
29
|
Requires-Dist: torchvision<=0.22.1
|
|
30
30
|
Requires-Dist: typer<=0.16.0,>=0.12.3
|
|
31
|
-
Requires-Dist: xarray<2025.3.0
|
|
32
31
|
Requires-Dist: zarr<3.0.0
|
|
33
32
|
Provides-Extra: czi
|
|
34
33
|
Requires-Dist: pylibczirw<6.0.0,>=4.1.2; extra == 'czi'
|