careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  from typing import TYPE_CHECKING, Optional
4
5
 
5
6
  import numpy as np
@@ -12,6 +13,67 @@ if TYPE_CHECKING:
12
13
  # TODO this module shouldn't be in lvae folder
13
14
 
14
15
 
16
+ def create_histogram(bins, min_val, max_val, observation, signal):
17
+ """
18
+ Creates a 2D histogram from 'observation' and 'signal'.
19
+
20
+ Parameters
21
+ ----------
22
+ bins: int
23
+ The number of bins in x and y. The total number of 2D bins is 'bins'**2.
24
+ min_val: float
25
+ the lower bound of the lowest bin in x and y.
26
+ max_val: float
27
+ the highest bound of the highest bin in x and y.
28
+ observation: numpy array
29
+ A 3D numpy array that is interpretted as a stack of 2D images.
30
+ The number of images has to be divisible by the number of images in 'signal'.
31
+ It is assumed that n subsequent images in observation belong to one image image in 'signal'.
32
+ signal: numpy array
33
+ A 3D numpy array that is interpretted as a stack of 2D images.
34
+
35
+ Returns
36
+ -------
37
+ histogram: numpy array
38
+ A 3D array:
39
+ 'histogram[0,...]' holds the normalized 2D counts.
40
+ Each row sums to 1, describing p(x_i|s_i).
41
+ 'histogram[1,...]' holds the lower boundaries of each bin in y.
42
+ 'histogram[2,...]' holds the upper boundaries of each bin in y.
43
+ The values for x can be obtained by transposing 'histogram[1,...]' and 'histogram[2,...]'.
44
+ """
45
+ # TODO refactor this function
46
+ img_factor = int(observation.shape[0] / signal.shape[0])
47
+ histogram = np.zeros((3, bins, bins))
48
+ ra = [min_val, max_val]
49
+
50
+ for i in range(observation.shape[0]):
51
+ observation_ = observation[i].copy().ravel()
52
+
53
+ signal_ = (signal[i // img_factor].copy()).ravel()
54
+ a = np.histogram2d(signal_, observation_, bins=bins, range=[ra, ra])
55
+ histogram[0] = histogram[0] + a[0] + 1e-30 # This is for numerical stability
56
+
57
+ for i in range(bins):
58
+ if (
59
+ np.sum(histogram[0, i, :]) > 1e-20
60
+ ): # We exclude empty rows from normalization
61
+ histogram[0, i, :] /= np.sum(
62
+ histogram[0, i, :]
63
+ ) # we normalize each non-empty row
64
+
65
+ for i in range(bins):
66
+ histogram[1, :, i] = a[1][
67
+ :-1
68
+ ] # The lower boundaries of each bin in y are stored in dimension 1
69
+ histogram[2, :, i] = a[1][
70
+ 1:
71
+ ] # The upper boundaries of each bin in y are stored in dimension 2
72
+ # The accordent numbers for x are just transopsed.
73
+
74
+ return histogram
75
+
76
+
15
77
  def noise_model_factory(
16
78
  model_config: Optional[MultiChannelNMConfig],
17
79
  ) -> Optional[MultiChannelNoiseModel]:
@@ -36,23 +98,24 @@ def noise_model_factory(
36
98
  """
37
99
  if model_config:
38
100
  noise_models = []
39
- for nm_config in model_config.noise_models:
40
- if nm_config.path:
41
- if nm_config.model_type == "GaussianMixtureNoiseModel":
42
- noise_models.append(GaussianMixtureNoiseModel(nm_config))
101
+ for nm in model_config.noise_models:
102
+ if nm.path:
103
+ if nm.model_type == "GaussianMixtureNoiseModel":
104
+ noise_models.append(GaussianMixtureNoiseModel(nm))
43
105
  else:
44
106
  raise NotImplementedError(
45
- f"Model {nm_config.model_type} is not implemented"
107
+ f"Model {nm.model_type} is not implemented"
46
108
  )
47
109
 
48
110
  else: # TODO this means signal/obs are provided. Controlled in pydantic model
49
111
  # TODO train a new model. Config should always be provided?
50
- if nm_config.model_type == "GaussianMixtureNoiseModel":
51
- trained_nm = train_gm_noise_model(nm_config)
112
+ if nm.model_type == "GaussianMixtureNoiseModel":
113
+ # TODO one model for each channel all make this choise inside the model?
114
+ trained_nm = train_gm_noise_model(nm)
52
115
  noise_models.append(trained_nm)
53
116
  else:
54
117
  raise NotImplementedError(
55
- f"Model {nm_config.model_type} is not implemented"
118
+ f"Model {nm.model_type} is not implemented"
56
119
  )
57
120
  return MultiChannelNoiseModel(noise_models)
58
121
  return None
@@ -60,6 +123,8 @@ def noise_model_factory(
60
123
 
61
124
  def train_gm_noise_model(
62
125
  model_config: GaussianMixtureNMConfig,
126
+ signal: np.ndarray,
127
+ observation: np.ndarray,
63
128
  ) -> GaussianMixtureNoiseModel:
64
129
  """Train a Gaussian mixture noise model.
65
130
 
@@ -76,7 +141,7 @@ def train_gm_noise_model(
76
141
  # TODO any training params ? Different channels ?
77
142
  noise_model = GaussianMixtureNoiseModel(model_config)
78
143
  # TODO revisit config unpacking
79
- noise_model.train_noise_model(model_config.signal, model_config.observation)
144
+ noise_model.fit(signal, observation)
80
145
  return noise_model
81
146
 
82
147
 
@@ -98,7 +163,7 @@ class MultiChannelNoiseModel(nn.Module):
98
163
  List of noise models, one for each output channel.
99
164
  """
100
165
  super().__init__()
101
- for i, nmodel in enumerate(nmodels):
166
+ for i, nmodel in enumerate(nmodels): # TODO refactor this !!!
102
167
  if nmodel is not None:
103
168
  self.add_module(
104
169
  f"nmodel_{i}", nmodel
@@ -213,65 +278,49 @@ class GaussianMixtureNoiseModel(nn.Module):
213
278
  # TODO training a NM relies on getting a clean data(N2V e.g,)
214
279
  def __init__(self, config: GaussianMixtureNMConfig):
215
280
  super().__init__()
216
- self._learnable = False
217
281
 
282
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
218
283
  if config.path is None:
284
+ self.mode = "train"
219
285
  # TODO this is (probably) to train a nm. We leave it for later refactoring
220
286
  weight = config.weight
221
287
  n_gaussian = config.n_gaussian
222
288
  n_coeff = config.n_coeff
223
- min_signal = config.min_signal
224
- max_signal = config.max_signal
225
- # self.device = kwargs.get('device')
289
+ min_signal = torch.Tensor([config.min_signal])
290
+ max_signal = torch.Tensor([config.max_signal])
226
291
  # TODO min_sigma cant be None ?
227
292
  self.min_sigma = config.min_sigma
228
293
  if weight is None:
229
- weight = np.random.randn(n_gaussian * 3, n_coeff)
230
- weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
231
- weight = torch.from_numpy(weight.astype(np.float32)).float().cuda()
294
+ weight = torch.randn(n_gaussian * 3, n_coeff)
295
+ weight[n_gaussian : 2 * n_gaussian, 1] = (
296
+ torch.log(max_signal - min_signal).float().to(self.device)
297
+ )
232
298
  weight.requires_grad = True
233
299
 
234
300
  self.n_gaussian = weight.shape[0] // 3
235
301
  self.n_coeff = weight.shape[1]
236
302
  self.weight = weight
237
- self.min_signal = torch.Tensor([min_signal])
238
- self.max_signal = torch.Tensor([max_signal])
239
- self.tol = torch.Tensor([1e-10])
303
+ self.min_signal = torch.Tensor([min_signal]).to(self.device)
304
+ self.max_signal = torch.Tensor([max_signal]).to(self.device)
305
+ self.tol = torch.tensor([1e-10]).to(self.device)
306
+ # TODO refactor to train on CPU!
240
307
  else:
241
308
  params = np.load(config.path)
242
- # self.device = kwargs.get('device')
309
+ self.mode = "inference" # TODO better name?
243
310
 
244
311
  self.min_signal = torch.Tensor(params["min_signal"])
245
312
  self.max_signal = torch.Tensor(params["max_signal"])
246
313
 
247
- self.weight = torch.nn.Parameter(
248
- torch.Tensor(params["trained_weight"]), requires_grad=False
249
- )
314
+ self.weight = torch.Tensor(params["trained_weight"])
250
315
  self.min_sigma = params["min_sigma"].item()
251
- self.n_gaussian = self.weight.shape[0] // 3
316
+ self.n_gaussian = self.weight.shape[0] // 3 # TODO why // 3 ?
252
317
  self.n_coeff = self.weight.shape[1]
253
- self.tol = torch.Tensor([1e-10]) # .to(self.device)
254
- self.min_signal = torch.Tensor([self.min_signal]) # .to(self.device)
255
- self.max_signal = torch.Tensor([self.max_signal]) # .to(self.device)
318
+ self.tol = torch.Tensor([1e-10])
319
+ self.min_signal = torch.Tensor([self.min_signal])
320
+ self.max_signal = torch.Tensor([self.max_signal])
256
321
 
257
322
  print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
258
323
 
259
- def make_learnable(self):
260
- print(f"[{self.__class__.__name__}] Making noise model learnable")
261
- self._learnable = True
262
- self.weight.requires_grad = True
263
-
264
- def to_device(self, cuda_tensor):
265
- # TODO wtf is this ?
266
- # move everything to GPU
267
- if self.min_signal.device != cuda_tensor.device:
268
- self.max_signal = self.max_signal.cuda()
269
- self.min_signal = self.min_signal.cuda()
270
- self.tol = self.tol.cuda()
271
- # self.weight = self.weight.cuda()
272
- if self._learnable:
273
- self.weight.requires_grad = True
274
-
275
324
  def polynomialRegressor(self, weightParams, signals):
276
325
  """Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
277
326
 
@@ -279,6 +328,7 @@ class GaussianMixtureNoiseModel(nn.Module):
279
328
  ----------
280
329
  weightParams : torch.cuda.FloatTensor
281
330
  Corresponds to specific rows of the `self.weight`
331
+
282
332
  signals : torch.cuda.FloatTensor
283
333
  Signals
284
334
 
@@ -294,90 +344,91 @@ class GaussianMixtureNoiseModel(nn.Module):
294
344
  )
295
345
  return value
296
346
 
297
- def normalDens(
298
- self, x: torch.Tensor, m_: torch.Tensor = 0.0, std_: torch.Tensor = None
299
- ) -> torch.Tensor:
300
- """Evaluates the normal probability density at `x` given the mean `m` and
301
- standard deviation `std`.
347
+ def normalDens(self, x, m_=0.0, std_=None):
348
+ """Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`.
302
349
 
303
350
  Parameters
304
351
  ----------
305
- x: torch.Tensor
306
- Observations (i.e., noisy image).
307
- m_: torch.Tensor
308
- Pixel-wise mean.
309
- std_: torch.Tensor
310
- Pixel-wise standard deviation.
352
+ x: torch.cuda.FloatTensor
353
+ Observations
354
+ m_: torch.cuda.FloatTensor
355
+ Mean
356
+ std_: torch.cuda.FloatTensor
357
+ Standard-deviation
311
358
 
312
359
  Returns
313
360
  -------
314
- tmp: torch.Tensor
361
+ tmp: torch.cuda.FloatTensor
315
362
  Normal probability density of `x` given `m_` and `std_`
363
+
316
364
  """
317
365
  tmp = -((x - m_) ** 2)
318
366
  tmp = tmp / (2.0 * std_ * std_)
319
367
  tmp = torch.exp(tmp)
320
368
  tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
369
+ # print(tmp.min().item(), tmp.mean().item(), tmp.max().item(), tmp.shape)
321
370
  return tmp
322
371
 
323
- def likelihood(
324
- self, observations: torch.Tensor, signals: torch.Tensor
325
- ) -> torch.Tensor:
326
- """Evaluate the likelihood of observations given the signals and the
327
- corresponding gaussian parameters.
372
+ def likelihood(self, observations, signals):
373
+ """Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
328
374
 
329
375
  Parameters
330
376
  ----------
331
377
  observations : torch.cuda.FloatTensor
332
- Noisy observations.
378
+ Noisy observations
333
379
  signals : torch.cuda.FloatTensor
334
- Underlying signals.
380
+ Underlying signals
335
381
 
336
382
  Returns
337
383
  -------
338
384
  value :p + self.tol
339
385
  Likelihood of observations given the signals and the GMM noise model
386
+
340
387
  """
341
- self.to_device(signals) # move al needed stuff to the same device as `signals``
388
+ if self.mode != "train":
389
+ signals = signals.cpu()
390
+ observations = observations.cpu()
391
+ self.weight = self.weight.to(signals.device)
392
+ self.min_signal = self.min_signal.to(signals.device)
393
+ self.max_signal = self.max_signal.to(signals.device)
394
+ self.tol = self.tol.to(signals.device)
395
+
342
396
  gaussianParameters = self.getGaussianParameters(signals)
343
397
  p = 0
344
398
  for gaussian in range(self.n_gaussian):
345
399
  p += (
346
400
  self.normalDens(
347
- x=observations,
348
- m_=gaussianParameters[gaussian],
349
- std_=gaussianParameters[self.n_gaussian + gaussian],
401
+ observations,
402
+ gaussianParameters[gaussian],
403
+ gaussianParameters[self.n_gaussian + gaussian],
350
404
  )
351
405
  * gaussianParameters[2 * self.n_gaussian + gaussian]
352
406
  )
353
407
  return p + self.tol
354
408
 
355
- def getGaussianParameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
356
- """Returns the noise model for given signals.
409
+ def getGaussianParameters(self, signals):
410
+ """Returns the noise model for given signals
357
411
 
358
412
  Parameters
359
413
  ----------
360
- signals : torch.Tensor
414
+ signals : torch.cuda.FloatTensor
361
415
  Underlying signals
362
416
 
363
417
  Returns
364
418
  -------
365
- gmmParams: list[torch.Tensor]
366
- A list containing tensors representing `mu`, `sigma` and `alpha`
367
- parameters for the `n_gaussian` gaussians in the mixture.
368
-
419
+ noiseModel: list of torch.cuda.FloatTensor
420
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
369
421
  """
370
- gmmParams = []
422
+ noiseModel = []
371
423
  mu = []
372
424
  sigma = []
373
425
  alpha = []
374
426
  kernels = self.weight.shape[0] // 3
375
427
  for num in range(kernels):
376
- # For each Gaussian in the mixture, evaluate mean, std and weight
377
428
  mu.append(self.polynomialRegressor(self.weight[num, :], signals))
378
-
429
+ # expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W))
379
430
  expval = torch.exp(self.weight[kernels + num, :])
380
- # TODO: why taking the exp? it is not in PPN2V paper...
431
+ # self.maxval = max(self.maxval, expval.max().item())
381
432
  sigmaTemp = self.polynomialRegressor(expval, signals)
382
433
  sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
383
434
  sigma.append(torch.sqrt(sigmaTemp))
@@ -386,7 +437,7 @@ class GaussianMixtureNoiseModel(nn.Module):
386
437
  self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
387
438
  + self.tol
388
439
  )
389
- alpha.append(expval) # NOTE: these are the numerators of weights
440
+ alpha.append(expval)
390
441
 
391
442
  sum_alpha = 0
392
443
  for al in range(kernels):
@@ -403,22 +454,21 @@ class GaussianMixtureNoiseModel(nn.Module):
403
454
 
404
455
  # subtracting the alpha weighted average of the means from the means
405
456
  # ensures that the GMM has the inclination to have the mean=signals.
406
- # TODO: I don't understand why we need to learn the mean?
457
+ # its like a residual conection. I don't understand why we need to learn the mean?
407
458
  for ker in range(kernels):
408
459
  mu[ker] = mu[ker] - sum_means + signals
409
460
 
410
461
  for i in range(kernels):
411
- gmmParams.append(mu[i])
462
+ noiseModel.append(mu[i])
412
463
  for j in range(kernels):
413
- gmmParams.append(sigma[j])
464
+ noiseModel.append(sigma[j])
414
465
  for k in range(kernels):
415
- gmmParams.append(alpha[k])
466
+ noiseModel.append(alpha[k])
416
467
 
417
- return gmmParams
468
+ return noiseModel
418
469
 
419
- # TODO: this is to train the noise model
420
470
  def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
421
- """Returns the Signal-Observation pixel intensities as a two-column array.
471
+ """Returns the Signal-Observation pixel intensities as a two-column array
422
472
 
423
473
  Parameters
424
474
  ----------
@@ -433,8 +483,9 @@ class GaussianMixtureNoiseModel(nn.Module):
433
483
 
434
484
  Returns
435
485
  -------
436
- gmmParams: list of torch floats
486
+ noiseModel: list of torch floats
437
487
  Contains a list of `mu`, `sigma` and `alpha` for the `signals`
488
+
438
489
  """
439
490
  lb = np.percentile(signal, lowerClip)
440
491
  ub = np.percentile(signal, upperClip)
@@ -452,13 +503,7 @@ class GaussianMixtureNoiseModel(nn.Module):
452
503
  ]
453
504
  return fastShuffle(sig_obs_pairs, 2)
454
505
 
455
- # TODO: what's the use of this method?
456
- def forward(self, x, y):
457
- """Temporary dummy forward method."""
458
- return x, y
459
-
460
- # TODO taken from pn2v. Ashesh needs to clarify this
461
- def train_noise_model(
506
+ def fit(
462
507
  self,
463
508
  signal,
464
509
  observation,
@@ -499,9 +544,9 @@ class GaussianMixtureNoiseModel(nn.Module):
499
544
  )
500
545
  counter = 0
501
546
  optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
502
- for t in range(n_epochs):
547
+ loss_arr = []
503
548
 
504
- jointLoss = 0
549
+ for t in range(n_epochs):
505
550
  if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
506
551
  counter = 0
507
552
  sig_obs_pairs = fastShuffle(sig_obs_pairs, 1)
@@ -511,31 +556,53 @@ class GaussianMixtureNoiseModel(nn.Module):
511
556
  ]
512
557
  observations = batch_vectors[:, 1].astype(np.float32)
513
558
  signals = batch_vectors[:, 0].astype(np.float32)
514
- # TODO do we absolutely need to move to GPU?
515
559
  observations = (
516
- torch.from_numpy(observations.astype(np.float32)).float().cuda()
560
+ torch.from_numpy(observations.astype(np.float32))
561
+ .float()
562
+ .to(self.device)
517
563
  )
518
- signals = torch.from_numpy(signals).float().cuda()
564
+ signals = torch.from_numpy(signals).float().to(self.device)
565
+
519
566
  p = self.likelihood(observations, signals)
520
- loss = torch.mean(-torch.log(p))
521
- jointLoss = jointLoss + loss
522
567
 
523
- if t % 100 == 0:
524
- print(t, jointLoss.item())
568
+ jointLoss = torch.mean(-torch.log(p))
569
+ loss_arr.append(jointLoss.item())
570
+ if self.weight.isnan().any() or self.weight.isinf().any():
571
+ print(
572
+ "NaN or Inf detected in the weights. Aborting training at epoch: ",
573
+ t,
574
+ )
575
+ break
525
576
 
526
- if t % (int(n_epochs * 0.5)) == 0:
527
- trained_weight = self.weight.cpu().detach().numpy()
528
- min_signal = self.min_signal.cpu().detach().numpy()
529
- max_signal = self.max_signal.cpu().detach().numpy()
530
- # TODO do we need to save?
531
- # np.savez(self.path+name, trained_weight=trained_weight, min_signal = min_signal, max_signal = max_signal, min_sigma = self.min_sigma)
577
+ if t % 100 == 0:
578
+ print(t, np.mean(loss_arr))
532
579
 
533
580
  optimizer.zero_grad()
534
581
  jointLoss.backward()
535
582
  optimizer.step()
536
583
  counter += 1
537
584
 
585
+ self.trained_weight = self.weight.cpu().detach().numpy()
586
+ self.min_signal = self.min_signal.cpu().detach().numpy()
587
+ self.max_signal = self.max_signal.cpu().detach().numpy()
538
588
  print("===================\n")
539
- # print("The trained parameters (" + name + ") is saved at location: "+ self.path)
540
- # TODO return istead of save ?
541
- return trained_weight, min_signal, max_signal, self.min_sigma
589
+
590
+ def save(self, path: str, name: str):
591
+ """Save the trained parameters on the noise model.
592
+
593
+ Parameters
594
+ ----------
595
+ path : str
596
+ Path to save the trained parameters.
597
+ name : str
598
+ File name to save the trained parameters.
599
+ """
600
+ os.makedirs(path, exist_ok=True)
601
+ np.savez(
602
+ os.path.join(path, name),
603
+ trained_weight=self.trained_weight,
604
+ min_signal=self.min_signal,
605
+ max_signal=self.max_signal,
606
+ min_sigma=self.min_sigma,
607
+ )
608
+ print("The trained parameters (" + name + ") is saved at location: " + path)