careamics 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import os
4
4
  from typing import TYPE_CHECKING, Optional
5
5
 
6
+ from numpy.typing import NDArray
6
7
  import numpy as np
7
8
  import torch
8
9
  import torch.nn as nn
@@ -13,63 +14,59 @@ if TYPE_CHECKING:
13
14
  # TODO this module shouldn't be in lvae folder
14
15
 
15
16
 
16
- def create_histogram(bins, min_val, max_val, observation, signal):
17
+ def create_histogram(
18
+ bins: int, min_val: float, max_val: float, observation: NDArray, signal: NDArray
19
+ ) -> NDArray:
17
20
  """
18
21
  Creates a 2D histogram from 'observation' and 'signal'.
19
22
 
20
23
  Parameters
21
24
  ----------
22
- bins: int
23
- The number of bins in x and y. The total number of 2D bins is 'bins'**2.
24
- min_val: float
25
- the lower bound of the lowest bin in x and y.
26
- max_val: float
27
- the highest bound of the highest bin in x and y.
28
- observation: numpy array
29
- A 3D numpy array that is interpretted as a stack of 2D images.
30
- The number of images has to be divisible by the number of images in 'signal'.
31
- It is assumed that n subsequent images in observation belong to one image image in 'signal'.
32
- signal: numpy array
33
- A 3D numpy array that is interpretted as a stack of 2D images.
25
+ bins : int
26
+ Number of bins in x and y.
27
+ min_val : float
28
+ Lower bound of the lowest bin in x and y.
29
+ max_val : float
30
+ Upper bound of the highest bin in x and y.
31
+ observation : np.ndarray
32
+ 3D numpy array (stack of 2D images).
33
+ Observation.shape[0] must be divisible by signal.shape[0].
34
+ Assumes that n subsequent images in observation belong to one image in 'signal'.
35
+ signal : np.ndarray
36
+ 3D numpy array (stack of 2D images).
34
37
 
35
38
  Returns
36
39
  -------
37
- histogram: numpy array
40
+ histogram : np.ndarray
38
41
  A 3D array:
39
- 'histogram[0,...]' holds the normalized 2D counts.
40
- Each row sums to 1, describing p(x_i|s_i).
41
- 'histogram[1,...]' holds the lower boundaries of each bin in y.
42
- 'histogram[2,...]' holds the upper boundaries of each bin in y.
43
- The values for x can be obtained by transposing 'histogram[1,...]' and 'histogram[2,...]'.
42
+ - histogram[0]: Normalized 2D counts.
43
+ - histogram[1]: Lower boundaries of bins along y.
44
+ - histogram[2]: Upper boundaries of bins along y.
45
+ The values for x can be obtained by transposing 'histogram[1]' and 'histogram[2]'.
44
46
  """
45
- # TODO refactor this function
46
- img_factor = int(observation.shape[0] / signal.shape[0])
47
47
  histogram = np.zeros((3, bins, bins))
