careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
11
|
+
|
|
12
|
+
# TODO this module shouldn't be in lvae folder
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def noise_model_factory(
|
|
16
|
+
model_config: Optional[MultiChannelNMConfig],
|
|
17
|
+
) -> Optional[MultiChannelNoiseModel]:
|
|
18
|
+
"""Noise model factory.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
model_config : Optional[MultiChannelNMConfig]
|
|
23
|
+
Noise model configuration, a `MultiChannelNMConfig` config that defines
|
|
24
|
+
noise models for the different output channels.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
Optional[MultiChannelNoiseModel]
|
|
29
|
+
A noise model instance.
|
|
30
|
+
|
|
31
|
+
Raises
|
|
32
|
+
------
|
|
33
|
+
NotImplementedError
|
|
34
|
+
If the chosen noise model `model_type` is not implemented.
|
|
35
|
+
Currently only `GaussianMixtureNoiseModel` is implemented.
|
|
36
|
+
"""
|
|
37
|
+
if model_config:
|
|
38
|
+
noise_models = []
|
|
39
|
+
for nm_config in model_config.noise_models:
|
|
40
|
+
if nm_config.path:
|
|
41
|
+
if nm_config.model_type == "GaussianMixtureNoiseModel":
|
|
42
|
+
noise_models.append(GaussianMixtureNoiseModel(nm_config))
|
|
43
|
+
else:
|
|
44
|
+
raise NotImplementedError(
|
|
45
|
+
f"Model {nm_config.model_type} is not implemented"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
else: # TODO this means signal/obs are provided. Controlled in pydantic model
|
|
49
|
+
# TODO train a new model. Config should always be provided?
|
|
50
|
+
if nm_config.model_type == "GaussianMixtureNoiseModel":
|
|
51
|
+
trained_nm = train_gm_noise_model(nm_config)
|
|
52
|
+
noise_models.append(trained_nm)
|
|
53
|
+
else:
|
|
54
|
+
raise NotImplementedError(
|
|
55
|
+
f"Model {nm_config.model_type} is not implemented"
|
|
56
|
+
)
|
|
57
|
+
return MultiChannelNoiseModel(noise_models)
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def train_gm_noise_model(
|
|
62
|
+
model_config: GaussianMixtureNMConfig,
|
|
63
|
+
) -> GaussianMixtureNoiseModel:
|
|
64
|
+
"""Train a Gaussian mixture noise model.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
model_config : GaussianMixtureNoiseModel
|
|
69
|
+
_description_
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
_description_
|
|
74
|
+
"""
|
|
75
|
+
# TODO where to put train params?
|
|
76
|
+
# TODO any training params ? Different channels ?
|
|
77
|
+
noise_model = GaussianMixtureNoiseModel(model_config)
|
|
78
|
+
# TODO revisit config unpacking
|
|
79
|
+
noise_model.train_noise_model(noise_model.signal, noise_model.observation)
|
|
80
|
+
return noise_model
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class MultiChannelNoiseModel(nn.Module):
|
|
84
|
+
def __init__(self, nmodels: list[GaussianMixtureNoiseModel]):
|
|
85
|
+
"""Constructor.
|
|
86
|
+
|
|
87
|
+
To handle noise models and the relative likelihood computation for multiple
|
|
88
|
+
output channels (e.g., muSplit, denoiseSplit).
|
|
89
|
+
|
|
90
|
+
This class:
|
|
91
|
+
- receives as input a variable number of noise models, one for each channel.
|
|
92
|
+
- computes the likelihood of observations given signals for each channel.
|
|
93
|
+
- returns the concatenation of these likelihoods.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
nmodels : list[GaussianMixtureNoiseModel]
|
|
98
|
+
List of noise models, one for each output channel.
|
|
99
|
+
"""
|
|
100
|
+
super().__init__()
|
|
101
|
+
for i, nmodel in enumerate(nmodels):
|
|
102
|
+
if nmodel is not None:
|
|
103
|
+
self.add_module(
|
|
104
|
+
f"nmodel_{i}", nmodel
|
|
105
|
+
) # TODO: wouldn't be easier to use a list?
|
|
106
|
+
|
|
107
|
+
self._nm_cnt = 0
|
|
108
|
+
for nmodel in nmodels:
|
|
109
|
+
if nmodel is not None:
|
|
110
|
+
self._nm_cnt += 1
|
|
111
|
+
|
|
112
|
+
print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
|
|
113
|
+
|
|
114
|
+
def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
|
|
115
|
+
"""Compute the likelihood of observations given signals for each channel.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
obs : torch.Tensor
|
|
120
|
+
Noisy observations, i.e., the target(s). Specifically, the input noisy
|
|
121
|
+
image for HDN, or the noisy unmixed images used for supervision
|
|
122
|
+
for denoiSplit. Shape: (B, C, [Z], Y, X), where C is the number of
|
|
123
|
+
unmixed channels.
|
|
124
|
+
signal : torch.Tensor
|
|
125
|
+
Underlying signals, i.e., the (clean) output of the model. Specifically, the
|
|
126
|
+
denoised image for HDN, or the unmixed images for denoiSplit.
|
|
127
|
+
Shape: (B, C, [Z], Y, X), where C is the number of unmixed channels.
|
|
128
|
+
"""
|
|
129
|
+
# Case 1: obs and signal have a single channel (e.g., denoising)
|
|
130
|
+
if obs.shape[1] == 1:
|
|
131
|
+
assert signal.shape[1] == 1
|
|
132
|
+
return self.nmodel_0.likelihood(obs, signal)
|
|
133
|
+
|
|
134
|
+
# Case 2: obs and signal have multiple channels (e.g., denoiSplit)
|
|
135
|
+
assert obs.shape[1] == self._nm_cnt, (
|
|
136
|
+
"The number of channels in `obs` must match the number of noise models."
|
|
137
|
+
f" Got instead: obs={obs.shape[1]}, nm={self._nm_cnt}"
|
|
138
|
+
)
|
|
139
|
+
ll_list = []
|
|
140
|
+
for ch_idx in range(obs.shape[1]):
|
|
141
|
+
nmodel = getattr(self, f"nmodel_{ch_idx}")
|
|
142
|
+
ll_list.append(
|
|
143
|
+
nmodel.likelihood(
|
|
144
|
+
obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
|
|
145
|
+
) # slicing to keep the channel dimension
|
|
146
|
+
)
|
|
147
|
+
return torch.cat(ll_list, dim=1)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# TODO: is this needed?
|
|
151
|
+
def fastShuffle(series, num):
|
|
152
|
+
"""_summary_.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
series : _type_
|
|
157
|
+
_description_
|
|
158
|
+
num : _type_
|
|
159
|
+
_description_
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
_type_
|
|
164
|
+
_description_
|
|
165
|
+
"""
|
|
166
|
+
length = series.shape[0]
|
|
167
|
+
for _ in range(num):
|
|
168
|
+
series = series[np.random.permutation(length), :]
|
|
169
|
+
return series
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class GaussianMixtureNoiseModel(nn.Module):
|
|
173
|
+
"""Define a noise model parameterized as a mixture of gaussians.
|
|
174
|
+
|
|
175
|
+
If `config.path` is not provided a new object is initialized from scratch.
|
|
176
|
+
Otherwise, a model is loaded from `config.path`.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
config : GaussianMixtureNMConfig
|
|
181
|
+
A `pydantic` model that defines the configuration of the GMM noise model.
|
|
182
|
+
|
|
183
|
+
Attributes
|
|
184
|
+
----------
|
|
185
|
+
min_signal : float
|
|
186
|
+
Minimum signal intensity expected in the image.
|
|
187
|
+
max_signal : float
|
|
188
|
+
Maximum signal intensity expected in the image.
|
|
189
|
+
path: Union[str, Path]
|
|
190
|
+
Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
|
|
191
|
+
weight : torch.nn.Parameter
|
|
192
|
+
A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
193
|
+
describing the GMM noise model, with each row corresponding to one
|
|
194
|
+
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
195
|
+
Specifically, rows are organized as follows:
|
|
196
|
+
- first n_gaussian rows correspond to the means
|
|
197
|
+
- next n_gaussian rows correspond to the weights
|
|
198
|
+
- last n_gaussian rows correspond to the standard deviations
|
|
199
|
+
If `weight=None`, the weight array is initialized using the `min_signal`
|
|
200
|
+
and `max_signal` parameters.
|
|
201
|
+
n_gaussian: int
|
|
202
|
+
Number of gaussians in the mixture.
|
|
203
|
+
n_coeff: int
|
|
204
|
+
Number of coefficients to describe the functional relationship between gaussian
|
|
205
|
+
parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
|
|
206
|
+
relationship and so on.
|
|
207
|
+
device: device
|
|
208
|
+
GPU device.
|
|
209
|
+
min_sigma: float
|
|
210
|
+
All values of `standard deviation` below this are clamped to this value.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
# TODO training a NM relies on getting a clean data(N2V e.g,)
|
|
214
|
+
def __init__(self, config: GaussianMixtureNMConfig):
|
|
215
|
+
super().__init__()
|
|
216
|
+
self._learnable = False
|
|
217
|
+
|
|
218
|
+
if config.path is None:
|
|
219
|
+
# TODO this is (probably) to train a nm. We leave it for later refactoring
|
|
220
|
+
weight = config.weight
|
|
221
|
+
n_gaussian = config.n_gaussian
|
|
222
|
+
n_coeff = config.n_coeff
|
|
223
|
+
min_signal = config.min_signal
|
|
224
|
+
max_signal = config.max_signal
|
|
225
|
+
# self.device = kwargs.get('device')
|
|
226
|
+
# TODO min_sigma cant be None ?
|
|
227
|
+
self.min_sigma = config.min_sigma
|
|
228
|
+
if weight is None:
|
|
229
|
+
weight = np.random.randn(n_gaussian * 3, n_coeff)
|
|
230
|
+
weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
|
|
231
|
+
weight = torch.from_numpy(weight.astype(np.float32)).float().cuda()
|
|
232
|
+
weight.requires_grad = True
|
|
233
|
+
|
|
234
|
+
self.n_gaussian = weight.shape[0] // 3
|
|
235
|
+
self.n_coeff = weight.shape[1]
|
|
236
|
+
self.weight = weight
|
|
237
|
+
self.min_signal = torch.Tensor([min_signal])
|
|
238
|
+
self.max_signal = torch.Tensor([max_signal])
|
|
239
|
+
self.tol = torch.Tensor([1e-10])
|
|
240
|
+
else:
|
|
241
|
+
params = np.load(config.path)
|
|
242
|
+
# self.device = kwargs.get('device')
|
|
243
|
+
|
|
244
|
+
self.min_signal = torch.Tensor(params["min_signal"])
|
|
245
|
+
self.max_signal = torch.Tensor(params["max_signal"])
|
|
246
|
+
|
|
247
|
+
self.weight = torch.nn.Parameter(
|
|
248
|
+
torch.Tensor(params["trained_weight"]), requires_grad=False
|
|
249
|
+
)
|
|
250
|
+
self.min_sigma = params["min_sigma"].item()
|
|
251
|
+
self.n_gaussian = self.weight.shape[0] // 3
|
|
252
|
+
self.n_coeff = self.weight.shape[1]
|
|
253
|
+
self.tol = torch.Tensor([1e-10]) # .to(self.device)
|
|
254
|
+
self.min_signal = torch.Tensor([self.min_signal]) # .to(self.device)
|
|
255
|
+
self.max_signal = torch.Tensor([self.max_signal]) # .to(self.device)
|
|
256
|
+
|
|
257
|
+
print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
|
|
258
|
+
|
|
259
|
+
def make_learnable(self):
|
|
260
|
+
print(f"[{self.__class__.__name__}] Making noise model learnable")
|
|
261
|
+
self._learnable = True
|
|
262
|
+
self.weight.requires_grad = True
|
|
263
|
+
|
|
264
|
+
def to_device(self, cuda_tensor):
|
|
265
|
+
# TODO wtf is this ?
|
|
266
|
+
# move everything to GPU
|
|
267
|
+
if self.min_signal.device != cuda_tensor.device:
|
|
268
|
+
self.max_signal = self.max_signal.cuda()
|
|
269
|
+
self.min_signal = self.min_signal.cuda()
|
|
270
|
+
self.tol = self.tol.cuda()
|
|
271
|
+
# self.weight = self.weight.cuda()
|
|
272
|
+
if self._learnable:
|
|
273
|
+
self.weight.requires_grad = True
|
|
274
|
+
|
|
275
|
+
def polynomialRegressor(self, weightParams, signals):
|
|
276
|
+
"""Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
weightParams : torch.cuda.FloatTensor
|
|
281
|
+
Corresponds to specific rows of the `self.weight`
|
|
282
|
+
signals : torch.cuda.FloatTensor
|
|
283
|
+
Signals
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
value : torch.cuda.FloatTensor
|
|
288
|
+
Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
|
|
289
|
+
"""
|
|
290
|
+
value = 0
|
|
291
|
+
for i in range(weightParams.shape[0]):
|
|
292
|
+
value += weightParams[i] * (
|
|
293
|
+
((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
|
|
294
|
+
)
|
|
295
|
+
return value
|
|
296
|
+
|
|
297
|
+
def normalDens(
|
|
298
|
+
self, x: torch.Tensor, m_: torch.Tensor = 0.0, std_: torch.Tensor = None
|
|
299
|
+
) -> torch.Tensor:
|
|
300
|
+
"""Evaluates the normal probability density at `x` given the mean `m` and
|
|
301
|
+
standard deviation `std`.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
x: torch.Tensor
|
|
306
|
+
Observations (i.e., noisy image).
|
|
307
|
+
m_: torch.Tensor
|
|
308
|
+
Pixel-wise mean.
|
|
309
|
+
std_: torch.Tensor
|
|
310
|
+
Pixel-wise standard deviation.
|
|
311
|
+
|
|
312
|
+
Returns
|
|
313
|
+
-------
|
|
314
|
+
tmp: torch.Tensor
|
|
315
|
+
Normal probability density of `x` given `m_` and `std_`
|
|
316
|
+
"""
|
|
317
|
+
tmp = -((x - m_) ** 2)
|
|
318
|
+
tmp = tmp / (2.0 * std_ * std_)
|
|
319
|
+
tmp = torch.exp(tmp)
|
|
320
|
+
tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
|
|
321
|
+
return tmp
|
|
322
|
+
|
|
323
|
+
def likelihood(
|
|
324
|
+
self, observations: torch.Tensor, signals: torch.Tensor
|
|
325
|
+
) -> torch.Tensor:
|
|
326
|
+
"""Evaluate the likelihood of observations given the signals and the
|
|
327
|
+
corresponding gaussian parameters.
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
observations : torch.cuda.FloatTensor
|
|
332
|
+
Noisy observations.
|
|
333
|
+
signals : torch.cuda.FloatTensor
|
|
334
|
+
Underlying signals.
|
|
335
|
+
|
|
336
|
+
Returns
|
|
337
|
+
-------
|
|
338
|
+
value :p + self.tol
|
|
339
|
+
Likelihood of observations given the signals and the GMM noise model
|
|
340
|
+
"""
|
|
341
|
+
self.to_device(signals) # move al needed stuff to the same device as `signals``
|
|
342
|
+
gaussianParameters = self.getGaussianParameters(signals)
|
|
343
|
+
p = 0
|
|
344
|
+
for gaussian in range(self.n_gaussian):
|
|
345
|
+
p += (
|
|
346
|
+
self.normalDens(
|
|
347
|
+
x=observations,
|
|
348
|
+
m_=gaussianParameters[gaussian],
|
|
349
|
+
std_=gaussianParameters[self.n_gaussian + gaussian],
|
|
350
|
+
)
|
|
351
|
+
* gaussianParameters[2 * self.n_gaussian + gaussian]
|
|
352
|
+
)
|
|
353
|
+
return p + self.tol
|
|
354
|
+
|
|
355
|
+
def getGaussianParameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
|
|
356
|
+
"""Returns the noise model for given signals.
|
|
357
|
+
|
|
358
|
+
Parameters
|
|
359
|
+
----------
|
|
360
|
+
signals : torch.Tensor
|
|
361
|
+
Underlying signals
|
|
362
|
+
|
|
363
|
+
Returns
|
|
364
|
+
-------
|
|
365
|
+
gmmParams: list[torch.Tensor]
|
|
366
|
+
A list containing tensors representing `mu`, `sigma` and `alpha`
|
|
367
|
+
parameters for the `n_gaussian` gaussians in the mixture.
|
|
368
|
+
|
|
369
|
+
"""
|
|
370
|
+
gmmParams = []
|
|
371
|
+
mu = []
|
|
372
|
+
sigma = []
|
|
373
|
+
alpha = []
|
|
374
|
+
kernels = self.weight.shape[0] // 3
|
|
375
|
+
for num in range(kernels):
|
|
376
|
+
# For each Gaussian in the mixture, evaluate mean, std and weight
|
|
377
|
+
mu.append(self.polynomialRegressor(self.weight[num, :], signals))
|
|
378
|
+
|
|
379
|
+
expval = torch.exp(self.weight[kernels + num, :])
|
|
380
|
+
# TODO: why taking the exp? it is not in PPN2V paper...
|
|
381
|
+
sigmaTemp = self.polynomialRegressor(expval, signals)
|
|
382
|
+
sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
|
|
383
|
+
sigma.append(torch.sqrt(sigmaTemp))
|
|
384
|
+
|
|
385
|
+
expval = torch.exp(
|
|
386
|
+
self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
|
|
387
|
+
+ self.tol
|
|
388
|
+
)
|
|
389
|
+
alpha.append(expval) # NOTE: these are the numerators of weights
|
|
390
|
+
|
|
391
|
+
sum_alpha = 0
|
|
392
|
+
for al in range(kernels):
|
|
393
|
+
sum_alpha = alpha[al] + sum_alpha
|
|
394
|
+
|
|
395
|
+
# sum of alpha is forced to be 1.
|
|
396
|
+
for ker in range(kernels):
|
|
397
|
+
alpha[ker] = alpha[ker] / sum_alpha
|
|
398
|
+
|
|
399
|
+
sum_means = 0
|
|
400
|
+
# sum_means is the alpha weighted average of the means
|
|
401
|
+
for ker in range(kernels):
|
|
402
|
+
sum_means = alpha[ker] * mu[ker] + sum_means
|
|
403
|
+
|
|
404
|
+
# subtracting the alpha weighted average of the means from the means
|
|
405
|
+
# ensures that the GMM has the inclination to have the mean=signals.
|
|
406
|
+
# TODO: I don't understand why we need to learn the mean?
|
|
407
|
+
for ker in range(kernels):
|
|
408
|
+
mu[ker] = mu[ker] - sum_means + signals
|
|
409
|
+
|
|
410
|
+
for i in range(kernels):
|
|
411
|
+
gmmParams.append(mu[i])
|
|
412
|
+
for j in range(kernels):
|
|
413
|
+
gmmParams.append(sigma[j])
|
|
414
|
+
for k in range(kernels):
|
|
415
|
+
gmmParams.append(alpha[k])
|
|
416
|
+
|
|
417
|
+
return gmmParams
|
|
418
|
+
|
|
419
|
+
# TODO: this is to train the noise model
|
|
420
|
+
def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
|
|
421
|
+
"""Returns the Signal-Observation pixel intensities as a two-column array.
|
|
422
|
+
|
|
423
|
+
Parameters
|
|
424
|
+
----------
|
|
425
|
+
signal : numpy array
|
|
426
|
+
Clean Signal Data
|
|
427
|
+
observation: numpy array
|
|
428
|
+
Noisy observation Data
|
|
429
|
+
lowerClip: float
|
|
430
|
+
Lower percentile bound for clipping.
|
|
431
|
+
upperClip: float
|
|
432
|
+
Upper percentile bound for clipping.
|
|
433
|
+
|
|
434
|
+
Returns
|
|
435
|
+
-------
|
|
436
|
+
gmmParams: list of torch floats
|
|
437
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
438
|
+
"""
|
|
439
|
+
lb = np.percentile(signal, lowerClip)
|
|
440
|
+
ub = np.percentile(signal, upperClip)
|
|
441
|
+
stepsize = observation[0].size
|
|
442
|
+
n_observations = observation.shape[0]
|
|
443
|
+
n_signals = signal.shape[0]
|
|
444
|
+
sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
|
|
445
|
+
|
|
446
|
+
for i in range(n_observations):
|
|
447
|
+
j = i // (n_observations // n_signals)
|
|
448
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
|
|
449
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
|
|
450
|
+
sig_obs_pairs = sig_obs_pairs[
|
|
451
|
+
(sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
|
|
452
|
+
]
|
|
453
|
+
return fastShuffle(sig_obs_pairs, 2)
|
|
454
|
+
|
|
455
|
+
# TODO: what's the use of this method?
|
|
456
|
+
def forward(self, x, y):
|
|
457
|
+
"""Temporary dummy forward method."""
|
|
458
|
+
return x, y
|
|
459
|
+
|
|
460
|
+
# TODO taken from pn2v. Ashesh needs to clarify this
|
|
461
|
+
def train_noise_model(
|
|
462
|
+
self,
|
|
463
|
+
signal,
|
|
464
|
+
observation,
|
|
465
|
+
learning_rate=1e-1,
|
|
466
|
+
batchSize=250000,
|
|
467
|
+
n_epochs=2000,
|
|
468
|
+
name="GMMNoiseModel.npz",
|
|
469
|
+
lowerClip=0,
|
|
470
|
+
upperClip=100,
|
|
471
|
+
):
|
|
472
|
+
"""Training to learn the noise model from signal - observation pairs.
|
|
473
|
+
|
|
474
|
+
Parameters
|
|
475
|
+
----------
|
|
476
|
+
signal: numpy array
|
|
477
|
+
Clean Signal Data
|
|
478
|
+
observation: numpy array
|
|
479
|
+
Noisy Observation Data
|
|
480
|
+
learning_rate: float
|
|
481
|
+
Learning rate. Default = 1e-1.
|
|
482
|
+
batchSize: int
|
|
483
|
+
Nini-batch size. Default = 250000.
|
|
484
|
+
n_epochs: int
|
|
485
|
+
Number of epochs. Default = 2000.
|
|
486
|
+
name: string
|
|
487
|
+
|
|
488
|
+
Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`.
|
|
489
|
+
|
|
490
|
+
lowerClip : int
|
|
491
|
+
Lower percentile for clipping. Default is 0.
|
|
492
|
+
upperClip : int
|
|
493
|
+
Upper percentile for clipping. Default is 100.
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
"""
|
|
497
|
+
sig_obs_pairs = self.getSignalObservationPairs(
|
|
498
|
+
signal, observation, lowerClip, upperClip
|
|
499
|
+
)
|
|
500
|
+
counter = 0
|
|
501
|
+
optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
|
|
502
|
+
for t in range(n_epochs):
|
|
503
|
+
|
|
504
|
+
jointLoss = 0
|
|
505
|
+
if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
|
|
506
|
+
counter = 0
|
|
507
|
+
sig_obs_pairs = fastShuffle(sig_obs_pairs, 1)
|
|
508
|
+
|
|
509
|
+
batch_vectors = sig_obs_pairs[
|
|
510
|
+
counter * batchSize : (counter + 1) * batchSize, :
|
|
511
|
+
]
|
|
512
|
+
observations = batch_vectors[:, 1].astype(np.float32)
|
|
513
|
+
signals = batch_vectors[:, 0].astype(np.float32)
|
|
514
|
+
# TODO do we absolutely need to move to GPU?
|
|
515
|
+
observations = (
|
|
516
|
+
torch.from_numpy(observations.astype(np.float32)).float().cuda()
|
|
517
|
+
)
|
|
518
|
+
signals = torch.from_numpy(signals).float().cuda()
|
|
519
|
+
p = self.likelihood(observations, signals)
|
|
520
|
+
loss = torch.mean(-torch.log(p))
|
|
521
|
+
jointLoss = jointLoss + loss
|
|
522
|
+
|
|
523
|
+
if t % 100 == 0:
|
|
524
|
+
print(t, jointLoss.item())
|
|
525
|
+
|
|
526
|
+
if t % (int(n_epochs * 0.5)) == 0:
|
|
527
|
+
trained_weight = self.weight.cpu().detach().numpy()
|
|
528
|
+
min_signal = self.min_signal.cpu().detach().numpy()
|
|
529
|
+
max_signal = self.max_signal.cpu().detach().numpy()
|
|
530
|
+
# TODO do we need to save?
|
|
531
|
+
# np.savez(self.path+name, trained_weight=trained_weight, min_signal = min_signal, max_signal = max_signal, min_sigma = self.min_sigma)
|
|
532
|
+
|
|
533
|
+
optimizer.zero_grad()
|
|
534
|
+
jointLoss.backward()
|
|
535
|
+
optimizer.step()
|
|
536
|
+
counter += 1
|
|
537
|
+
|
|
538
|
+
print("===================\n")
|
|
539
|
+
# print("The trained parameters (" + name + ") is saved at location: "+ self.path)
|
|
540
|
+
# TODO return istead of save ?
|
|
541
|
+
return trained_weight, min_signal, max_signal, self.min_sigma
|