careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,26 +1,108 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import torch
|
|
6
7
|
import torch.nn as nn
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
9
11
|
|
|
12
|
+
# TODO this module shouldn't be in lvae folder
|
|
10
13
|
|
|
11
|
-
class DisentNoiseModel(nn.Module):
|
|
12
14
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
15
|
+
def noise_model_factory(
|
|
16
|
+
model_config: Optional[MultiChannelNMConfig],
|
|
17
|
+
) -> Optional[MultiChannelNoiseModel]:
|
|
18
|
+
"""Noise model factory.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
model_config : Optional[MultiChannelNMConfig]
|
|
23
|
+
Noise model configuration, a `MultiChannelNMConfig` config that defines
|
|
24
|
+
noise models for the different output channels.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
Optional[MultiChannelNoiseModel]
|
|
29
|
+
A noise model instance.
|
|
30
|
+
|
|
31
|
+
Raises
|
|
32
|
+
------
|
|
33
|
+
NotImplementedError
|
|
34
|
+
If the chosen noise model `model_type` is not implemented.
|
|
35
|
+
Currently only `GaussianMixtureNoiseModel` is implemented.
|
|
36
|
+
"""
|
|
37
|
+
if model_config:
|
|
38
|
+
noise_models = []
|
|
39
|
+
for nm_config in model_config.noise_models:
|
|
40
|
+
if nm_config.path:
|
|
41
|
+
if nm_config.model_type == "GaussianMixtureNoiseModel":
|
|
42
|
+
noise_models.append(GaussianMixtureNoiseModel(nm_config))
|
|
43
|
+
else:
|
|
44
|
+
raise NotImplementedError(
|
|
45
|
+
f"Model {nm_config.model_type} is not implemented"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
else: # TODO this means signal/obs are provided. Controlled in pydantic model
|
|
49
|
+
# TODO train a new model. Config should always be provided?
|
|
50
|
+
if nm_config.model_type == "GaussianMixtureNoiseModel":
|
|
51
|
+
trained_nm = train_gm_noise_model(nm_config)
|
|
52
|
+
noise_models.append(trained_nm)
|
|
53
|
+
else:
|
|
54
|
+
raise NotImplementedError(
|
|
55
|
+
f"Model {nm_config.model_type} is not implemented"
|
|
56
|
+
)
|
|
57
|
+
return MultiChannelNoiseModel(noise_models)
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def train_gm_noise_model(
|
|
62
|
+
model_config: GaussianMixtureNMConfig,
|
|
63
|
+
) -> GaussianMixtureNoiseModel:
|
|
64
|
+
"""Train a Gaussian mixture noise model.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
model_config : GaussianMixtureNoiseModel
|
|
69
|
+
_description_
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
_description_
|
|
74
|
+
"""
|
|
75
|
+
# TODO where to put train params?
|
|
76
|
+
# TODO any training params ? Different channels ?
|
|
77
|
+
noise_model = GaussianMixtureNoiseModel(model_config)
|
|
78
|
+
# TODO revisit config unpacking
|
|
79
|
+
noise_model.train_noise_model(noise_model.signal, noise_model.observation)
|
|
80
|
+
return noise_model
|
|
81
|
+
|
|
16
82
|
|
|
17
|
-
|
|
83
|
+
class MultiChannelNoiseModel(nn.Module):
|
|
84
|
+
def __init__(self, nmodels: list[GaussianMixtureNoiseModel]):
|
|
85
|
+
"""Constructor.
|
|
86
|
+
|
|
87
|
+
To handle noise models and the relative likelihood computation for multiple
|
|
88
|
+
output channels (e.g., muSplit, denoiseSplit).
|
|
89
|
+
|
|
90
|
+
This class:
|
|
91
|
+
- receives as input a variable number of noise models, one for each channel.
|
|
92
|
+
- computes the likelihood of observations given signals for each channel.
|
|
93
|
+
- returns the concatenation of these likelihoods.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
nmodels : list[GaussianMixtureNoiseModel]
|
|
98
|
+
List of noise models, one for each output channel.
|
|
18
99
|
"""
|
|
19
100
|
super().__init__()
|
|
20
|
-
# self.nmodels = nmodels
|
|
21
101
|
for i, nmodel in enumerate(nmodels):
|
|
22
102
|
if nmodel is not None:
|
|
23
|
-
self.add_module(
|
|
103
|
+
self.add_module(
|
|
104
|
+
f"nmodel_{i}", nmodel
|
|
105
|
+
) # TODO: wouldn't be easier to use a list?
|
|
24
106
|
|
|
25
107
|
self._nm_cnt = 0
|
|
26
108
|
for nmodel in nmodels:
|
|
@@ -30,181 +112,141 @@ class DisentNoiseModel(nn.Module):
|
|
|
30
112
|
print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
|
|
31
113
|
|
|
32
114
|
def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
|
|
115
|
+
"""Compute the likelihood of observations given signals for each channel.
|
|
33
116
|
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
obs : torch.Tensor
|
|
120
|
+
Noisy observations, i.e., the target(s). Specifically, the input noisy
|
|
121
|
+
image for HDN, or the noisy unmixed images used for supervision
|
|
122
|
+
for denoiSplit. Shape: (B, C, [Z], Y, X), where C is the number of
|
|
123
|
+
unmixed channels.
|
|
124
|
+
signal : torch.Tensor
|
|
125
|
+
Underlying signals, i.e., the (clean) output of the model. Specifically, the
|
|
126
|
+
denoised image for HDN, or the unmixed images for denoiSplit.
|
|
127
|
+
Shape: (B, C, [Z], Y, X), where C is the number of unmixed channels.
|
|
128
|
+
"""
|
|
129
|
+
# Case 1: obs and signal have a single channel (e.g., denoising)
|
|
34
130
|
if obs.shape[1] == 1:
|
|
35
131
|
assert signal.shape[1] == 1
|
|
36
|
-
assert self.n2model is None
|
|
37
132
|
return self.nmodel_0.likelihood(obs, signal)
|
|
38
133
|
|
|
39
|
-
|
|
40
|
-
|
|
134
|
+
# Case 2: obs and signal have multiple channels (e.g., denoiSplit)
|
|
135
|
+
assert obs.shape[1] == self._nm_cnt, (
|
|
136
|
+
"The number of channels in `obs` must match the number of noise models."
|
|
137
|
+
f" Got instead: obs={obs.shape[1]}, nm={self._nm_cnt}"
|
|
138
|
+
)
|
|
41
139
|
ll_list = []
|
|
42
140
|
for ch_idx in range(obs.shape[1]):
|
|
43
141
|
nmodel = getattr(self, f"nmodel_{ch_idx}")
|
|
44
142
|
ll_list.append(
|
|
45
143
|
nmodel.likelihood(
|
|
46
144
|
obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
|
|
47
|
-
)
|
|
145
|
+
) # slicing to keep the channel dimension
|
|
48
146
|
)
|
|
49
|
-
|
|
50
147
|
return torch.cat(ll_list, dim=1)
|
|
51
148
|
|
|
52
149
|
|
|
53
|
-
|
|
54
|
-
return os.path.join(*fpath.split("/")[-2:])
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def get_nm_config(noise_model_fpath: str):
|
|
58
|
-
config_fpath = os.path.join(os.path.dirname(noise_model_fpath), "config.json")
|
|
59
|
-
with open(config_fpath) as f:
|
|
60
|
-
noise_model_config = json.load(f)
|
|
61
|
-
return noise_model_config
|
|
62
|
-
|
|
63
|
-
|
|
150
|
+
# TODO: is this needed?
|
|
64
151
|
def fastShuffle(series, num):
|
|
152
|
+
"""_summary_.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
series : _type_
|
|
157
|
+
_description_
|
|
158
|
+
num : _type_
|
|
159
|
+
_description_
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
_type_
|
|
164
|
+
_description_
|
|
165
|
+
"""
|
|
65
166
|
length = series.shape[0]
|
|
66
|
-
for
|
|
167
|
+
for _ in range(num):
|
|
67
168
|
series = series[np.random.permutation(length), :]
|
|
68
169
|
return series
|
|
69
170
|
|
|
70
171
|
|
|
71
|
-
def get_noise_model(
|
|
72
|
-
enable_noise_model: bool,
|
|
73
|
-
model_type: ModelType,
|
|
74
|
-
noise_model_type: str,
|
|
75
|
-
noise_model_ch1_fpath: str,
|
|
76
|
-
noise_model_ch2_fpath: str,
|
|
77
|
-
noise_model_learnable: bool = False,
|
|
78
|
-
denoise_channel: str = "input",
|
|
79
|
-
):
|
|
80
|
-
if enable_noise_model:
|
|
81
|
-
nmodels = []
|
|
82
|
-
# HDN -> one single output -> one single noise model
|
|
83
|
-
if model_type == ModelType.Denoiser:
|
|
84
|
-
if noise_model_type == "hist":
|
|
85
|
-
raise NotImplementedError(
|
|
86
|
-
'"hist" noise model is not supported for now.'
|
|
87
|
-
)
|
|
88
|
-
elif noise_model_type == "gmm":
|
|
89
|
-
if denoise_channel == "Ch1":
|
|
90
|
-
nmodel_fpath = noise_model_ch1_fpath
|
|
91
|
-
print(f"Noise model Ch1: {nmodel_fpath}")
|
|
92
|
-
nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
|
|
93
|
-
nmodel2 = None
|
|
94
|
-
nmodels = [nmodel1, nmodel2]
|
|
95
|
-
elif denoise_channel == "Ch2":
|
|
96
|
-
nmodel_fpath = noise_model_ch2_fpath
|
|
97
|
-
print(f"Noise model Ch2: {nmodel_fpath}")
|
|
98
|
-
nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
|
|
99
|
-
nmodel2 = None
|
|
100
|
-
nmodels = [nmodel1, nmodel2]
|
|
101
|
-
elif denoise_channel == "input":
|
|
102
|
-
nmodel_fpath = noise_model_ch1_fpath
|
|
103
|
-
print(f"Noise model input: {nmodel_fpath}")
|
|
104
|
-
nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
|
|
105
|
-
nmodel2 = None
|
|
106
|
-
nmodels = [nmodel1, nmodel2]
|
|
107
|
-
else:
|
|
108
|
-
raise ValueError(f"Invalid denoise_channel: {denoise_channel}")
|
|
109
|
-
# muSplit -> two outputs -> two noise models
|
|
110
|
-
elif noise_model_type == "gmm":
|
|
111
|
-
print(f"Noise model Ch1: {noise_model_ch1_fpath}")
|
|
112
|
-
print(f"Noise model Ch2: {noise_model_ch2_fpath}")
|
|
113
|
-
|
|
114
|
-
nmodel1 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch1_fpath))
|
|
115
|
-
nmodel2 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch2_fpath))
|
|
116
|
-
|
|
117
|
-
nmodels = [nmodel1, nmodel2]
|
|
118
|
-
|
|
119
|
-
# if 'noise_model_ch3_fpath' in config.model:
|
|
120
|
-
# print(f'Noise model Ch3: {config.model.noise_model_ch3_fpath}')
|
|
121
|
-
# nmodel3 = GaussianMixtureNoiseModel(params=np.load(config.model.noise_model_ch3_fpath))
|
|
122
|
-
# nmodels = [nmodel1, nmodel2, nmodel3]
|
|
123
|
-
# else:
|
|
124
|
-
# nmodels = [nmodel1, nmodel2]
|
|
125
|
-
else:
|
|
126
|
-
raise ValueError(f"Invalid noise_model_type: {noise_model_type}")
|
|
127
|
-
|
|
128
|
-
if noise_model_learnable:
|
|
129
|
-
for nmodel in nmodels:
|
|
130
|
-
if nmodel is not None:
|
|
131
|
-
nmodel.make_learnable()
|
|
132
|
-
|
|
133
|
-
return DisentNoiseModel(*nmodels)
|
|
134
|
-
return None
|
|
135
|
-
|
|
136
|
-
|
|
137
172
|
class GaussianMixtureNoiseModel(nn.Module):
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
If
|
|
141
|
-
|
|
173
|
+
"""Define a noise model parameterized as a mixture of gaussians.
|
|
174
|
+
|
|
175
|
+
If `config.path` is not provided a new object is initialized from scratch.
|
|
176
|
+
Otherwise, a model is loaded from `config.path`.
|
|
142
177
|
|
|
143
178
|
Parameters
|
|
144
179
|
----------
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
n_gaussian
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
180
|
+
config : GaussianMixtureNMConfig
|
|
181
|
+
A `pydantic` model that defines the configuration of the GMM noise model.
|
|
182
|
+
|
|
183
|
+
Attributes
|
|
184
|
+
----------
|
|
185
|
+
min_signal : float
|
|
186
|
+
Minimum signal intensity expected in the image.
|
|
187
|
+
max_signal : float
|
|
188
|
+
Maximum signal intensity expected in the image.
|
|
189
|
+
path: Union[str, Path]
|
|
190
|
+
Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
|
|
191
|
+
weight : torch.nn.Parameter
|
|
192
|
+
A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
193
|
+
describing the GMM noise model, with each row corresponding to one
|
|
194
|
+
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
195
|
+
Specifically, rows are organized as follows:
|
|
196
|
+
- first n_gaussian rows correspond to the means
|
|
197
|
+
- next n_gaussian rows correspond to the weights
|
|
198
|
+
- last n_gaussian rows correspond to the standard deviations
|
|
199
|
+
If `weight=None`, the weight array is initialized using the `min_signal`
|
|
200
|
+
and `max_signal` parameters.
|
|
201
|
+
n_gaussian: int
|
|
202
|
+
Number of gaussians in the mixture.
|
|
203
|
+
n_coeff: int
|
|
204
|
+
Number of coefficients to describe the functional relationship between gaussian
|
|
205
|
+
parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
|
|
206
|
+
relationship and so on.
|
|
207
|
+
device: device
|
|
208
|
+
GPU device.
|
|
209
|
+
min_sigma: float
|
|
210
|
+
All values of `standard deviation` below this are clamped to this value.
|
|
169
211
|
"""
|
|
170
212
|
|
|
171
|
-
|
|
213
|
+
# TODO training a NM relies on getting a clean data(N2V e.g,)
|
|
214
|
+
def __init__(self, config: GaussianMixtureNMConfig):
|
|
172
215
|
super().__init__()
|
|
173
216
|
self._learnable = False
|
|
174
217
|
|
|
175
|
-
if
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
218
|
+
if config.path is None:
|
|
219
|
+
# TODO this is (probably) to train a nm. We leave it for later refactoring
|
|
220
|
+
weight = config.weight
|
|
221
|
+
n_gaussian = config.n_gaussian
|
|
222
|
+
n_coeff = config.n_coeff
|
|
223
|
+
min_signal = config.min_signal
|
|
224
|
+
max_signal = config.max_signal
|
|
181
225
|
# self.device = kwargs.get('device')
|
|
182
|
-
|
|
183
|
-
self.min_sigma =
|
|
226
|
+
# TODO min_sigma cant be None ?
|
|
227
|
+
self.min_sigma = config.min_sigma
|
|
184
228
|
if weight is None:
|
|
185
229
|
weight = np.random.randn(n_gaussian * 3, n_coeff)
|
|
186
230
|
weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
|
|
187
|
-
weight = torch.from_numpy(
|
|
188
|
-
|
|
189
|
-
).float() # .to(self.device)
|
|
190
|
-
weight = nn.Parameter(weight, requires_grad=True)
|
|
231
|
+
weight = torch.from_numpy(weight.astype(np.float32)).float().cuda()
|
|
232
|
+
weight.requires_grad = True
|
|
191
233
|
|
|
192
234
|
self.n_gaussian = weight.shape[0] // 3
|
|
193
235
|
self.n_coeff = weight.shape[1]
|
|
194
236
|
self.weight = weight
|
|
195
|
-
self.min_signal = torch.Tensor([min_signal])
|
|
196
|
-
self.max_signal = torch.Tensor([max_signal])
|
|
197
|
-
self.tol = torch.Tensor([1e-10])
|
|
237
|
+
self.min_signal = torch.Tensor([min_signal])
|
|
238
|
+
self.max_signal = torch.Tensor([max_signal])
|
|
239
|
+
self.tol = torch.Tensor([1e-10])
|
|
198
240
|
else:
|
|
199
|
-
params =
|
|
241
|
+
params = np.load(config.path)
|
|
200
242
|
# self.device = kwargs.get('device')
|
|
201
243
|
|
|
202
|
-
self.min_signal = torch.Tensor(params["min_signal"])
|
|
203
|
-
self.max_signal = torch.Tensor(params["max_signal"])
|
|
244
|
+
self.min_signal = torch.Tensor(params["min_signal"])
|
|
245
|
+
self.max_signal = torch.Tensor(params["max_signal"])
|
|
204
246
|
|
|
205
247
|
self.weight = torch.nn.Parameter(
|
|
206
248
|
torch.Tensor(params["trained_weight"]), requires_grad=False
|
|
207
|
-
)
|
|
249
|
+
)
|
|
208
250
|
self.min_sigma = params["min_sigma"].item()
|
|
209
251
|
self.n_gaussian = self.weight.shape[0] // 3
|
|
210
252
|
self.n_coeff = self.weight.shape[1]
|
|
@@ -216,19 +258,17 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
216
258
|
|
|
217
259
|
def make_learnable(self):
|
|
218
260
|
print(f"[{self.__class__.__name__}] Making noise model learnable")
|
|
219
|
-
|
|
220
261
|
self._learnable = True
|
|
221
262
|
self.weight.requires_grad = True
|
|
222
263
|
|
|
223
|
-
#
|
|
224
|
-
|
|
225
264
|
def to_device(self, cuda_tensor):
|
|
265
|
+
# TODO wtf is this ?
|
|
226
266
|
# move everything to GPU
|
|
227
267
|
if self.min_signal.device != cuda_tensor.device:
|
|
228
|
-
self.max_signal = self.max_signal.
|
|
229
|
-
self.min_signal = self.min_signal.
|
|
230
|
-
self.tol = self.tol.
|
|
231
|
-
self.weight = self.weight.
|
|
268
|
+
self.max_signal = self.max_signal.cuda()
|
|
269
|
+
self.min_signal = self.min_signal.cuda()
|
|
270
|
+
self.tol = self.tol.cuda()
|
|
271
|
+
# self.weight = self.weight.cuda()
|
|
232
272
|
if self._learnable:
|
|
233
273
|
self.weight.requires_grad = True
|
|
234
274
|
|
|
@@ -254,21 +294,24 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
254
294
|
)
|
|
255
295
|
return value
|
|
256
296
|
|
|
257
|
-
def normalDens(
|
|
258
|
-
|
|
297
|
+
def normalDens(
|
|
298
|
+
self, x: torch.Tensor, m_: torch.Tensor = 0.0, std_: torch.Tensor = None
|
|
299
|
+
) -> torch.Tensor:
|
|
300
|
+
"""Evaluates the normal probability density at `x` given the mean `m` and
|
|
301
|
+
standard deviation `std`.
|
|
259
302
|
|
|
260
303
|
Parameters
|
|
261
304
|
----------
|
|
262
|
-
x: torch.
|
|
263
|
-
Observations
|
|
264
|
-
m_: torch.
|
|
265
|
-
|
|
266
|
-
std_: torch.
|
|
267
|
-
|
|
305
|
+
x: torch.Tensor
|
|
306
|
+
Observations (i.e., noisy image).
|
|
307
|
+
m_: torch.Tensor
|
|
308
|
+
Pixel-wise mean.
|
|
309
|
+
std_: torch.Tensor
|
|
310
|
+
Pixel-wise standard deviation.
|
|
268
311
|
|
|
269
312
|
Returns
|
|
270
313
|
-------
|
|
271
|
-
tmp: torch.
|
|
314
|
+
tmp: torch.Tensor
|
|
272
315
|
Normal probability density of `x` given `m_` and `std_`
|
|
273
316
|
"""
|
|
274
317
|
tmp = -((x - m_) ** 2)
|
|
@@ -277,72 +320,73 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
277
320
|
tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
|
|
278
321
|
return tmp
|
|
279
322
|
|
|
280
|
-
def likelihood(
|
|
281
|
-
|
|
323
|
+
def likelihood(
|
|
324
|
+
self, observations: torch.Tensor, signals: torch.Tensor
|
|
325
|
+
) -> torch.Tensor:
|
|
326
|
+
"""Evaluate the likelihood of observations given the signals and the
|
|
327
|
+
corresponding gaussian parameters.
|
|
282
328
|
|
|
283
329
|
Parameters
|
|
284
330
|
----------
|
|
285
331
|
observations : torch.cuda.FloatTensor
|
|
286
|
-
Noisy observations
|
|
332
|
+
Noisy observations.
|
|
287
333
|
signals : torch.cuda.FloatTensor
|
|
288
|
-
Underlying signals
|
|
334
|
+
Underlying signals.
|
|
289
335
|
|
|
290
336
|
Returns
|
|
291
337
|
-------
|
|
292
338
|
value :p + self.tol
|
|
293
339
|
Likelihood of observations given the signals and the GMM noise model
|
|
294
340
|
"""
|
|
295
|
-
self.to_device(signals)
|
|
341
|
+
self.to_device(signals) # move al needed stuff to the same device as `signals``
|
|
296
342
|
gaussianParameters = self.getGaussianParameters(signals)
|
|
297
343
|
p = 0
|
|
298
344
|
for gaussian in range(self.n_gaussian):
|
|
299
345
|
p += (
|
|
300
346
|
self.normalDens(
|
|
301
|
-
observations,
|
|
302
|
-
gaussianParameters[gaussian],
|
|
303
|
-
gaussianParameters[self.n_gaussian + gaussian],
|
|
347
|
+
x=observations,
|
|
348
|
+
m_=gaussianParameters[gaussian],
|
|
349
|
+
std_=gaussianParameters[self.n_gaussian + gaussian],
|
|
304
350
|
)
|
|
305
351
|
* gaussianParameters[2 * self.n_gaussian + gaussian]
|
|
306
352
|
)
|
|
307
353
|
return p + self.tol
|
|
308
354
|
|
|
309
|
-
def getGaussianParameters(self, signals):
|
|
310
|
-
"""Returns the noise model for given signals
|
|
355
|
+
def getGaussianParameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
|
|
356
|
+
"""Returns the noise model for given signals.
|
|
311
357
|
|
|
312
358
|
Parameters
|
|
313
359
|
----------
|
|
314
|
-
signals : torch.
|
|
360
|
+
signals : torch.Tensor
|
|
315
361
|
Underlying signals
|
|
316
362
|
|
|
317
363
|
Returns
|
|
318
364
|
-------
|
|
319
|
-
|
|
320
|
-
|
|
365
|
+
gmmParams: list[torch.Tensor]
|
|
366
|
+
A list containing tensors representing `mu`, `sigma` and `alpha`
|
|
367
|
+
parameters for the `n_gaussian` gaussians in the mixture.
|
|
321
368
|
|
|
322
369
|
"""
|
|
323
|
-
|
|
370
|
+
gmmParams = []
|
|
324
371
|
mu = []
|
|
325
372
|
sigma = []
|
|
326
373
|
alpha = []
|
|
327
374
|
kernels = self.weight.shape[0] // 3
|
|
328
375
|
for num in range(kernels):
|
|
376
|
+
# For each Gaussian in the mixture, evaluate mean, std and weight
|
|
329
377
|
mu.append(self.polynomialRegressor(self.weight[num, :], signals))
|
|
330
|
-
|
|
378
|
+
|
|
331
379
|
expval = torch.exp(self.weight[kernels + num, :])
|
|
332
|
-
#
|
|
380
|
+
# TODO: why taking the exp? it is not in PPN2V paper...
|
|
333
381
|
sigmaTemp = self.polynomialRegressor(expval, signals)
|
|
334
382
|
sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
|
|
335
383
|
sigma.append(torch.sqrt(sigmaTemp))
|
|
336
384
|
|
|
337
|
-
# expval = torch.exp(
|
|
338
|
-
# torch.clamp(
|
|
339
|
-
# self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + self.tol, MAX_ALPHA_W))
|
|
340
385
|
expval = torch.exp(
|
|
341
386
|
self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
|
|
342
387
|
+ self.tol
|
|
343
388
|
)
|
|
344
|
-
#
|
|
345
|
-
alpha.append(expval)
|
|
389
|
+
alpha.append(expval) # NOTE: these are the numerators of weights
|
|
346
390
|
|
|
347
391
|
sum_alpha = 0
|
|
348
392
|
for al in range(kernels):
|
|
@@ -357,24 +401,24 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
357
401
|
for ker in range(kernels):
|
|
358
402
|
sum_means = alpha[ker] * mu[ker] + sum_means
|
|
359
403
|
|
|
360
|
-
mu_shifted = []
|
|
361
404
|
# subtracting the alpha weighted average of the means from the means
|
|
362
405
|
# ensures that the GMM has the inclination to have the mean=signals.
|
|
363
|
-
#
|
|
406
|
+
# TODO: I don't understand why we need to learn the mean?
|
|
364
407
|
for ker in range(kernels):
|
|
365
408
|
mu[ker] = mu[ker] - sum_means + signals
|
|
366
409
|
|
|
367
410
|
for i in range(kernels):
|
|
368
|
-
|
|
411
|
+
gmmParams.append(mu[i])
|
|
369
412
|
for j in range(kernels):
|
|
370
|
-
|
|
413
|
+
gmmParams.append(sigma[j])
|
|
371
414
|
for k in range(kernels):
|
|
372
|
-
|
|
415
|
+
gmmParams.append(alpha[k])
|
|
373
416
|
|
|
374
|
-
return
|
|
417
|
+
return gmmParams
|
|
375
418
|
|
|
419
|
+
# TODO: this is to train the noise model
|
|
376
420
|
def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
|
|
377
|
-
"""Returns the Signal-Observation pixel intensities as a two-column array
|
|
421
|
+
"""Returns the Signal-Observation pixel intensities as a two-column array.
|
|
378
422
|
|
|
379
423
|
Parameters
|
|
380
424
|
----------
|
|
@@ -389,7 +433,7 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
389
433
|
|
|
390
434
|
Returns
|
|
391
435
|
-------
|
|
392
|
-
|
|
436
|
+
gmmParams: list of torch floats
|
|
393
437
|
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
394
438
|
"""
|
|
395
439
|
lb = np.percentile(signal, lowerClip)
|
|
@@ -407,3 +451,91 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
407
451
|
(sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
|
|
408
452
|
]
|
|
409
453
|
return fastShuffle(sig_obs_pairs, 2)
|
|
454
|
+
|
|
455
|
+
# TODO: what's the use of this method?
|
|
456
|
+
def forward(self, x, y):
|
|
457
|
+
"""Temporary dummy forward method."""
|
|
458
|
+
return x, y
|
|
459
|
+
|
|
460
|
+
# TODO taken from pn2v. Ashesh needs to clarify this
|
|
461
|
+
def train_noise_model(
|
|
462
|
+
self,
|
|
463
|
+
signal,
|
|
464
|
+
observation,
|
|
465
|
+
learning_rate=1e-1,
|
|
466
|
+
batchSize=250000,
|
|
467
|
+
n_epochs=2000,
|
|
468
|
+
name="GMMNoiseModel.npz",
|
|
469
|
+
lowerClip=0,
|
|
470
|
+
upperClip=100,
|
|
471
|
+
):
|
|
472
|
+
"""Training to learn the noise model from signal - observation pairs.
|
|
473
|
+
|
|
474
|
+
Parameters
|
|
475
|
+
----------
|
|
476
|
+
signal: numpy array
|
|
477
|
+
Clean Signal Data
|
|
478
|
+
observation: numpy array
|
|
479
|
+
Noisy Observation Data
|
|
480
|
+
learning_rate: float
|
|
481
|
+
Learning rate. Default = 1e-1.
|
|
482
|
+
batchSize: int
|
|
483
|
+
Nini-batch size. Default = 250000.
|
|
484
|
+
n_epochs: int
|
|
485
|
+
Number of epochs. Default = 2000.
|
|
486
|
+
name: string
|
|
487
|
+
|
|
488
|
+
Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`.
|
|
489
|
+
|
|
490
|
+
lowerClip : int
|
|
491
|
+
Lower percentile for clipping. Default is 0.
|
|
492
|
+
upperClip : int
|
|
493
|
+
Upper percentile for clipping. Default is 100.
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
"""
|
|
497
|
+
sig_obs_pairs = self.getSignalObservationPairs(
|
|
498
|
+
signal, observation, lowerClip, upperClip
|
|
499
|
+
)
|
|
500
|
+
counter = 0
|
|
501
|
+
optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
|
|
502
|
+
for t in range(n_epochs):
|
|
503
|
+
|
|
504
|
+
jointLoss = 0
|
|
505
|
+
if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
|
|
506
|
+
counter = 0
|
|
507
|
+
sig_obs_pairs = fastShuffle(sig_obs_pairs, 1)
|
|
508
|
+
|
|
509
|
+
batch_vectors = sig_obs_pairs[
|
|
510
|
+
counter * batchSize : (counter + 1) * batchSize, :
|
|
511
|
+
]
|
|
512
|
+
observations = batch_vectors[:, 1].astype(np.float32)
|
|
513
|
+
signals = batch_vectors[:, 0].astype(np.float32)
|
|
514
|
+
# TODO do we absolutely need to move to GPU?
|
|
515
|
+
observations = (
|
|
516
|
+
torch.from_numpy(observations.astype(np.float32)).float().cuda()
|
|
517
|
+
)
|
|
518
|
+
signals = torch.from_numpy(signals).float().cuda()
|
|
519
|
+
p = self.likelihood(observations, signals)
|
|
520
|
+
loss = torch.mean(-torch.log(p))
|
|
521
|
+
jointLoss = jointLoss + loss
|
|
522
|
+
|
|
523
|
+
if t % 100 == 0:
|
|
524
|
+
print(t, jointLoss.item())
|
|
525
|
+
|
|
526
|
+
if t % (int(n_epochs * 0.5)) == 0:
|
|
527
|
+
trained_weight = self.weight.cpu().detach().numpy()
|
|
528
|
+
min_signal = self.min_signal.cpu().detach().numpy()
|
|
529
|
+
max_signal = self.max_signal.cpu().detach().numpy()
|
|
530
|
+
# TODO do we need to save?
|
|
531
|
+
# np.savez(self.path+name, trained_weight=trained_weight, min_signal = min_signal, max_signal = max_signal, min_sigma = self.min_sigma)
|
|
532
|
+
|
|
533
|
+
optimizer.zero_grad()
|
|
534
|
+
jointLoss.backward()
|
|
535
|
+
optimizer.step()
|
|
536
|
+
counter += 1
|
|
537
|
+
|
|
538
|
+
print("===================\n")
|
|
539
|
+
# print("The trained parameters (" + name + ") is saved at location: "+ self.path)
|
|
540
|
+
# TODO return istead of save ?
|
|
541
|
+
return trained_weight, min_signal, max_signal, self.min_sigma
|