48
- ra = [min_val, max_val]
49
-
50
- for i in range(observation.shape[0]):
51
- observation_ = observation[i].copy().ravel()
52
-
53
- signal_ = (signal[i // img_factor].copy()).ravel()
54
- a = np.histogram2d(signal_, observation_, bins=bins, range=[ra, ra])
55
- histogram[0] = histogram[0] + a[0] + 1e-30 # This is for numerical stability
56
-
57
- for i in range(bins):
58
- if (
59
- np.sum(histogram[0, i, :]) > 1e-20
60
- ): # We exclude empty rows from normalization
61
- histogram[0, i, :] /= np.sum(
62
- histogram[0, i, :]
63
- ) # we normalize each non-empty row
64
-
65
- for i in range(bins):
66
- histogram[1, :, i] = a[1][
67
- :-1
68
- ] # The lower boundaries of each bin in y are stored in dimension 1
69
- histogram[2, :, i] = a[1][
70
- 1:
71
- ] # The upper boundaries of each bin in y are stored in dimension 2
72
- # The accordent numbers for x are just transopsed.
48
+
49
+ value_range = [min_val, max_val]
50
+
51
+ # Compute mapping factor between observation and signal samples
52
+ obs_to_signal_shape_factor = int(observation.shape[0] / signal.shape[0])
53
+
54
+ # Flatten arrays and align signal values
55
+ signal_indices = np.arange(observation.shape[0]) // obs_to_signal_shape_factor
56
+ signal_values = signal[signal_indices].ravel()
57
+ observation_values = observation.ravel()
58
+
59
+ count_histogram, signal_edges, _ = np.histogram2d(
60
+ signal_values, observation_values, bins=bins, range=[value_range, value_range]
61
+ )
62
+
63
+ # Normalize rows to obtain probabilities
64
+ row_sums = count_histogram.sum(axis=1, keepdims=True)
65
+ count_histogram /= np.clip(row_sums, a_min=1e-20, a_max=None)
66
+
67
+ histogram[0] = count_histogram
68
+ histogram[1] = signal_edges[:-1][..., np.newaxis]
69
+ histogram[2] = signal_edges[1:][..., np.newaxis]
73
70
 
74
71
  return histogram
75
72
 
@@ -111,8 +108,11 @@ def noise_model_factory(
111
108
  # TODO train a new model. Config should always be provided?
112
109
  if nm.model_type == "GaussianMixtureNoiseModel":
113
110
  # TODO one model for each channel all make this choise inside the model?
114
- trained_nm = train_gm_noise_model(nm)
115
- noise_models.append(trained_nm)
111
+ # trained_nm = train_gm_noise_model(nm)
112
+ # noise_models.append(trained_nm)
113
+ raise NotImplementedError(
114
+ "GaussianMixtureNoiseModel model training is not implemented."
115
+ )
116
116
  else:
117
117
  raise NotImplementedError(
118
118
  f"Model {nm.model_type} is not implemented"
@@ -163,6 +163,8 @@ class MultiChannelNoiseModel(nn.Module):
163
163
  List of noise models, one for each output channel.
164
164
  """
165
165
  super().__init__()
166
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
167
+
166
168
  for i, nmodel in enumerate(nmodels): # TODO refactor this !!!
167
169
  if nmodel is not None:
168
170
  self.add_module(
@@ -176,6 +178,13 @@ class MultiChannelNoiseModel(nn.Module):
176
178
 
177
179
  print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
178
180
 
181
+ def to_device(self, device: torch.device):
182
+ self.device = device
183
+ self.to(device)
184
+ for ch_idx in range(self._nm_cnt):
185
+ nmodel = getattr(self, f"nmodel_{ch_idx}")
186
+ nmodel.to_device(device)
187
+
179
188
  def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
180
189
  """Compute the likelihood of observations given signals for each channel.
181
190
 
@@ -212,28 +221,6 @@ class MultiChannelNoiseModel(nn.Module):
212
221
  return torch.cat(ll_list, dim=1)
213
222
 
214
223
 
215
- # TODO: is this needed?
216
- def fastShuffle(series, num):
217
- """_summary_.
218
-
219
- Parameters
220
- ----------
221
- series : _type_
222
- _description_
223
- num : _type_
224
- _description_
225
-
226
- Returns
227
- -------
228
- _type_
229
- _description_
230
- """
231
- length = series.shape[0]
232
- for _ in range(num):
233
- series = series[np.random.permutation(length), :]
234
- return series
235
-
236
-
237
224
  class GaussianMixtureNoiseModel(nn.Module):
238
225
  """Define a noise model parameterized as a mixture of gaussians.
239
226
 
@@ -276,166 +263,176 @@ class GaussianMixtureNoiseModel(nn.Module):
276
263
  """
277
264
 
278
265
  # TODO training a NM relies on getting a clean data(N2V e.g,)
279
- def __init__(self, config: GaussianMixtureNMConfig):
266
+ def __init__(self, config: GaussianMixtureNMConfig) -> None:
280
267
  super().__init__()
268
+ self.device = torch.device("cpu")
281
269
 
282
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
283
- if config.path is None:
284
- self.mode = "train"
285
- # TODO this is (probably) to train a nm. We leave it for later refactoring
286
- weight = config.weight
287
- n_gaussian = config.n_gaussian
288
- n_coeff = config.n_coeff
289
- min_signal = torch.Tensor([config.min_signal])
290
- max_signal = torch.Tensor([config.max_signal])
291
- # TODO min_sigma cant be None ?
292
- self.min_sigma = config.min_sigma
293
- if weight is None:
294
- weight = torch.randn(n_gaussian * 3, n_coeff)
295
- weight[n_gaussian : 2 * n_gaussian, 1] = (
296
- torch.log(max_signal - min_signal).float().to(self.device)
297
- )
298
- weight.requires_grad = True
299
-
300
- self.n_gaussian = weight.shape[0] // 3
301
- self.n_coeff = weight.shape[1]
302
- self.weight = weight
303
- self.min_signal = torch.Tensor([min_signal]).to(self.device)
304
- self.max_signal = torch.Tensor([max_signal]).to(self.device)
305
- self.tol = torch.tensor([1e-10]).to(self.device)
306
- # TODO refactor to train on CPU!
307
- else:
270
+ if config.path is not None:
308
271
  params = np.load(config.path)
309
- self.mode = "inference" # TODO better name?
272
+ else:
273
+ params = config.model_dump(exclude_none=True)
274
+
275
+ min_sigma = torch.tensor(params["min_sigma"])
276
+ min_signal = torch.tensor(params["min_signal"])
277
+ max_signal = torch.tensor(params["max_signal"])
278
+ self.register_buffer("min_signal", min_signal)
279
+ self.register_buffer("max_signal", max_signal)
280
+ self.register_buffer("min_sigma", min_sigma)
281
+ self.register_buffer("tolerance", torch.tensor([1e-10]))
282
+
283
+ if "trained_weight" in params:
284
+ weight = torch.tensor(params["trained_weight"])
285
+ elif "weight" in params and params["weight"] is not None:
286
+ weight = torch.tensor(params["weight"])
287
+ else:
288
+ weight = self._initialize_weights(
289
+ params["n_gaussian"], params["n_coeff"], max_signal, min_signal
290
+ )
310
291
 
311
- self.min_signal = torch.Tensor(params["min_signal"])
312
- self.max_signal = torch.Tensor(params["max_signal"])
292
+ self.n_gaussian = weight.shape[0] // 3
293
+ self.n_coeff = weight.shape[1]
313
294
 
314
- self.weight = torch.Tensor(params["trained_weight"])
315
- self.min_sigma = params["min_sigma"].item()
316
- self.n_gaussian = self.weight.shape[0] // 3 # TODO why // 3 ?
317
- self.n_coeff = self.weight.shape[1]
318
- self.tol = torch.Tensor([1e-10])
319
- self.min_signal = torch.Tensor([self.min_signal])
320
- self.max_signal = torch.Tensor([self.max_signal])
295
+ self.register_parameter("weight", nn.Parameter(weight))
296
+ self._set_model_mode(mode="prediction")
321
297
 
322
298
  print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
323
299
 
324
- def polynomialRegressor(self, weightParams, signals):
325
- """Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
300
+ def _initialize_weights(
301
+ self,
302
+ n_gaussian: int,
303
+ n_coeff: int,
304
+ max_signal: torch.Tensor,
305
+ min_signal: torch.Tensor,
306
+ ) -> torch.Tensor:
307
+ """Create random weight initialization."""
308
+ weight = torch.randn(n_gaussian * 3, n_coeff)
309
+ weight[n_gaussian : 2 * n_gaussian, 1] = torch.log(
310
+ max_signal - min_signal
311
+ ).float()
312
+ return weight
313
+
314
+ def to_device(self, device: torch.device):
315
+ self.device = device
316
+ self.to(device)
317
+
318
+ def _set_model_mode(self, mode: str) -> None:
319
+ """Move parameters to the device and set weights' requires_grad depending on the mode"""
320
+ if mode == "train":
321
+ self.weight.requires_grad = True
322
+ else:
323
+ self.weight.requires_grad = False
324
+
325
+ def polynomial_regressor(
326
+ self, weight_params: torch.Tensor, signals: torch.Tensor
327
+ ) -> torch.Tensor:
328
+ """Combines `weight_params` and signal `signals` to regress for the gaussian parameter values.
326
329
 
327
330
  Parameters
328
331
  ----------
329
- weightParams : torch.cuda.FloatTensor
332
+ weight_params : Tensor
330
333
  Corresponds to specific rows of the `self.weight`
331
334
 
332
- signals : torch.cuda.FloatTensor
335
+ signals : Tensor
333
336
  Signals
334
337
 
335
338
  Returns
336
339
  -------
337
- value : torch.cuda.FloatTensor
340
+ value : Tensor
338
341
  Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
339
342
  """
340
- value = 0
341
- for i in range(weightParams.shape[0]):
342
- value += weightParams[i] * (
343
+ value = torch.zeros_like(signals)
344
+ for i in range(weight_params.shape[0]):
345
+ value += weight_params[i] * (
343
346
  ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
344
347
  )
345
348
  return value
346
349
 
347
- def normalDens(self, x, m_=0.0, std_=None):
348
- """Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`.
350
+ def normal_density(
351
+ self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
352
+ ) -> torch.Tensor:
353
+ """
354
+ Evaluates the normal probability density at `x` given the mean `mean` and standard deviation `std`.
349
355
 
350
356
  Parameters
351
357
  ----------
352
- x: torch.cuda.FloatTensor
358
+ x: Tensor
353
359
  Observations
354
- m_: torch.cuda.FloatTensor
360
+ mean: Tensor
355
361
  Mean
356
- std_: torch.cuda.FloatTensor
362
+ std: Tensor
357
363
  Standard-deviation
358
364
 
359
365
  Returns
360
366
  -------
361
- tmp: torch.cuda.FloatTensor
362
- Normal probability density of `x` given `m_` and `std_`
363
-
367
+ tmp: Tensor
368
+ Normal probability density of `x` given `mean` and `std`
364
369
  """
365
- tmp = -((x - m_) ** 2)
366
- tmp = tmp / (2.0 * std_ * std_)
370
+ tmp = -((x - mean) ** 2)
371
+ tmp = tmp / (2.0 * std * std)
367
372
  tmp = torch.exp(tmp)
368
- tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
369
- # print(tmp.min().item(), tmp.mean().item(), tmp.max().item(), tmp.shape)
373
+ tmp = tmp / torch.sqrt((2.0 * np.pi) * std * std)
370
374
  return tmp
371
375
 
372
- def likelihood(self, observations, signals):
373
- """Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
376
+ def likelihood(
377
+ self, observations: torch.Tensor, signals: torch.Tensor
378
+ ) -> torch.Tensor:
379
+ """
380
+ Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
374
381
 
375
382
  Parameters
376
383
  ----------
377
- observations : torch.cuda.FloatTensor
384
+ observations : Tensor
378
385
  Noisy observations
379
- signals : torch.cuda.FloatTensor
386
+ signals : Tensor
380
387
  Underlying signals
381
388
 
382
389
  Returns
383
390
  -------
384
- value :p + self.tol
391
+ value: torch.Tensor:
385
392
  Likelihood of observations given the signals and the GMM noise model
386
-
387
393
  """
388
- if self.mode != "train":
389
- signals = signals.cpu()
390
- observations = observations.cpu()
391
- self.weight = self.weight.to(signals.device)
392
- self.min_signal = self.min_signal.to(signals.device)
393
- self.max_signal = self.max_signal.to(signals.device)
394
- self.tol = self.tol.to(signals.device)
395
-
396
- gaussianParameters = self.getGaussianParameters(signals)
394
+ gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
397
395
  p = 0
398
396
  for gaussian in range(self.n_gaussian):
399
397
  p += (
400
- self.normalDens(
398
+ self.normal_density(
401
399
  observations,
402
- gaussianParameters[gaussian],
403
- gaussianParameters[self.n_gaussian + gaussian],
400
+ gaussian_parameters[gaussian],
401
+ gaussian_parameters[self.n_gaussian + gaussian],
404
402
  )
405
- * gaussianParameters[2 * self.n_gaussian + gaussian]
403
+ * gaussian_parameters[2 * self.n_gaussian + gaussian]
406
404
  )
407
- return p + self.tol
405
+ return p + self.tolerance
408
406
 
409
- def getGaussianParameters(self, signals):
410
- """Returns the noise model for given signals
407
+ def get_gaussian_parameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
408
+ """
409
+ Returns the noise model for given signals
411
410
 
412
411
  Parameters
413
412
  ----------
414
- signals : torch.cuda.FloatTensor
413
+ signals : Tensor
415
414
  Underlying signals
416
415
 
417
416
  Returns
418
417
  -------
419
- noiseModel: list of torch.cuda.FloatTensor
418
+ noise_model: list of Tensor
420
419
  Contains a list of `mu`, `sigma` and `alpha` for the `signals`
421
420
  """
422
- noiseModel = []
421
+ noise_model = []
423
422
  mu = []
424
423
  sigma = []
425
424
  alpha = []
426
425
  kernels = self.weight.shape[0] // 3
427
426
  for num in range(kernels):
428
- mu.append(self.polynomialRegressor(self.weight[num, :], signals))
429
- # expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W))
427
+ mu.append(self.polynomial_regressor(self.weight[num, :], signals))
430
428
  expval = torch.exp(self.weight[kernels + num, :])
431
- # self.maxval = max(self.maxval, expval.max().item())
432
- sigmaTemp = self.polynomialRegressor(expval, signals)
433
- sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
434
- sigma.append(torch.sqrt(sigmaTemp))
429
+ sigma_temp = self.polynomial_regressor(expval, signals)
430
+ sigma_temp = torch.clamp(sigma_temp, min=self.min_sigma)
431
+ sigma.append(torch.sqrt(sigma_temp))
435
432
 
436
433
  expval = torch.exp(
437
- self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
438
- + self.tol
434
+ self.polynomial_regressor(self.weight[2 * kernels + num, :], signals)
435
+ + self.tolerance
439
436
  )
440
437
  alpha.append(expval)
441
438
 
@@ -459,15 +456,30 @@ class GaussianMixtureNoiseModel(nn.Module):
459
456
  mu[ker] = mu[ker] - sum_means + signals
460
457
 
461
458
  for i in range(kernels):
462
- noiseModel.append(mu[i])
459
+ noise_model.append(mu[i])
463
460
  for j in range(kernels):
464
- noiseModel.append(sigma[j])
461
+ noise_model.append(sigma[j])
465
462
  for k in range(kernels):
466
- noiseModel.append(alpha[k])
463
+ noise_model.append(alpha[k])
464
+
465
+ return noise_model
467
466
 
468
- return noiseModel
467
+ @staticmethod
468
+ def _fast_shuffle(series: torch.Tensor, num: int) -> torch.Tensor:
469
+ """Shuffle the inputs randomly num times"""
470
+ length = series.shape[0]
471
+ for _ in range(num):
472
+ idx = torch.randperm(length)
473
+ series = series[idx, :]
474
+ return series
469
475
 
470
- def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
476
+ def get_signal_observation_pairs(
477
+ self,
478
+ signal: NDArray,
479
+ observation: NDArray,
480
+ lower_clip: float,
481
+ upper_clip: float,
482
+ ) -> torch.Tensor:
471
483
  """Returns the Signal-Observation pixel intensities as a two-column array
472
484
 
473
485
  Parameters
@@ -476,19 +488,18 @@ class GaussianMixtureNoiseModel(nn.Module):
476
488
  Clean Signal Data
477
489
  observation: numpy array
478
490
  Noisy observation Data
479
- lowerClip: float
491
+ lower_clip: float
480
492
  Lower percentile bound for clipping.
481
- upperClip: float
493
+ upper_clip: float
482
494
  Upper percentile bound for clipping.
483
495
 
484
496
  Returns
485
497
  -------
486
- noiseModel: list of torch floats
498
+ noise_model: list of torch floats
487
499
  Contains a list of `mu`, `sigma` and `alpha` for the `signals`
488
-
489
500
  """
490
- lb = np.percentile(signal, lowerClip)
491
- ub = np.percentile(signal, upperClip)
501
+ lb = np.percentile(signal, lower_clip)
502
+ ub = np.percentile(signal, upper_clip)
492
503
  stepsize = observation[0].size
493
504
  n_observations = observation.shape[0]
494
505
  n_signals = signal.shape[0]
@@ -501,19 +512,20 @@ class GaussianMixtureNoiseModel(nn.Module):
501
512
  sig_obs_pairs = sig_obs_pairs[
502
513
  (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
503
514
  ]
504
- return fastShuffle(sig_obs_pairs, 2)
515
+ sig_obs_pairs = sig_obs_pairs.astype(np.float32)
516
+ sig_obs_pairs = torch.from_numpy(sig_obs_pairs)
517
+ return self._fast_shuffle(sig_obs_pairs, 2)
505
518
 
506
519
  def fit(
507
520
  self,
508
- signal,
509
- observation,
510
- learning_rate=1e-1,
511
- batchSize=250000,
512
- n_epochs=2000,
513
- name="GMMNoiseModel.npz",
514
- lowerClip=0,
515
- upperClip=100,
516
- ):
521
+ signal: NDArray,
522
+ observation: NDArray,
523
+ learning_rate: float = 1e-1,
524
+ batch_size: int = 250000,
525
+ n_epochs: int = 2000,
526
+ lower_clip: float = 0.0,
527
+ upper_clip: float = 100.0,
528
+ ) -> list[float]:
517
529
  """Training to learn the noise model from signal - observation pairs.
518
530
 
519
531
  Parameters
@@ -524,49 +536,42 @@ class GaussianMixtureNoiseModel(nn.Module):
524
536
  Noisy Observation Data
525
537
  learning_rate: float
526
538
  Learning rate. Default = 1e-1.
527
- batchSize: int
539
+ batch_size: int
528
540
  Nini-batch size. Default = 250000.
529
541
  n_epochs: int
530
542
  Number of epochs. Default = 2000.
531
- name: string
532
-
533
- Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`.
534
-
535
- lowerClip : int
543
+ lower_clip : int
536
544
  Lower percentile for clipping. Default is 0.
537
- upperClip : int
545
+ upper_clip : int
538
546
  Upper percentile for clipping. Default is 100.
539
-
540
-
541
547
  """
542
- sig_obs_pairs = self.getSignalObservationPairs(
543
- signal, observation, lowerClip, upperClip
544
- )
545
- counter = 0
548
+ self._set_model_mode(mode="train")
549
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
550
+ self.to_device(device)
546
551
  optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
547
- loss_arr = []
548
552
 
553
+ sig_obs_pairs = self.get_signal_observation_pairs(
554
+ signal, observation, lower_clip, upper_clip
555
+ )
556
+
557
+ train_losses = []
558
+ counter = 0
549
559
  for t in range(n_epochs):
550
- if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
560
+ if (counter + 1) * batch_size >= sig_obs_pairs.shape[0]:
551
561
  counter = 0
552
- sig_obs_pairs = fastShuffle(sig_obs_pairs, 1)
562
+ sig_obs_pairs = self._fast_shuffle(sig_obs_pairs, 1)
553
563
 
554
564
  batch_vectors = sig_obs_pairs[
555
- counter * batchSize : (counter + 1) * batchSize, :
565
+ counter * batch_size : (counter + 1) * batch_size, :
556
566
  ]
557
- observations = batch_vectors[:, 1].astype(np.float32)
558
- signals = batch_vectors[:, 0].astype(np.float32)
559
- observations = (
560
- torch.from_numpy(observations.astype(np.float32))
561
- .float()
562
- .to(self.device)
563
- )
564
- signals = torch.from_numpy(signals).float().to(self.device)
567
+ observations = batch_vectors[:, 1].to(self.device)
568
+ signals = batch_vectors[:, 0].to(self.device)
565
569
 
566
570
  p = self.likelihood(observations, signals)
567
571
 
568
- jointLoss = torch.mean(-torch.log(p))
569
- loss_arr.append(jointLoss.item())
572
+ joint_loss = torch.mean(-torch.log(p))
573
+ train_losses.append(joint_loss.item())
574
+
570
575
  if self.weight.isnan().any() or self.weight.isinf().any():
571
576
  print(
572
577
  "NaN or Inf detected in the weights. Aborting training at epoch: ",
@@ -575,19 +580,77 @@ class GaussianMixtureNoiseModel(nn.Module):
575
580
  break
576
581
 
577
582
  if t % 100 == 0:
578
- print(t, np.mean(loss_arr))
583
+ last_losses = train_losses[-100:]
584
+ print(t, np.mean(last_losses))
579
585
 
580
586
  optimizer.zero_grad()
581
- jointLoss.backward()
587
+ joint_loss.backward()
582
588
  optimizer.step()
583
589
  counter += 1
584
590
 
585
- self.trained_weight = self.weight.cpu().detach().numpy()
586
- self.min_signal = self.min_signal.cpu().detach().numpy()
587
- self.max_signal = self.max_signal.cpu().detach().numpy()
591
+ self._set_model_mode(mode="prediction")
592
+ self.to_device(torch.device("cpu"))
588
593
  print("===================\n")
594
+ return train_losses
595
+
596
+ def sample_observation_from_signal(self, signal: NDArray) -> NDArray:
597
+ """
598
+ Sample an instance of observation based on an input signal using a
599
+ learned Gaussian Mixture Model. For each pixel in the input signal,
600
+ samples a corresponding noisy pixel.
601
+
602
+ Parameters
603
+ ----------
604
+ signal: numpy array
605
+ Clean 2D signal data.
606
+
607
+ Returns
608
+ -------
609
+ observation: numpy array
610
+ An instance of noisy observation data based on the input signal.
611
+ """
612
+ assert len(signal.shape) == 2, "Only 2D inputs are supported."
613
+
614
+ signal_tensor = torch.from_numpy(signal).to(torch.float32)
615
+ height, width = signal_tensor.shape
616
+
617
+ with torch.no_grad():
618
+ # Get gaussian parameters for each pixel
619
+ gaussian_params = self.get_gaussian_parameters(signal_tensor)
620
+ means = np.array(gaussian_params[: self.n_gaussian])
621
+ stds = np.array(gaussian_params[self.n_gaussian : self.n_gaussian * 2])
622
+ alphas = np.array(gaussian_params[self.n_gaussian * 2 :])
623
+
624
+ if self.n_gaussian == 1:
625
+ # Single gaussian case
626
+ observation = np.random.normal(
627
+ loc=means[0], scale=stds[0], size=(height, width)
628
+ )
629
+ else:
630
+ # Multiple gaussians: sample component for each pixel
631
+ uniform = np.random.rand(1, height, width)
632
+ # Compute cumulative probabilities for component selection
633
+ cumulative_alphas = np.cumsum(
634
+ alphas, axis=0
635
+ ) # Shape: (n_gaussian, height, width)
636
+ selected_component = np.argmax(
637
+ uniform < cumulative_alphas, axis=0, keepdims=True
638
+ )
639
+
640
+ # For every pixel, choose the corresponding gaussian
641
+ # and get the learned mu and sigma
642
+ selected_mus = np.take_along_axis(means, selected_component, axis=0)
643
+ selected_stds = np.take_along_axis(stds, selected_component, axis=0)
644
+ selected_mus = selected_mus.squeeze(0)
645
+ selected_stds = selected_stds.squeeze(0)
646
+
647
+ # Sample from the normal distribution with learned mu and sigma
648
+ observation = np.random.normal(
649
+ selected_mus, selected_stds, size=(height, width)
650
+ )
651
+ return observation
589
652
 
590
- def save(self, path: str, name: str):
653
+ def save(self, path: str, name: str) -> None:
591
654
  """Save the trained parameters on the noise model.
592
655
 
593
656
  Parameters
@@ -600,9 +663,9 @@ class GaussianMixtureNoiseModel(nn.Module):
600
663
  os.makedirs(path, exist_ok=True)
601
664
  np.savez(
602
665
  os.path.join(path, name),
603
- trained_weight=self.trained_weight,
604
- min_signal=self.min_signal,
605
- max_signal=self.max_signal,
666
+ trained_weight=self.weight.numpy(),
667
+ min_signal=self.min_signal.numpy(),
668
+ max_signal=self.max_signal.numpy(),
606
669
  min_sigma=self.min_sigma,
607
670
  )
608
671
  print("The trained parameters (" + name + ") is saved at location: " + path)