careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +235 -25
- careamics/cli/conf.py +19 -30
- careamics/cli/main.py +111 -10
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +2 -0
- careamics/config/architectures/lvae_model.py +104 -21
- careamics/config/configuration_factory.py +49 -45
- careamics/config/configuration_model.py +2 -2
- careamics/config/likelihood_model.py +7 -6
- careamics/config/loss_model.py +56 -0
- careamics/config/nm_model.py +24 -24
- careamics/config/vae_algorithm_model.py +14 -13
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/lightning/lightning_module.py +58 -27
- careamics/lightning/train_data_module.py +15 -1
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/bioimage/_readme_factory.py +25 -33
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +39 -17
- careamics/model_io/bmz_io.py +36 -25
- careamics/models/layers.py +6 -4
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -272
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import TYPE_CHECKING, Optional
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -12,6 +13,67 @@ if TYPE_CHECKING:
|
|
|
12
13
|
# TODO this module shouldn't be in lvae folder
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
def create_histogram(bins, min_val, max_val, observation, signal):
|
|
17
|
+
"""
|
|
18
|
+
Creates a 2D histogram from 'observation' and 'signal'.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
bins: int
|
|
23
|
+
The number of bins in x and y. The total number of 2D bins is 'bins'**2.
|
|
24
|
+
min_val: float
|
|
25
|
+
the lower bound of the lowest bin in x and y.
|
|
26
|
+
max_val: float
|
|
27
|
+
the highest bound of the highest bin in x and y.
|
|
28
|
+
observation: numpy array
|
|
29
|
+
A 3D numpy array that is interpretted as a stack of 2D images.
|
|
30
|
+
The number of images has to be divisible by the number of images in 'signal'.
|
|
31
|
+
It is assumed that n subsequent images in observation belong to one image image in 'signal'.
|
|
32
|
+
signal: numpy array
|
|
33
|
+
A 3D numpy array that is interpretted as a stack of 2D images.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
histogram: numpy array
|
|
38
|
+
A 3D array:
|
|
39
|
+
'histogram[0,...]' holds the normalized 2D counts.
|
|
40
|
+
Each row sums to 1, describing p(x_i|s_i).
|
|
41
|
+
'histogram[1,...]' holds the lower boundaries of each bin in y.
|
|
42
|
+
'histogram[2,...]' holds the upper boundaries of each bin in y.
|
|
43
|
+
The values for x can be obtained by transposing 'histogram[1,...]' and 'histogram[2,...]'.
|
|
44
|
+
"""
|
|
45
|
+
# TODO refactor this function
|
|
46
|
+
img_factor = int(observation.shape[0] / signal.shape[0])
|
|
47
|
+
histogram = np.zeros((3, bins, bins))
|
|
48
|
+
ra = [min_val, max_val]
|
|
49
|
+
|
|
50
|
+
for i in range(observation.shape[0]):
|
|
51
|
+
observation_ = observation[i].copy().ravel()
|
|
52
|
+
|
|
53
|
+
signal_ = (signal[i // img_factor].copy()).ravel()
|
|
54
|
+
a = np.histogram2d(signal_, observation_, bins=bins, range=[ra, ra])
|
|
55
|
+
histogram[0] = histogram[0] + a[0] + 1e-30 # This is for numerical stability
|
|
56
|
+
|
|
57
|
+
for i in range(bins):
|
|
58
|
+
if (
|
|
59
|
+
np.sum(histogram[0, i, :]) > 1e-20
|
|
60
|
+
): # We exclude empty rows from normalization
|
|
61
|
+
histogram[0, i, :] /= np.sum(
|
|
62
|
+
histogram[0, i, :]
|
|
63
|
+
) # we normalize each non-empty row
|
|
64
|
+
|
|
65
|
+
for i in range(bins):
|
|
66
|
+
histogram[1, :, i] = a[1][
|
|
67
|
+
:-1
|
|
68
|
+
] # The lower boundaries of each bin in y are stored in dimension 1
|
|
69
|
+
histogram[2, :, i] = a[1][
|
|
70
|
+
1:
|
|
71
|
+
] # The upper boundaries of each bin in y are stored in dimension 2
|
|
72
|
+
# The accordent numbers for x are just transopsed.
|
|
73
|
+
|
|
74
|
+
return histogram
|
|
75
|
+
|
|
76
|
+
|
|
15
77
|
def noise_model_factory(
|
|
16
78
|
model_config: Optional[MultiChannelNMConfig],
|
|
17
79
|
) -> Optional[MultiChannelNoiseModel]:
|
|
@@ -36,23 +98,24 @@ def noise_model_factory(
|
|
|
36
98
|
"""
|
|
37
99
|
if model_config:
|
|
38
100
|
noise_models = []
|
|
39
|
-
for
|
|
40
|
-
if
|
|
41
|
-
if
|
|
42
|
-
noise_models.append(GaussianMixtureNoiseModel(
|
|
101
|
+
for nm in model_config.noise_models:
|
|
102
|
+
if nm.path:
|
|
103
|
+
if nm.model_type == "GaussianMixtureNoiseModel":
|
|
104
|
+
noise_models.append(GaussianMixtureNoiseModel(nm))
|
|
43
105
|
else:
|
|
44
106
|
raise NotImplementedError(
|
|
45
|
-
f"Model {
|
|
107
|
+
f"Model {nm.model_type} is not implemented"
|
|
46
108
|
)
|
|
47
109
|
|
|
48
110
|
else: # TODO this means signal/obs are provided. Controlled in pydantic model
|
|
49
111
|
# TODO train a new model. Config should always be provided?
|
|
50
|
-
if
|
|
51
|
-
|
|
112
|
+
if nm.model_type == "GaussianMixtureNoiseModel":
|
|
113
|
+
# TODO one model for each channel all make this choise inside the model?
|
|
114
|
+
trained_nm = train_gm_noise_model(nm)
|
|
52
115
|
noise_models.append(trained_nm)
|
|
53
116
|
else:
|
|
54
117
|
raise NotImplementedError(
|
|
55
|
-
f"Model {
|
|
118
|
+
f"Model {nm.model_type} is not implemented"
|
|
56
119
|
)
|
|
57
120
|
return MultiChannelNoiseModel(noise_models)
|
|
58
121
|
return None
|
|
@@ -60,6 +123,8 @@ def noise_model_factory(
|
|
|
60
123
|
|
|
61
124
|
def train_gm_noise_model(
|
|
62
125
|
model_config: GaussianMixtureNMConfig,
|
|
126
|
+
signal: np.ndarray,
|
|
127
|
+
observation: np.ndarray,
|
|
63
128
|
) -> GaussianMixtureNoiseModel:
|
|
64
129
|
"""Train a Gaussian mixture noise model.
|
|
65
130
|
|
|
@@ -76,7 +141,7 @@ def train_gm_noise_model(
|
|
|
76
141
|
# TODO any training params ? Different channels ?
|
|
77
142
|
noise_model = GaussianMixtureNoiseModel(model_config)
|
|
78
143
|
# TODO revisit config unpacking
|
|
79
|
-
noise_model.
|
|
144
|
+
noise_model.fit(signal, observation)
|
|
80
145
|
return noise_model
|
|
81
146
|
|
|
82
147
|
|
|
@@ -98,7 +163,7 @@ class MultiChannelNoiseModel(nn.Module):
|
|
|
98
163
|
List of noise models, one for each output channel.
|
|
99
164
|
"""
|
|
100
165
|
super().__init__()
|
|
101
|
-
for i, nmodel in enumerate(nmodels):
|
|
166
|
+
for i, nmodel in enumerate(nmodels): # TODO refactor this !!!
|
|
102
167
|
if nmodel is not None:
|
|
103
168
|
self.add_module(
|
|
104
169
|
f"nmodel_{i}", nmodel
|
|
@@ -213,65 +278,49 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
213
278
|
# TODO training a NM relies on getting a clean data(N2V e.g,)
|
|
214
279
|
def __init__(self, config: GaussianMixtureNMConfig):
|
|
215
280
|
super().__init__()
|
|
216
|
-
self._learnable = False
|
|
217
281
|
|
|
282
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
218
283
|
if config.path is None:
|
|
284
|
+
self.mode = "train"
|
|
219
285
|
# TODO this is (probably) to train a nm. We leave it for later refactoring
|
|
220
286
|
weight = config.weight
|
|
221
287
|
n_gaussian = config.n_gaussian
|
|
222
288
|
n_coeff = config.n_coeff
|
|
223
|
-
min_signal = config.min_signal
|
|
224
|
-
max_signal = config.max_signal
|
|
225
|
-
# self.device = kwargs.get('device')
|
|
289
|
+
min_signal = torch.Tensor([config.min_signal])
|
|
290
|
+
max_signal = torch.Tensor([config.max_signal])
|
|
226
291
|
# TODO min_sigma cant be None ?
|
|
227
292
|
self.min_sigma = config.min_sigma
|
|
228
293
|
if weight is None:
|
|
229
|
-
weight =
|
|
230
|
-
weight[n_gaussian : 2 * n_gaussian, 1] =
|
|
231
|
-
|
|
294
|
+
weight = torch.randn(n_gaussian * 3, n_coeff)
|
|
295
|
+
weight[n_gaussian : 2 * n_gaussian, 1] = (
|
|
296
|
+
torch.log(max_signal - min_signal).float().to(self.device)
|
|
297
|
+
)
|
|
232
298
|
weight.requires_grad = True
|
|
233
299
|
|
|
234
300
|
self.n_gaussian = weight.shape[0] // 3
|
|
235
301
|
self.n_coeff = weight.shape[1]
|
|
236
302
|
self.weight = weight
|
|
237
|
-
self.min_signal = torch.Tensor([min_signal])
|
|
238
|
-
self.max_signal = torch.Tensor([max_signal])
|
|
239
|
-
self.tol = torch.
|
|
303
|
+
self.min_signal = torch.Tensor([min_signal]).to(self.device)
|
|
304
|
+
self.max_signal = torch.Tensor([max_signal]).to(self.device)
|
|
305
|
+
self.tol = torch.tensor([1e-10]).to(self.device)
|
|
306
|
+
# TODO refactor to train on CPU!
|
|
240
307
|
else:
|
|
241
308
|
params = np.load(config.path)
|
|
242
|
-
|
|
309
|
+
self.mode = "inference" # TODO better name?
|
|
243
310
|
|
|
244
311
|
self.min_signal = torch.Tensor(params["min_signal"])
|
|
245
312
|
self.max_signal = torch.Tensor(params["max_signal"])
|
|
246
313
|
|
|
247
|
-
self.weight = torch.
|
|
248
|
-
torch.Tensor(params["trained_weight"]), requires_grad=False
|
|
249
|
-
)
|
|
314
|
+
self.weight = torch.Tensor(params["trained_weight"])
|
|
250
315
|
self.min_sigma = params["min_sigma"].item()
|
|
251
|
-
self.n_gaussian = self.weight.shape[0] // 3
|
|
316
|
+
self.n_gaussian = self.weight.shape[0] // 3 # TODO why // 3 ?
|
|
252
317
|
self.n_coeff = self.weight.shape[1]
|
|
253
|
-
self.tol = torch.Tensor([1e-10])
|
|
254
|
-
self.min_signal = torch.Tensor([self.min_signal])
|
|
255
|
-
self.max_signal = torch.Tensor([self.max_signal])
|
|
318
|
+
self.tol = torch.Tensor([1e-10])
|
|
319
|
+
self.min_signal = torch.Tensor([self.min_signal])
|
|
320
|
+
self.max_signal = torch.Tensor([self.max_signal])
|
|
256
321
|
|
|
257
322
|
print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
|
|
258
323
|
|
|
259
|
-
def make_learnable(self):
|
|
260
|
-
print(f"[{self.__class__.__name__}] Making noise model learnable")
|
|
261
|
-
self._learnable = True
|
|
262
|
-
self.weight.requires_grad = True
|
|
263
|
-
|
|
264
|
-
def to_device(self, cuda_tensor):
|
|
265
|
-
# TODO wtf is this ?
|
|
266
|
-
# move everything to GPU
|
|
267
|
-
if self.min_signal.device != cuda_tensor.device:
|
|
268
|
-
self.max_signal = self.max_signal.cuda()
|
|
269
|
-
self.min_signal = self.min_signal.cuda()
|
|
270
|
-
self.tol = self.tol.cuda()
|
|
271
|
-
# self.weight = self.weight.cuda()
|
|
272
|
-
if self._learnable:
|
|
273
|
-
self.weight.requires_grad = True
|
|
274
|
-
|
|
275
324
|
def polynomialRegressor(self, weightParams, signals):
|
|
276
325
|
"""Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
|
|
277
326
|
|
|
@@ -279,6 +328,7 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
279
328
|
----------
|
|
280
329
|
weightParams : torch.cuda.FloatTensor
|
|
281
330
|
Corresponds to specific rows of the `self.weight`
|
|
331
|
+
|
|
282
332
|
signals : torch.cuda.FloatTensor
|
|
283
333
|
Signals
|
|
284
334
|
|
|
@@ -294,90 +344,91 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
294
344
|
)
|
|
295
345
|
return value
|
|
296
346
|
|
|
297
|
-
def normalDens(
|
|
298
|
-
|
|
299
|
-
) -> torch.Tensor:
|
|
300
|
-
"""Evaluates the normal probability density at `x` given the mean `m` and
|
|
301
|
-
standard deviation `std`.
|
|
347
|
+
def normalDens(self, x, m_=0.0, std_=None):
|
|
348
|
+
"""Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`.
|
|
302
349
|
|
|
303
350
|
Parameters
|
|
304
351
|
----------
|
|
305
|
-
x: torch.
|
|
306
|
-
Observations
|
|
307
|
-
m_: torch.
|
|
308
|
-
|
|
309
|
-
std_: torch.
|
|
310
|
-
|
|
352
|
+
x: torch.cuda.FloatTensor
|
|
353
|
+
Observations
|
|
354
|
+
m_: torch.cuda.FloatTensor
|
|
355
|
+
Mean
|
|
356
|
+
std_: torch.cuda.FloatTensor
|
|
357
|
+
Standard-deviation
|
|
311
358
|
|
|
312
359
|
Returns
|
|
313
360
|
-------
|
|
314
|
-
tmp: torch.
|
|
361
|
+
tmp: torch.cuda.FloatTensor
|
|
315
362
|
Normal probability density of `x` given `m_` and `std_`
|
|
363
|
+
|
|
316
364
|
"""
|
|
317
365
|
tmp = -((x - m_) ** 2)
|
|
318
366
|
tmp = tmp / (2.0 * std_ * std_)
|
|
319
367
|
tmp = torch.exp(tmp)
|
|
320
368
|
tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
|
|
369
|
+
# print(tmp.min().item(), tmp.mean().item(), tmp.max().item(), tmp.shape)
|
|
321
370
|
return tmp
|
|
322
371
|
|
|
323
|
-
def likelihood(
|
|
324
|
-
|
|
325
|
-
) -> torch.Tensor:
|
|
326
|
-
"""Evaluate the likelihood of observations given the signals and the
|
|
327
|
-
corresponding gaussian parameters.
|
|
372
|
+
def likelihood(self, observations, signals):
|
|
373
|
+
"""Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
|
|
328
374
|
|
|
329
375
|
Parameters
|
|
330
376
|
----------
|
|
331
377
|
observations : torch.cuda.FloatTensor
|
|
332
|
-
Noisy observations
|
|
378
|
+
Noisy observations
|
|
333
379
|
signals : torch.cuda.FloatTensor
|
|
334
|
-
Underlying signals
|
|
380
|
+
Underlying signals
|
|
335
381
|
|
|
336
382
|
Returns
|
|
337
383
|
-------
|
|
338
384
|
value :p + self.tol
|
|
339
385
|
Likelihood of observations given the signals and the GMM noise model
|
|
386
|
+
|
|
340
387
|
"""
|
|
341
|
-
self.
|
|
388
|
+
if self.mode != "train":
|
|
389
|
+
signals = signals.cpu()
|
|
390
|
+
observations = observations.cpu()
|
|
391
|
+
self.weight = self.weight.to(signals.device)
|
|
392
|
+
self.min_signal = self.min_signal.to(signals.device)
|
|
393
|
+
self.max_signal = self.max_signal.to(signals.device)
|
|
394
|
+
self.tol = self.tol.to(signals.device)
|
|
395
|
+
|
|
342
396
|
gaussianParameters = self.getGaussianParameters(signals)
|
|
343
397
|
p = 0
|
|
344
398
|
for gaussian in range(self.n_gaussian):
|
|
345
399
|
p += (
|
|
346
400
|
self.normalDens(
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
401
|
+
observations,
|
|
402
|
+
gaussianParameters[gaussian],
|
|
403
|
+
gaussianParameters[self.n_gaussian + gaussian],
|
|
350
404
|
)
|
|
351
405
|
* gaussianParameters[2 * self.n_gaussian + gaussian]
|
|
352
406
|
)
|
|
353
407
|
return p + self.tol
|
|
354
408
|
|
|
355
|
-
def getGaussianParameters(self, signals
|
|
356
|
-
"""Returns the noise model for given signals
|
|
409
|
+
def getGaussianParameters(self, signals):
|
|
410
|
+
"""Returns the noise model for given signals
|
|
357
411
|
|
|
358
412
|
Parameters
|
|
359
413
|
----------
|
|
360
|
-
signals : torch.
|
|
414
|
+
signals : torch.cuda.FloatTensor
|
|
361
415
|
Underlying signals
|
|
362
416
|
|
|
363
417
|
Returns
|
|
364
418
|
-------
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
parameters for the `n_gaussian` gaussians in the mixture.
|
|
368
|
-
|
|
419
|
+
noiseModel: list of torch.cuda.FloatTensor
|
|
420
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
369
421
|
"""
|
|
370
|
-
|
|
422
|
+
noiseModel = []
|
|
371
423
|
mu = []
|
|
372
424
|
sigma = []
|
|
373
425
|
alpha = []
|
|
374
426
|
kernels = self.weight.shape[0] // 3
|
|
375
427
|
for num in range(kernels):
|
|
376
|
-
# For each Gaussian in the mixture, evaluate mean, std and weight
|
|
377
428
|
mu.append(self.polynomialRegressor(self.weight[num, :], signals))
|
|
378
|
-
|
|
429
|
+
# expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W))
|
|
379
430
|
expval = torch.exp(self.weight[kernels + num, :])
|
|
380
|
-
#
|
|
431
|
+
# self.maxval = max(self.maxval, expval.max().item())
|
|
381
432
|
sigmaTemp = self.polynomialRegressor(expval, signals)
|
|
382
433
|
sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
|
|
383
434
|
sigma.append(torch.sqrt(sigmaTemp))
|
|
@@ -386,7 +437,7 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
386
437
|
self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
|
|
387
438
|
+ self.tol
|
|
388
439
|
)
|
|
389
|
-
alpha.append(expval)
|
|
440
|
+
alpha.append(expval)
|
|
390
441
|
|
|
391
442
|
sum_alpha = 0
|
|
392
443
|
for al in range(kernels):
|
|
@@ -403,22 +454,21 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
403
454
|
|
|
404
455
|
# subtracting the alpha weighted average of the means from the means
|
|
405
456
|
# ensures that the GMM has the inclination to have the mean=signals.
|
|
406
|
-
#
|
|
457
|
+
# its like a residual conection. I don't understand why we need to learn the mean?
|
|
407
458
|
for ker in range(kernels):
|
|
408
459
|
mu[ker] = mu[ker] - sum_means + signals
|
|
409
460
|
|
|
410
461
|
for i in range(kernels):
|
|
411
|
-
|
|
462
|
+
noiseModel.append(mu[i])
|
|
412
463
|
for j in range(kernels):
|
|
413
|
-
|
|
464
|
+
noiseModel.append(sigma[j])
|
|
414
465
|
for k in range(kernels):
|
|
415
|
-
|
|
466
|
+
noiseModel.append(alpha[k])
|
|
416
467
|
|
|
417
|
-
return
|
|
468
|
+
return noiseModel
|
|
418
469
|
|
|
419
|
-
# TODO: this is to train the noise model
|
|
420
470
|
def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
|
|
421
|
-
"""Returns the Signal-Observation pixel intensities as a two-column array
|
|
471
|
+
"""Returns the Signal-Observation pixel intensities as a two-column array
|
|
422
472
|
|
|
423
473
|
Parameters
|
|
424
474
|
----------
|
|
@@ -433,8 +483,9 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
433
483
|
|
|
434
484
|
Returns
|
|
435
485
|
-------
|
|
436
|
-
|
|
486
|
+
noiseModel: list of torch floats
|
|
437
487
|
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
488
|
+
|
|
438
489
|
"""
|
|
439
490
|
lb = np.percentile(signal, lowerClip)
|
|
440
491
|
ub = np.percentile(signal, upperClip)
|
|
@@ -452,13 +503,7 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
452
503
|
]
|
|
453
504
|
return fastShuffle(sig_obs_pairs, 2)
|
|
454
505
|
|
|
455
|
-
|
|
456
|
-
def forward(self, x, y):
|
|
457
|
-
"""Temporary dummy forward method."""
|
|
458
|
-
return x, y
|
|
459
|
-
|
|
460
|
-
# TODO taken from pn2v. Ashesh needs to clarify this
|
|
461
|
-
def train_noise_model(
|
|
506
|
+
def fit(
|
|
462
507
|
self,
|
|
463
508
|
signal,
|
|
464
509
|
observation,
|
|
@@ -499,9 +544,9 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
499
544
|
)
|
|
500
545
|
counter = 0
|
|
501
546
|
optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
|
|
502
|
-
|
|
547
|
+
loss_arr = []
|
|
503
548
|
|
|
504
|
-
|
|
549
|
+
for t in range(n_epochs):
|
|
505
550
|
if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
|
|
506
551
|
counter = 0
|
|
507
552
|
sig_obs_pairs = fastShuffle(sig_obs_pairs, 1)
|
|
@@ -511,31 +556,53 @@ class GaussianMixtureNoiseModel(nn.Module):
|
|
|
511
556
|
]
|
|
512
557
|
observations = batch_vectors[:, 1].astype(np.float32)
|
|
513
558
|
signals = batch_vectors[:, 0].astype(np.float32)
|
|
514
|
-
# TODO do we absolutely need to move to GPU?
|
|
515
559
|
observations = (
|
|
516
|
-
torch.from_numpy(observations.astype(np.float32))
|
|
560
|
+
torch.from_numpy(observations.astype(np.float32))
|
|
561
|
+
.float()
|
|
562
|
+
.to(self.device)
|
|
517
563
|
)
|
|
518
|
-
signals = torch.from_numpy(signals).float().
|
|
564
|
+
signals = torch.from_numpy(signals).float().to(self.device)
|
|
565
|
+
|
|
519
566
|
p = self.likelihood(observations, signals)
|
|
520
|
-
loss = torch.mean(-torch.log(p))
|
|
521
|
-
jointLoss = jointLoss + loss
|
|
522
567
|
|
|
523
|
-
|
|
524
|
-
|
|
568
|
+
jointLoss = torch.mean(-torch.log(p))
|
|
569
|
+
loss_arr.append(jointLoss.item())
|
|
570
|
+
if self.weight.isnan().any() or self.weight.isinf().any():
|
|
571
|
+
print(
|
|
572
|
+
"NaN or Inf detected in the weights. Aborting training at epoch: ",
|
|
573
|
+
t,
|
|
574
|
+
)
|
|
575
|
+
break
|
|
525
576
|
|
|
526
|
-
if t %
|
|
527
|
-
|
|
528
|
-
min_signal = self.min_signal.cpu().detach().numpy()
|
|
529
|
-
max_signal = self.max_signal.cpu().detach().numpy()
|
|
530
|
-
# TODO do we need to save?
|
|
531
|
-
# np.savez(self.path+name, trained_weight=trained_weight, min_signal = min_signal, max_signal = max_signal, min_sigma = self.min_sigma)
|
|
577
|
+
if t % 100 == 0:
|
|
578
|
+
print(t, np.mean(loss_arr))
|
|
532
579
|
|
|
533
580
|
optimizer.zero_grad()
|
|
534
581
|
jointLoss.backward()
|
|
535
582
|
optimizer.step()
|
|
536
583
|
counter += 1
|
|
537
584
|
|
|
585
|
+
self.trained_weight = self.weight.cpu().detach().numpy()
|
|
586
|
+
self.min_signal = self.min_signal.cpu().detach().numpy()
|
|
587
|
+
self.max_signal = self.max_signal.cpu().detach().numpy()
|
|
538
588
|
print("===================\n")
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
589
|
+
|
|
590
|
+
def save(self, path: str, name: str):
|
|
591
|
+
"""Save the trained parameters on the noise model.
|
|
592
|
+
|
|
593
|
+
Parameters
|
|
594
|
+
----------
|
|
595
|
+
path : str
|
|
596
|
+
Path to save the trained parameters.
|
|
597
|
+
name : str
|
|
598
|
+
File name to save the trained parameters.
|
|
599
|
+
"""
|
|
600
|
+
os.makedirs(path, exist_ok=True)
|
|
601
|
+
np.savez(
|
|
602
|
+
os.path.join(path, name),
|
|
603
|
+
trained_weight=self.trained_weight,
|
|
604
|
+
min_signal=self.min_signal,
|
|
605
|
+
max_signal=self.max_signal,
|
|
606
|
+
min_sigma=self.min_sigma,
|
|
607
|
+
)
|
|
608
|
+
print("The trained parameters (" + name + ") is saved at location: " + path)
|