careamics 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/config/algorithms/care_algorithm_model.py +12 -24
- careamics/config/algorithms/n2n_algorithm_model.py +13 -25
- careamics/config/algorithms/n2v_algorithm_model.py +13 -19
- careamics/config/configuration_factories.py +84 -23
- careamics/config/data/data_model.py +47 -2
- careamics/config/support/supported_algorithms.py +5 -1
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/lightning/callbacks/progress_bar_callback.py +1 -1
- careamics/lightning/train_data_module.py +10 -19
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +123 -128
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -1
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/METADATA +5 -3
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/RECORD +26 -24
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
|
+
from numpy.typing import NDArray
|
|
6
7
|
import numpy as np
|
|
7
8
|
import torch
|
|
8
9
|
import torch.nn as nn
|
|
@@ -13,63 +14,59 @@ if TYPE_CHECKING:
|
|
|
13
14
|
# TODO this module shouldn't be in lvae folder
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
def create_histogram(
|
|
17
|
+
def create_histogram(
|
|
18
|
+
bins: int, min_val: float, max_val: float, observation: NDArray, signal: NDArray
|
|
19
|
+
) -> NDArray:
|
|
17
20
|
"""
|
|
18
21
|
Creates a 2D histogram from 'observation' and 'signal'.
|
|
19
22
|
|
|
20
23
|
Parameters
|
|
21
24
|
----------
|
|
22
|
-
bins: int
|
|
23
|
-
|
|
24
|
-
min_val: float
|
|
25
|
-
|
|
26
|
-
max_val: float
|
|
27
|
-
|
|
28
|
-
observation:
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
signal:
|
|
33
|
-
|
|
25
|
+
bins : int
|
|
26
|
+
Number of bins in x and y.
|
|
27
|
+
min_val : float
|
|
28
|
+
Lower bound of the lowest bin in x and y.
|
|
29
|
+
max_val : float
|
|
30
|
+
Upper bound of the highest bin in x and y.
|
|
31
|
+
observation : np.ndarray
|
|
32
|
+
3D numpy array (stack of 2D images).
|
|
33
|
+
Observation.shape[0] must be divisible by signal.shape[0].
|
|
34
|
+
Assumes that n subsequent images in observation belong to one image in 'signal'.
|
|
35
|
+
signal : np.ndarray
|
|
36
|
+
3D numpy array (stack of 2D images).
|
|
34
37
|
|
|
35
38
|
Returns
|
|
36
39
|
-------
|
|
37
|
-
histogram:
|
|
40
|
+
histogram : np.ndarray
|
|
38
41
|
A 3D array:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
The values for x can be obtained by transposing 'histogram[1,...]' and 'histogram[2,...]'.
|
|
42
|
+
- histogram[0]: Normalized 2D counts.
|
|
43
|
+
- histogram[1]: Lower boundaries of bins along y.
|
|
44
|
+
- histogram[2]: Upper boundaries of bins along y.
|
|
45
|
+
The values for x can be obtained by transposing 'histogram[1]' and 'histogram[2]'.
|
|
44
46
|
"""
|
|
45
|
-
# TODO refactor this function
|
|
46
|
-
img_factor = int(observation.shape[0] / signal.shape[0])
|
|
47
47
|
histogram = np.zeros((3, bins, bins))
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
1:
|
|
71
|
-
] # The upper boundaries of each bin in y are stored in dimension 2
|
|
72
|
-
# The accordent numbers for x are just transopsed.
|
|
48
|
+
|
|
49
|
+
value_range = [min_val, max_val]
|
|
50
|
+
|
|
51
|
+
# Compute mapping factor between observation and signal samples
|
|
52
|
+
obs_to_signal_shape_factor = int(observation.shape[0] / signal.shape[0])
|
|
53
|
+
|
|
54
|
+
# Flatten arrays and align signal values
|
|
55
|
+
signal_indices = np.arange(observation.shape[0]) // obs_to_signal_shape_factor
|
|
56
|
+
signal_values = signal[signal_indices].ravel()
|
|
57
|
+
observation_values = observation.ravel()
|
|
58
|
+
|
|
59
|
+
count_histogram, signal_edges, _ = np.histogram2d(
|
|
60
|
+
signal_values, observation_values, bins=bins, range=[value_range, value_range]
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Normalize rows to obtain probabilities
|
|
64
|
+
row_sums = count_histogram.sum(axis=1, keepdims=True)
|
|
65
|
+
count_histogram /= np.clip(row_sums, a_min=1e-20, a_max=None)
|
|
66
|
+
|
|
67
|
+
histogram[0] = count_histogram
|
|
68
|
+
histogram[1] = signal_edges[:-1][..., np.newaxis]
|
|
69
|
+
histogram[2] = signal_edges[1:][..., np.newaxis]
|
|
73
70
|
|
|
74
71
|
return histogram
|
|
75
72
|
|
|
@@ -111,8 +108,11 @@ def noise_model_factory(
|
|
|
111
108
|
# TODO train a new model. Config should always be provided?
|
|
112
109
|
if nm.model_type == "GaussianMixtureNoiseModel":
|
|
113
110
|
# TODO one model for each channel all make this choise inside the model?
|
|
114
|
-
trained_nm = train_gm_noise_model(nm)
|
|
115
|
-
noise_models.append(trained_nm)
|
|
111
|
+
# trained_nm = train_gm_noise_model(nm)
|
|
112
|
+
# noise_models.append(trained_nm)
|
|
113
|
+
raise NotImplementedError(
|
|
114
|
+
"GaussianMixtureNoiseModel model training is not implemented."
|
|
115
|
+
)
|
|
116
116
|
else:
|
|
117
117
|
raise NotImplementedError(
|
|
118
118
|
f"Model {nm.model_type} is not implemented"
|
|
@@ -163,6 +163,8 @@ class MultiChannelNoiseModel(nn.Module):
|
|
|
163
163
|
List of noise models, one for each output channel.
|
|
164
164
|
"""
|
|
165
165
|
super().__init__()
|
|
166
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
167
|
+
|
|
166
168
|
for i, nmodel in enumerate(nmodels): # TODO refactor this !!!
|
|
167
169
|
if nmodel is not None:
|
|
168
170
|
self.add_module(
|
|
@@ -176,6 +178,13 @@ class MultiChannelNoiseModel(nn.Module):
|
|
|
176
178
|
|
|
177
179
|
print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
|
|
178
180
|
|
|
181
|
+
def to_device(self, device: torch.device):
|
|
182
|
+
self.device = device
|
|
183
|
+
self.to(device)
|
|
184
|
+
for ch_idx in range(self._nm_cnt):
|
|
185
|
+
nmodel = getattr(self, f"nmodel_{ch_idx}")
|
|
186
|
+
nmodel.to_device(device)
|
|
187
|
+
|
|
179
188
|
def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
|
|
180
189
|
"""Compute the likelihood of observations given signals for each channel.
|
|
181
190
|
|
|
@@ -212,28 +221,6 @@ class MultiChannelNoiseModel(nn.Module):
|
|
|
212
221
|
return torch.cat(ll_list, dim=1)
|
|
213
222
|
|
|
214
223
|
|
|
215
|
-
# TODO: is this needed?
|
|
216
|
-
def fastShuffle(series, num):
|
|
217
|
-
"""_summary_.
|
|
218
|
-
|
|
219
|
-
Parameters
|
|
220
|
-
----------
|
|
221
|
-
series : _type_
|
|
222
|
-
_description_
|
|
223
|
-
num : _type_
|
|
224
|
-
_description_
|
|
225
|
-
|
|
226
|
-
Returns
|
|
227
|
-
-------
|
|
228
|
-
_type_
|
|
229
|
-
_description_
|
|
230
|
-
"""
|
|
231
|
-
length = series.shape[0]
|
|
232
|
-
for _ in range(num):
|
|
233
|
-
series = series[np.random.permutation(length), :]
|
|
234
|
-
return series
|
|
235
|
-
|
|
236
|
-
|
|
237
224
|
class GaussianMixtureNoiseModel(nn.Module):
|
|
238
225
|
"""Define a noise model parameterized as a mixture of gaussians.
|
|
239
226
|
|
|
@@ -276,166 +263,176 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
276
263
|
"""
|
|
277
264
|
|
|
278
265
|
# TODO training a NM relies on getting a clean data(N2V e.g,)
|
|
279
|
-
def __init__(self, config: GaussianMixtureNMConfig):
|
|
266
|
+
def __init__(self, config: GaussianMixtureNMConfig) -> None:
|
|
280
267
|
super().__init__()
|
|
268
|
+
self.device = torch.device("cpu")
|
|
281
269
|
|
|
282
|
-
|
|
283
|
-
if config.path is None:
|
|
284
|
-
self.mode = "train"
|
|
285
|
-
# TODO this is (probably) to train a nm. We leave it for later refactoring
|
|
286
|
-
weight = config.weight
|
|
287
|
-
n_gaussian = config.n_gaussian
|
|
288
|
-
n_coeff = config.n_coeff
|
|
289
|
-
min_signal = torch.Tensor([config.min_signal])
|
|
290
|
-
max_signal = torch.Tensor([config.max_signal])
|
|
291
|
-
# TODO min_sigma cant be None ?
|
|
292
|
-
self.min_sigma = config.min_sigma
|
|
293
|
-
if weight is None:
|
|
294
|
-
weight = torch.randn(n_gaussian * 3, n_coeff)
|
|
295
|
-
weight[n_gaussian : 2 * n_gaussian, 1] = (
|
|
296
|
-
torch.log(max_signal - min_signal).float().to(self.device)
|
|
297
|
-
)
|
|
298
|
-
weight.requires_grad = True
|
|
299
|
-
|
|
300
|
-
self.n_gaussian = weight.shape[0] // 3
|
|
301
|
-
self.n_coeff = weight.shape[1]
|
|
302
|
-
self.weight = weight
|
|
303
|
-
self.min_signal = torch.Tensor([min_signal]).to(self.device)
|
|
304
|
-
self.max_signal = torch.Tensor([max_signal]).to(self.device)
|
|
305
|
-
self.tol = torch.tensor([1e-10]).to(self.device)
|
|
306
|
-
# TODO refactor to train on CPU!
|
|
307
|
-
else:
|
|
270
|
+
if config.path is not None:
|
|
308
271
|
params = np.load(config.path)
|
|
309
|
-
|
|
272
|
+
else:
|
|
273
|
+
params = config.model_dump(exclude_none=True)
|
|
274
|
+
|
|
275
|
+
min_sigma = torch.tensor(params["min_sigma"])
|
|
276
|
+
min_signal = torch.tensor(params["min_signal"])
|
|
277
|
+
max_signal = torch.tensor(params["max_signal"])
|
|
278
|
+
self.register_buffer("min_signal", min_signal)
|
|
279
|
+
self.register_buffer("max_signal", max_signal)
|
|
280
|
+
self.register_buffer("min_sigma", min_sigma)
|
|
281
|
+
self.register_buffer("tolerance", torch.tensor([1e-10]))
|
|
282
|
+
|
|
283
|
+
if "trained_weight" in params:
|
|
284
|
+
weight = torch.tensor(params["trained_weight"])
|
|
285
|
+
elif "weight" in params and params["weight"] is not None:
|
|
286
|
+
weight = torch.tensor(params["weight"])
|
|
287
|
+
else:
|
|
288
|
+
weight = self._initialize_weights(
|
|
289
|
+
params["n_gaussian"], params["n_coeff"], max_signal, min_signal
|
|
290
|
+
)
|
|
310
291
|
|
|
311
|
-
|
|
312
|
-
|
|
292
|
+
self.n_gaussian = weight.shape[0] // 3
|
|
293
|
+
self.n_coeff = weight.shape[1]
|
|
313
294
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
self.n_gaussian = self.weight.shape[0] // 3 # TODO why // 3 ?
|
|
317
|
-
self.n_coeff = self.weight.shape[1]
|
|
318
|
-
self.tol = torch.Tensor([1e-10])
|
|
319
|
-
self.min_signal = torch.Tensor([self.min_signal])
|
|
320
|
-
self.max_signal = torch.Tensor([self.max_signal])
|
|
295
|
+
self.register_parameter("weight", nn.Parameter(weight))
|
|
296
|
+
self._set_model_mode(mode="prediction")
|
|
321
297
|
|
|
322
298
|
print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
|
|
323
299
|
|
|
324
|
-
def
|
|
325
|
-
|
|
300
|
+
def _initialize_weights(
|
|
301
|
+
self,
|
|
302
|
+
n_gaussian: int,
|
|
303
|
+
n_coeff: int,
|
|
304
|
+
max_signal: torch.Tensor,
|
|
305
|
+
min_signal: torch.Tensor,
|
|
306
|
+
) -> torch.Tensor:
|
|
307
|
+
"""Create random weight initialization."""
|
|
308
|
+
weight = torch.randn(n_gaussian * 3, n_coeff)
|
|
309
|
+
weight[n_gaussian : 2 * n_gaussian, 1] = torch.log(
|
|
310
|
+
max_signal - min_signal
|
|
311
|
+
).float()
|
|
312
|
+
return weight
|
|
313
|
+
|
|
314
|
+
def to_device(self, device: torch.device):
|
|
315
|
+
self.device = device
|
|
316
|
+
self.to(device)
|
|
317
|
+
|
|
318
|
+
def _set_model_mode(self, mode: str) -> None:
|
|
319
|
+
"""Move parameters to the device and set weights' requires_grad depending on the mode"""
|
|
320
|
+
if mode == "train":
|
|
321
|
+
self.weight.requires_grad = True
|
|
322
|
+
else:
|
|
323
|
+
self.weight.requires_grad = False
|
|
324
|
+
|
|
325
|
+
def polynomial_regressor(
|
|
326
|
+
self, weight_params: torch.Tensor, signals: torch.Tensor
|
|
327
|
+
) -> torch.Tensor:
|
|
328
|
+
"""Combines `weight_params` and signal `signals` to regress for the gaussian parameter values.
|
|
326
329
|
|
|
327
330
|
Parameters
|
|
328
331
|
----------
|
|
329
|
-
|
|
332
|
+
weight_params : Tensor
|
|
330
333
|
Corresponds to specific rows of the `self.weight`
|
|
331
334
|
|
|
332
|
-
signals :
|
|
335
|
+
signals : Tensor
|
|
333
336
|
Signals
|
|
334
337
|
|
|
335
338
|
Returns
|
|
336
339
|
-------
|
|
337
|
-
value :
|
|
340
|
+
value : Tensor
|
|
338
341
|
Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
|
|
339
342
|
"""
|
|
340
|
-
value =
|
|
341
|
-
for i in range(
|
|
342
|
-
value +=
|
|
343
|
+
value = torch.zeros_like(signals)
|
|
344
|
+
for i in range(weight_params.shape[0]):
|
|
345
|
+
value += weight_params[i] * (
|
|
343
346
|
((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
|
|
344
347
|
)
|
|
345
348
|
return value
|
|
346
349
|
|
|
347
|
-
def
|
|
348
|
-
|
|
350
|
+
def normal_density(
|
|
351
|
+
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
|
|
352
|
+
) -> torch.Tensor:
|
|
353
|
+
"""
|
|
354
|
+
Evaluates the normal probability density at `x` given the mean `mean` and standard deviation `std`.
|
|
349
355
|
|
|
350
356
|
Parameters
|
|
351
357
|
----------
|
|
352
|
-
x:
|
|
358
|
+
x: Tensor
|
|
353
359
|
Observations
|
|
354
|
-
|
|
360
|
+
mean: Tensor
|
|
355
361
|
Mean
|
|
356
|
-
|
|
362
|
+
std: Tensor
|
|
357
363
|
Standard-deviation
|
|
358
364
|
|
|
359
365
|
Returns
|
|
360
366
|
-------
|
|
361
|
-
tmp:
|
|
362
|
-
Normal probability density of `x` given `
|
|
363
|
-
|
|
367
|
+
tmp: Tensor
|
|
368
|
+
Normal probability density of `x` given `mean` and `std`
|
|
364
369
|
"""
|
|
365
|
-
tmp = -((x -
|
|
366
|
-
tmp = tmp / (2.0 *
|
|
370
|
+
tmp = -((x - mean) ** 2)
|
|
371
|
+
tmp = tmp / (2.0 * std * std)
|
|
367
372
|
tmp = torch.exp(tmp)
|
|
368
|
-
tmp = tmp / torch.sqrt((2.0 * np.pi) *
|
|
369
|
-
# print(tmp.min().item(), tmp.mean().item(), tmp.max().item(), tmp.shape)
|
|
373
|
+
tmp = tmp / torch.sqrt((2.0 * np.pi) * std * std)
|
|
370
374
|
return tmp
|
|
371
375
|
|
|
372
|
-
def likelihood(
|
|
373
|
-
|
|
376
|
+
def likelihood(
|
|
377
|
+
self, observations: torch.Tensor, signals: torch.Tensor
|
|
378
|
+
) -> torch.Tensor:
|
|
379
|
+
"""
|
|
380
|
+
Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
|
|
374
381
|
|
|
375
382
|
Parameters
|
|
376
383
|
----------
|
|
377
|
-
observations :
|
|
384
|
+
observations : Tensor
|
|
378
385
|
Noisy observations
|
|
379
|
-
signals :
|
|
386
|
+
signals : Tensor
|
|
380
387
|
Underlying signals
|
|
381
388
|
|
|
382
389
|
Returns
|
|
383
390
|
-------
|
|
384
|
-
value
|
|
391
|
+
value: torch.Tensor:
|
|
385
392
|
Likelihood of observations given the signals and the GMM noise model
|
|
386
|
-
|
|
387
393
|
"""
|
|
388
|
-
|
|
389
|
-
signals = signals.cpu()
|
|
390
|
-
observations = observations.cpu()
|
|
391
|
-
self.weight = self.weight.to(signals.device)
|
|
392
|
-
self.min_signal = self.min_signal.to(signals.device)
|
|
393
|
-
self.max_signal = self.max_signal.to(signals.device)
|
|
394
|
-
self.tol = self.tol.to(signals.device)
|
|
395
|
-
|
|
396
|
-
gaussianParameters = self.getGaussianParameters(signals)
|
|
394
|
+
gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
|
|
397
395
|
p = 0
|
|
398
396
|
for gaussian in range(self.n_gaussian):
|
|
399
397
|
p += (
|
|
400
|
-
self.
|
|
398
|
+
self.normal_density(
|
|
401
399
|
observations,
|
|
402
|
-
|
|
403
|
-
|
|
400
|
+
gaussian_parameters[gaussian],
|
|
401
|
+
gaussian_parameters[self.n_gaussian + gaussian],
|
|
404
402
|
)
|
|
405
|
-
*
|
|
403
|
+
* gaussian_parameters[2 * self.n_gaussian + gaussian]
|
|
406
404
|
)
|
|
407
|
-
return p + self.
|
|
405
|
+
return p + self.tolerance
|
|
408
406
|
|
|
409
|
-
def
|
|
410
|
-
"""
|
|
407
|
+
def get_gaussian_parameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
|
|
408
|
+
"""
|
|
409
|
+
Returns the noise model for given signals
|
|
411
410
|
|
|
412
411
|
Parameters
|
|
413
412
|
----------
|
|
414
|
-
signals :
|
|
413
|
+
signals : Tensor
|
|
415
414
|
Underlying signals
|
|
416
415
|
|
|
417
416
|
Returns
|
|
418
417
|
-------
|
|
419
|
-
|
|
418
|
+
noise_model: list of Tensor
|
|
420
419
|
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
421
420
|
"""
|
|
422
|
-
|
|
421
|
+
noise_model = []
|
|
423
422
|
mu = []
|
|
424
423
|
sigma = []
|
|
425
424
|
alpha = []
|
|
426
425
|
kernels = self.weight.shape[0] // 3
|
|
427
426
|
for num in range(kernels):
|
|
428
|
-
mu.append(self.
|
|
429
|
-
# expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W))
|
|
427
|
+
mu.append(self.polynomial_regressor(self.weight[num, :], signals))
|
|
430
428
|
expval = torch.exp(self.weight[kernels + num, :])
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
sigma.append(torch.sqrt(sigmaTemp))
|
|
429
|
+
sigma_temp = self.polynomial_regressor(expval, signals)
|
|
430
|
+
sigma_temp = torch.clamp(sigma_temp, min=self.min_sigma)
|
|
431
|
+
sigma.append(torch.sqrt(sigma_temp))
|
|
435
432
|
|
|
436
433
|
expval = torch.exp(
|
|
437
|
-
self.
|
|
438
|
-
+ self.
|
|
434
|
+
self.polynomial_regressor(self.weight[2 * kernels + num, :], signals)
|
|
435
|
+
+ self.tolerance
|
|
439
436
|
)
|
|
440
437
|
alpha.append(expval)
|
|
441
438
|
|
|
@@ -459,15 +456,30 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
459
456
|
mu[ker] = mu[ker] - sum_means + signals
|
|
460
457
|
|
|
461
458
|
for i in range(kernels):
|
|
462
|
-
|
|
459
|
+
noise_model.append(mu[i])
|
|
463
460
|
for j in range(kernels):
|
|
464
|
-
|
|
461
|
+
noise_model.append(sigma[j])
|
|
465
462
|
for k in range(kernels):
|
|
466
|
-
|
|
463
|
+
noise_model.append(alpha[k])
|
|
464
|
+
|
|
465
|
+
return noise_model
|
|
467
466
|
|
|
468
|
-
|
|
467
|
+
@staticmethod
|
|
468
|
+
def _fast_shuffle(series: torch.Tensor, num: int) -> torch.Tensor:
|
|
469
|
+
"""Shuffle the inputs randomly num times"""
|
|
470
|
+
length = series.shape[0]
|
|
471
|
+
for _ in range(num):
|
|
472
|
+
idx = torch.randperm(length)
|
|
473
|
+
series = series[idx, :]
|
|
474
|
+
return series
|
|
469
475
|
|
|
470
|
-
def
|
|
476
|
+
def get_signal_observation_pairs(
|
|
477
|
+
self,
|
|
478
|
+
signal: NDArray,
|
|
479
|
+
observation: NDArray,
|
|
480
|
+
lower_clip: float,
|
|
481
|
+
upper_clip: float,
|
|
482
|
+
) -> torch.Tensor:
|
|
471
483
|
"""Returns the Signal-Observation pixel intensities as a two-column array
|
|
472
484
|
|
|
473
485
|
Parameters
|
|
@@ -476,19 +488,18 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
476
488
|
Clean Signal Data
|
|
477
489
|
observation: numpy array
|
|
478
490
|
Noisy observation Data
|
|
479
|
-
|
|
491
|
+
lower_clip: float
|
|
480
492
|
Lower percentile bound for clipping.
|
|
481
|
-
|
|
493
|
+
upper_clip: float
|
|
482
494
|
Upper percentile bound for clipping.
|
|
483
495
|
|
|
484
496
|
Returns
|
|
485
497
|
-------
|
|
486
|
-
|
|
498
|
+
noise_model: list of torch floats
|
|
487
499
|
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
488
|
-
|
|
489
500
|
"""
|
|
490
|
-
lb = np.percentile(signal,
|
|
491
|
-
ub = np.percentile(signal,
|
|
501
|
+
lb = np.percentile(signal, lower_clip)
|
|
502
|
+
ub = np.percentile(signal, upper_clip)
|
|
492
503
|
stepsize = observation[0].size
|
|
493
504
|
n_observations = observation.shape[0]
|
|
494
505
|
n_signals = signal.shape[0]
|
|
@@ -501,19 +512,20 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
501
512
|
sig_obs_pairs = sig_obs_pairs[
|
|
502
513
|
(sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
|
|
503
514
|
]
|
|
504
|
-
|
|
515
|
+
sig_obs_pairs = sig_obs_pairs.astype(np.float32)
|
|
516
|
+
sig_obs_pairs = torch.from_numpy(sig_obs_pairs)
|
|
517
|
+
return self._fast_shuffle(sig_obs_pairs, 2)
|
|
505
518
|
|
|
506
519
|
def fit(
|
|
507
520
|
self,
|
|
508
|
-
signal,
|
|
509
|
-
observation,
|
|
510
|
-
learning_rate=1e-1,
|
|
511
|
-
|
|
512
|
-
n_epochs=2000,
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
):
|
|
521
|
+
signal: NDArray,
|
|
522
|
+
observation: NDArray,
|
|
523
|
+
learning_rate: float = 1e-1,
|
|
524
|
+
batch_size: int = 250000,
|
|
525
|
+
n_epochs: int = 2000,
|
|
526
|
+
lower_clip: float = 0.0,
|
|
527
|
+
upper_clip: float = 100.0,
|
|
528
|
+
) -> list[float]:
|
|
517
529
|
"""Training to learn the noise model from signal - observation pairs.
|
|
518
530
|
|
|
519
531
|
Parameters
|
|
@@ -524,49 +536,42 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
524
536
|
Noisy Observation Data
|
|
525
537
|
learning_rate: float
|
|
526
538
|
Learning rate. Default = 1e-1.
|
|
527
|
-
|
|
539
|
+
batch_size: int
|
|
528
540
|
Nini-batch size. Default = 250000.
|
|
529
541
|
n_epochs: int
|
|
530
542
|
Number of epochs. Default = 2000.
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`.
|
|
534
|
-
|
|
535
|
-
lowerClip : int
|
|
543
|
+
lower_clip : int
|
|
536
544
|
Lower percentile for clipping. Default is 0.
|
|
537
|
-
|
|
545
|
+
upper_clip : int
|
|
538
546
|
Upper percentile for clipping. Default is 100.
|
|
539
|
-
|
|
540
|
-
|
|
541
547
|
"""
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
)
|
|
545
|
-
counter = 0
|
|
548
|
+
self._set_model_mode(mode="train")
|
|
549
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
550
|
+
self.to_device(device)
|
|
546
551
|
optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
|
|
547
|
-
loss_arr = []
|
|
548
552
|
|
|
553
|
+
sig_obs_pairs = self.get_signal_observation_pairs(
|
|
554
|
+
signal, observation, lower_clip, upper_clip
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
train_losses = []
|
|
558
|
+
counter = 0
|
|
549
559
|
for t in range(n_epochs):
|
|
550
|
-
if (counter + 1) *
|
|
560
|
+
if (counter + 1) * batch_size >= sig_obs_pairs.shape[0]:
|
|
551
561
|
counter = 0
|
|
552
|
-
sig_obs_pairs =
|
|
562
|
+
sig_obs_pairs = self._fast_shuffle(sig_obs_pairs, 1)
|
|
553
563
|
|
|
554
564
|
batch_vectors = sig_obs_pairs[
|
|
555
|
-
counter *
|
|
565
|
+
counter * batch_size : (counter + 1) * batch_size, :
|
|
556
566
|
]
|
|
557
|
-
observations = batch_vectors[:, 1].
|
|
558
|
-
signals = batch_vectors[:, 0].
|
|
559
|
-
observations = (
|
|
560
|
-
torch.from_numpy(observations.astype(np.float32))
|
|
561
|
-
.float()
|
|
562
|
-
.to(self.device)
|
|
563
|
-
)
|
|
564
|
-
signals = torch.from_numpy(signals).float().to(self.device)
|
|
567
|
+
observations = batch_vectors[:, 1].to(self.device)
|
|
568
|
+
signals = batch_vectors[:, 0].to(self.device)
|
|
565
569
|
|
|
566
570
|
p = self.likelihood(observations, signals)
|
|
567
571
|
|
|
568
|
-
|
|
569
|
-
|
|
572
|
+
joint_loss = torch.mean(-torch.log(p))
|
|
573
|
+
train_losses.append(joint_loss.item())
|
|
574
|
+
|
|
570
575
|
if self.weight.isnan().any() or self.weight.isinf().any():
|
|
571
576
|
print(
|
|
572
577
|
"NaN or Inf detected in the weights. Aborting training at epoch: ",
|
|
@@ -575,19 +580,77 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
575
580
|
break
|
|
576
581
|
|
|
577
582
|
if t % 100 == 0:
|
|
578
|
-
|
|
583
|
+
last_losses = train_losses[-100:]
|
|
584
|
+
print(t, np.mean(last_losses))
|
|
579
585
|
|
|
580
586
|
optimizer.zero_grad()
|
|
581
|
-
|
|
587
|
+
joint_loss.backward()
|
|
582
588
|
optimizer.step()
|
|
583
589
|
counter += 1
|
|
584
590
|
|
|
585
|
-
self.
|
|
586
|
-
self.
|
|
587
|
-
self.max_signal = self.max_signal.cpu().detach().numpy()
|
|
591
|
+
self._set_model_mode(mode="prediction")
|
|
592
|
+
self.to_device(torch.device("cpu"))
|
|
588
593
|
print("===================\n")
|
|
594
|
+
return train_losses
|
|
595
|
+
|
|
596
|
+
def sample_observation_from_signal(self, signal: NDArray) -> NDArray:
|
|
597
|
+
"""
|
|
598
|
+
Sample an instance of observation based on an input signal using a
|
|
599
|
+
learned Gaussian Mixture Model. For each pixel in the input signal,
|
|
600
|
+
samples a corresponding noisy pixel.
|
|
601
|
+
|
|
602
|
+
Parameters
|
|
603
|
+
----------
|
|
604
|
+
signal: numpy array
|
|
605
|
+
Clean 2D signal data.
|
|
606
|
+
|
|
607
|
+
Returns
|
|
608
|
+
-------
|
|
609
|
+
observation: numpy array
|
|
610
|
+
An instance of noisy observation data based on the input signal.
|
|
611
|
+
"""
|
|
612
|
+
assert len(signal.shape) == 2, "Only 2D inputs are supported."
|
|
613
|
+
|
|
614
|
+
signal_tensor = torch.from_numpy(signal).to(torch.float32)
|
|
615
|
+
height, width = signal_tensor.shape
|
|
616
|
+
|
|
617
|
+
with torch.no_grad():
|
|
618
|
+
# Get gaussian parameters for each pixel
|
|
619
|
+
gaussian_params = self.get_gaussian_parameters(signal_tensor)
|
|
620
|
+
means = np.array(gaussian_params[: self.n_gaussian])
|
|
621
|
+
stds = np.array(gaussian_params[self.n_gaussian : self.n_gaussian * 2])
|
|
622
|
+
alphas = np.array(gaussian_params[self.n_gaussian * 2 :])
|
|
623
|
+
|
|
624
|
+
if self.n_gaussian == 1:
|
|
625
|
+
# Single gaussian case
|
|
626
|
+
observation = np.random.normal(
|
|
627
|
+
loc=means[0], scale=stds[0], size=(height, width)
|
|
628
|
+
)
|
|
629
|
+
else:
|
|
630
|
+
# Multiple gaussians: sample component for each pixel
|
|
631
|
+
uniform = np.random.rand(1, height, width)
|
|
632
|
+
# Compute cumulative probabilities for component selection
|
|
633
|
+
cumulative_alphas = np.cumsum(
|
|
634
|
+
alphas, axis=0
|
|
635
|
+
) # Shape: (n_gaussian, height, width)
|
|
636
|
+
selected_component = np.argmax(
|
|
637
|
+
uniform < cumulative_alphas, axis=0, keepdims=True
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
# For every pixel, choose the corresponding gaussian
|
|
641
|
+
# and get the learned mu and sigma
|
|
642
|
+
selected_mus = np.take_along_axis(means, selected_component, axis=0)
|
|
643
|
+
selected_stds = np.take_along_axis(stds, selected_component, axis=0)
|
|
644
|
+
selected_mus = selected_mus.squeeze(0)
|
|
645
|
+
selected_stds = selected_stds.squeeze(0)
|
|
646
|
+
|
|
647
|
+
# Sample from the normal distribution with learned mu and sigma
|
|
648
|
+
observation = np.random.normal(
|
|
649
|
+
selected_mus, selected_stds, size=(height, width)
|
|
650
|
+
)
|
|
651
|
+
return observation
|
|
589
652
|
|
|
590
|
-
def save(self, path: str, name: str):
|
|
653
|
+
def save(self, path: str, name: str) -> None:
|
|
591
654
|
"""Save the trained parameters on the noise model.
|
|
592
655
|
|
|
593
656
|
Parameters
|
|
@@ -600,9 +663,9 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
600
663
|
os.makedirs(path, exist_ok=True)
|
|
601
664
|
np.savez(
|
|
602
665
|
os.path.join(path, name),
|
|
603
|
-
trained_weight=self.
|
|
604
|
-
min_signal=self.min_signal,
|
|
605
|
-
max_signal=self.max_signal,
|
|
666
|
+
trained_weight=self.weight.numpy(),
|
|
667
|
+
min_signal=self.min_signal.numpy(),
|
|
668
|
+
max_signal=self.max_signal.numpy(),
|
|
606
669
|
min_sigma=self.min_sigma,
|
|
607
670
|
)
|
|
608
671
|
print("The trained parameters (" + name + ") is saved at location: " + path)
|