careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics submodule.
|
|
3
|
+
|
|
4
|
+
This module contains various metrics and a metrics tracking class.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from skimage.metrics import peak_signal_noise_ratio
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float:
|
|
15
|
+
"""
|
|
16
|
+
Peak Signal to Noise Ratio.
|
|
17
|
+
|
|
18
|
+
This method calls skimage.metrics.peak_signal_noise_ratio. See:
|
|
19
|
+
https://scikit-image.org/docs/dev/api/skimage.metrics.html.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
gt : NumPy array
|
|
24
|
+
Ground truth image.
|
|
25
|
+
pred : NumPy array
|
|
26
|
+
Predicted image.
|
|
27
|
+
range : float, optional
|
|
28
|
+
The images pixel range, by default 255.0.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
float
|
|
33
|
+
PSNR value.
|
|
34
|
+
"""
|
|
35
|
+
return peak_signal_noise_ratio(gt, pred, data_range=range)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _zero_mean(x: np.ndarray) -> np.ndarray:
|
|
39
|
+
"""
|
|
40
|
+
Zero the mean of an array.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
x : NumPy array
|
|
45
|
+
Input array.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
NumPy array
|
|
50
|
+
Zero-mean array.
|
|
51
|
+
"""
|
|
52
|
+
return x - np.mean(x)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _fix_range(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
|
|
56
|
+
"""
|
|
57
|
+
Adjust the range of an array based on a reference ground-truth array.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
gt : np.ndarray
|
|
62
|
+
Ground truth image.
|
|
63
|
+
x : np.ndarray
|
|
64
|
+
Input array.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
np.ndarray
|
|
69
|
+
Range-adjusted array.
|
|
70
|
+
"""
|
|
71
|
+
a = np.sum(gt * x) / (np.sum(x * x))
|
|
72
|
+
return x * a
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _fix(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
|
|
76
|
+
"""
|
|
77
|
+
Zero mean a groud truth array and adjust the range of the array.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
gt : np.ndarray
|
|
82
|
+
Ground truth image.
|
|
83
|
+
x : np.ndarray
|
|
84
|
+
Input array.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
np.ndarray
|
|
89
|
+
Zero-mean and range-adjusted array.
|
|
90
|
+
"""
|
|
91
|
+
gt_ = _zero_mean(gt)
|
|
92
|
+
return _fix_range(gt_, _zero_mean(x))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def scale_invariant_psnr(
|
|
96
|
+
gt: np.ndarray, pred: np.ndarray
|
|
97
|
+
) -> Union[float, torch.tensor]:
|
|
98
|
+
"""
|
|
99
|
+
Scale invariant PSNR.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
gt : np.ndarray
|
|
104
|
+
Ground truth image.
|
|
105
|
+
pred : np.ndarray
|
|
106
|
+
Predicted image.
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
Union[float, torch.tensor]
|
|
111
|
+
Scale invariant PSNR value.
|
|
112
|
+
"""
|
|
113
|
+
range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
|
|
114
|
+
gt_ = _zero_mean(gt) / np.std(gt)
|
|
115
|
+
return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class RunningPSNR:
|
|
119
|
+
"""Compute the running PSNR during validation step in training.
|
|
120
|
+
|
|
121
|
+
This class allows to compute the PSNR on the entire validation set
|
|
122
|
+
one batch at the time.
|
|
123
|
+
|
|
124
|
+
Attributes
|
|
125
|
+
----------
|
|
126
|
+
N : int
|
|
127
|
+
Number of elements seen so far during the epoch.
|
|
128
|
+
mse_sum : float
|
|
129
|
+
Running sum of the MSE over the N elements seen so far.
|
|
130
|
+
max : float
|
|
131
|
+
Running max value of the N target images seen so far.
|
|
132
|
+
min : float
|
|
133
|
+
Running min value of the N target images seen so far.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(self):
|
|
137
|
+
"""Constructor."""
|
|
138
|
+
self.N = None
|
|
139
|
+
self.mse_sum = None
|
|
140
|
+
self.max = self.min = None
|
|
141
|
+
self.reset()
|
|
142
|
+
|
|
143
|
+
def reset(self):
|
|
144
|
+
"""Reset the running PSNR computation.
|
|
145
|
+
|
|
146
|
+
Usually called at the end of each epoch.
|
|
147
|
+
"""
|
|
148
|
+
self.mse_sum = 0
|
|
149
|
+
self.N = 0
|
|
150
|
+
self.max = self.min = None
|
|
151
|
+
|
|
152
|
+
def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None:
|
|
153
|
+
"""Update the running PSNR statistics given a new batch.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
rec : torch.Tensor
|
|
158
|
+
Reconstructed batch.
|
|
159
|
+
tar : torch.Tensor
|
|
160
|
+
Target batch.
|
|
161
|
+
"""
|
|
162
|
+
ins_max = torch.max(tar).item()
|
|
163
|
+
ins_min = torch.min(tar).item()
|
|
164
|
+
if self.max is None:
|
|
165
|
+
assert self.min is None
|
|
166
|
+
self.max = ins_max
|
|
167
|
+
self.min = ins_min
|
|
168
|
+
else:
|
|
169
|
+
self.max = max(self.max, ins_max)
|
|
170
|
+
self.min = min(self.min, ins_min)
|
|
171
|
+
|
|
172
|
+
mse = (rec - tar) ** 2
|
|
173
|
+
elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
|
|
174
|
+
self.mse_sum += torch.nansum(elementwise_mse)
|
|
175
|
+
self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
|
|
176
|
+
|
|
177
|
+
def get(self) -> Optional[torch.Tensor]:
|
|
178
|
+
"""Get the actual PSNR value given the running statistics.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
Optional[torch.Tensor]
|
|
183
|
+
PSNR value.
|
|
184
|
+
"""
|
|
185
|
+
if self.N == 0 or self.N is None:
|
|
186
|
+
return None
|
|
187
|
+
rmse = torch.sqrt(self.mse_sum / self.N)
|
|
188
|
+
return 20 * torch.log10((self.max - self.min) / rmse)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Utility functions for paths."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def check_path_exists(path: Union[str, Path]) -> Path:
|
|
8
|
+
"""Check if a path exists. If not, raise an error.
|
|
9
|
+
|
|
10
|
+
Note that it returns `path` as a Path object.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
path : Union[str, Path]
|
|
15
|
+
Path to check.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
Path
|
|
20
|
+
Path as a Path object.
|
|
21
|
+
"""
|
|
22
|
+
path = Path(path)
|
|
23
|
+
if not path.exists():
|
|
24
|
+
raise FileNotFoundError(f"Data path {path} is incorrect or does not exist.")
|
|
25
|
+
|
|
26
|
+
return path
|
careamics/utils/ram.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Receptive field calculation for computing the tile overlap."""
|
|
2
|
+
|
|
3
|
+
# TODO better docstring and function names
|
|
4
|
+
# Adapted from: https://github.com/frgfm/torch-scan
|
|
5
|
+
|
|
6
|
+
# import math
|
|
7
|
+
# import warnings
|
|
8
|
+
# from typing import Tuple, Union
|
|
9
|
+
|
|
10
|
+
# from torch import Tensor, nn
|
|
11
|
+
# from torch.nn import Module
|
|
12
|
+
# from torch.nn.modules.batchnorm import _BatchNorm
|
|
13
|
+
# from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
|
|
14
|
+
# from torch.nn.modules.pooling import (
|
|
15
|
+
# _AdaptiveAvgPoolNd,
|
|
16
|
+
# _AdaptiveMaxPoolNd,
|
|
17
|
+
# _AvgPoolNd,
|
|
18
|
+
# _MaxPoolNd,
|
|
19
|
+
# )
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]:
|
|
23
|
+
# """Estimate the spatial receptive field of the module.
|
|
24
|
+
|
|
25
|
+
# Parameters
|
|
26
|
+
# ----------
|
|
27
|
+
# module : Module
|
|
28
|
+
# Module to estimate the receptive field.
|
|
29
|
+
# inp : Tensor
|
|
30
|
+
# Input tensor.
|
|
31
|
+
# out : Tensor
|
|
32
|
+
# Output tensor.
|
|
33
|
+
|
|
34
|
+
# Returns
|
|
35
|
+
# -------
|
|
36
|
+
# Tuple[float, float, float]
|
|
37
|
+
# Receptive field, effective stride and padding.
|
|
38
|
+
# """
|
|
39
|
+
# if isinstance(
|
|
40
|
+
# module,
|
|
41
|
+
# (
|
|
42
|
+
# nn.Identity,
|
|
43
|
+
# nn.Flatten,
|
|
44
|
+
# nn.ReLU,
|
|
45
|
+
# nn.ELU,
|
|
46
|
+
# nn.LeakyReLU,
|
|
47
|
+
# nn.ReLU6,
|
|
48
|
+
# nn.Tanh,
|
|
49
|
+
# nn.Sigmoid,
|
|
50
|
+
# _BatchNorm,
|
|
51
|
+
# nn.Dropout,
|
|
52
|
+
# nn.Linear,
|
|
53
|
+
# ),
|
|
54
|
+
# ):
|
|
55
|
+
# return 1.0, 1.0, 0.0
|
|
56
|
+
# elif isinstance(module, _ConvTransposeNd):
|
|
57
|
+
# return rf_convtransposend(module, inp, out)
|
|
58
|
+
# elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)):
|
|
59
|
+
# return rf_aggregnd(module, inp, out)
|
|
60
|
+
# elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):
|
|
61
|
+
# return rf_adaptive_poolnd(module, inp, out)
|
|
62
|
+
# else:
|
|
63
|
+
# warnings.warn(
|
|
64
|
+
# f"Module type not supported: {module.__class__.__name__}", stacklevel=1
|
|
65
|
+
# )
|
|
66
|
+
# return 1.0, 1.0, 0.0
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# def rf_convtransposend(
|
|
70
|
+
# module: _ConvTransposeNd, _: Tensor, __: Tensor
|
|
71
|
+
# ) -> Tuple[float, float, float]:
|
|
72
|
+
# k = (
|
|
73
|
+
# module.kernel_size[0]
|
|
74
|
+
# if isinstance(module.kernel_size, tuple)
|
|
75
|
+
# else module.kernel_size
|
|
76
|
+
# )
|
|
77
|
+
# s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
|
|
78
|
+
# return -k, 1.0 / s, 0.0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# def rf_aggregnd(
|
|
82
|
+
# module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor
|
|
83
|
+
# ) -> Tuple[float, float, float]:
|
|
84
|
+
# k = (
|
|
85
|
+
# module.kernel_size[0]
|
|
86
|
+
# if isinstance(module.kernel_size, tuple)
|
|
87
|
+
# else module.kernel_size
|
|
88
|
+
# )
|
|
89
|
+
# if hasattr(module, "dilation"):
|
|
90
|
+
# d = (
|
|
91
|
+
# module.dilation[0]
|
|
92
|
+
# if isinstance(module.dilation, tuple)
|
|
93
|
+
# else module.dilation
|
|
94
|
+
# )
|
|
95
|
+
# k = d * (k - 1) + 1
|
|
96
|
+
# s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
|
|
97
|
+
# p = module.padding[0] if isinstance(module.padding, tuple) else module.padding
|
|
98
|
+
# return k, s, p # type: ignore[return-value]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# def rf_adaptive_poolnd(
|
|
102
|
+
# _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor
|
|
103
|
+
# ) -> Tuple[int, int, float]:
|
|
104
|
+
# stride = math.ceil(inp.shape[-1] / out.shape[-1])
|
|
105
|
+
# kernel_size = stride
|
|
106
|
+
# padding = (inp.shape[-1] - kernel_size * stride) / 2
|
|
107
|
+
|
|
108
|
+
# return kernel_size, stride, padding
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convenience functions using torch.
|
|
3
|
+
|
|
4
|
+
These functions are used to control certain aspects and behaviours of PyTorch.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
from typing import Dict, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from careamics.config.support import SupportedOptimizer, SupportedScheduler
|
|
13
|
+
|
|
14
|
+
from ..utils.logging import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__) # TODO are logger still needed?
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def filter_parameters(
|
|
20
|
+
func: type,
|
|
21
|
+
user_params: dict,
|
|
22
|
+
) -> dict:
|
|
23
|
+
"""
|
|
24
|
+
Filter parameters according to the function signature.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
func : type
|
|
29
|
+
Class object.
|
|
30
|
+
user_params : Dict
|
|
31
|
+
User provided parameters.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
Dict
|
|
36
|
+
Parameters matching `func`'s signature.
|
|
37
|
+
"""
|
|
38
|
+
# Get the list of all default parameters
|
|
39
|
+
default_params = list(inspect.signature(func).parameters.keys())
|
|
40
|
+
|
|
41
|
+
# Filter matching parameters
|
|
42
|
+
params_to_be_used = set(user_params.keys()) & set(default_params)
|
|
43
|
+
|
|
44
|
+
return {key: user_params[key] for key in params_to_be_used}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_optimizer(name: str) -> torch.optim.Optimizer:
|
|
48
|
+
"""
|
|
49
|
+
Return the optimizer class given its name.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
name : str
|
|
54
|
+
Optimizer name.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
torch.nn.Optimizer
|
|
59
|
+
Optimizer class.
|
|
60
|
+
"""
|
|
61
|
+
if name not in SupportedOptimizer:
|
|
62
|
+
raise NotImplementedError(f"Optimizer {name} is not yet supported.")
|
|
63
|
+
|
|
64
|
+
return getattr(torch.optim, name)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_optimizers() -> Dict[str, str]:
|
|
68
|
+
"""
|
|
69
|
+
Return the list of all optimizers available in torch.optim.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Dict
|
|
74
|
+
Optimizers available in torch.optim.
|
|
75
|
+
"""
|
|
76
|
+
optims = {}
|
|
77
|
+
for name, obj in inspect.getmembers(torch.optim):
|
|
78
|
+
if inspect.isclass(obj) and issubclass(obj, torch.optim.Optimizer):
|
|
79
|
+
if name != "Optimizer":
|
|
80
|
+
optims[name] = name
|
|
81
|
+
return optims
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_scheduler(
|
|
85
|
+
name: str,
|
|
86
|
+
) -> Union[
|
|
87
|
+
torch.optim.lr_scheduler.LRScheduler,
|
|
88
|
+
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
89
|
+
]:
|
|
90
|
+
"""
|
|
91
|
+
Return the scheduler class given its name.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
name : str
|
|
96
|
+
Scheduler name.
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
Union
|
|
101
|
+
Scheduler class.
|
|
102
|
+
"""
|
|
103
|
+
if name not in SupportedScheduler:
|
|
104
|
+
raise NotImplementedError(f"Scheduler {name} is not yet supported.")
|
|
105
|
+
|
|
106
|
+
return getattr(torch.optim.lr_scheduler, name)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_schedulers() -> Dict[str, str]:
|
|
110
|
+
"""
|
|
111
|
+
Return the list of all schedulers available in torch.optim.lr_scheduler.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
Dict
|
|
116
|
+
Schedulers available in torch.optim.lr_scheduler.
|
|
117
|
+
"""
|
|
118
|
+
schedulers = {}
|
|
119
|
+
for name, obj in inspect.getmembers(torch.optim.lr_scheduler):
|
|
120
|
+
if inspect.isclass(obj) and issubclass(
|
|
121
|
+
obj, torch.optim.lr_scheduler.LRScheduler
|
|
122
|
+
):
|
|
123
|
+
if "LRScheduler" not in name:
|
|
124
|
+
schedulers[name] = name
|
|
125
|
+
elif name == "ReduceLROnPlateau": # somewhat not a subclass of LRScheduler
|
|
126
|
+
schedulers[name] = name
|
|
127
|
+
return schedulers
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: careamics
|
|
3
|
+
Version: 0.0.3
|
|
4
|
+
Summary: Toolbox for running N2V and friends.
|
|
5
|
+
Project-URL: homepage, https://careamics.github.io/
|
|
6
|
+
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
7
|
+
Author-email: CAREamics team <rse@fht.org>, Ashesh <ashesh.ashesh@fht.org>, Federico Carrara <federico.carrara@fht.org>, Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Vera Galinova <vera.galinova@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
|
|
8
|
+
License: BSD-3-Clause
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Typing :: Typed
|
|
18
|
+
Requires-Python: >=3.9
|
|
19
|
+
Requires-Dist: bioimageio-core>=0.6.0
|
|
20
|
+
Requires-Dist: numpy<2.0.0
|
|
21
|
+
Requires-Dist: psutil
|
|
22
|
+
Requires-Dist: pydantic>=2.5
|
|
23
|
+
Requires-Dist: pytorch-lightning>=2.2.0
|
|
24
|
+
Requires-Dist: pyyaml
|
|
25
|
+
Requires-Dist: scikit-image<=0.23.2
|
|
26
|
+
Requires-Dist: tifffile
|
|
27
|
+
Requires-Dist: torch>=2.0.0
|
|
28
|
+
Requires-Dist: torchvision
|
|
29
|
+
Requires-Dist: zarr<3.0.0
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
32
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
33
|
+
Requires-Dist: pytest-cov; extra == 'dev'
|
|
34
|
+
Requires-Dist: sybil; extra == 'dev'
|
|
35
|
+
Provides-Extra: examples
|
|
36
|
+
Requires-Dist: careamics-portfolio; extra == 'examples'
|
|
37
|
+
Requires-Dist: jupyter; extra == 'examples'
|
|
38
|
+
Requires-Dist: matplotlib; extra == 'examples'
|
|
39
|
+
Provides-Extra: tensorboard
|
|
40
|
+
Requires-Dist: protobuf==3.20.3; extra == 'tensorboard'
|
|
41
|
+
Requires-Dist: tensorboard; extra == 'tensorboard'
|
|
42
|
+
Provides-Extra: wandb
|
|
43
|
+
Requires-Dist: wandb; extra == 'wandb'
|
|
44
|
+
Description-Content-Type: text/markdown
|
|
45
|
+
|
|
46
|
+
<p align="center">
|
|
47
|
+
<a href="https://careamics.github.io/">
|
|
48
|
+
<img src="https://raw.githubusercontent.com/CAREamics/.github/main/profile/images/banner_careamics.png">
|
|
49
|
+
</a>
|
|
50
|
+
</p>
|
|
51
|
+
|
|
52
|
+
# CAREamics
|
|
53
|
+
|
|
54
|
+
[](https://github.com/CAREamics/careamics/blob/main/LICENSE)
|
|
55
|
+
[](https://pypi.org/project/careamics)
|
|
56
|
+
[](https://python.org)
|
|
57
|
+
[](https://github.com/CAREamics/careamics/actions/workflows/ci.yml)
|
|
58
|
+
[](https://codecov.io/gh/CAREamics/careamics)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
CAREamics is a PyTorch library aimed at simplifying the use of Noise2Void and its many
|
|
62
|
+
variants and cousins (CARE, Noise2Noise, N2V2, P(P)N2V, HDN, muSplit etc.).
|
|
63
|
+
|
|
64
|
+
## Why CAREamics?
|
|
65
|
+
|
|
66
|
+
Noise2Void is a widely used denoising algorithm, and is readily available from the `n2v`
|
|
67
|
+
python package. However, `n2v` is based on TensorFlow, while more recent methods
|
|
68
|
+
denoising methods (PPN2V, DivNoising, HDN) are all implemented in PyTorch, but are
|
|
69
|
+
lacking the extra features that would make them usable by the community.
|
|
70
|
+
|
|
71
|
+
The aim of CAREamics is to provide a PyTorch library reuniting all the latest methods
|
|
72
|
+
in one package, while providing a simple and consistent API. The library relies on
|
|
73
|
+
PyTorch Lightning as a back-end. In addition, we will provide extensive documentation and
|
|
74
|
+
tutorials on how to best apply these methods in a scientific context.
|
|
75
|
+
|
|
76
|
+
## Installation and use
|
|
77
|
+
|
|
78
|
+
Check out the [documentation](https://careamics.github.io/) for installation instructions and guides!
|