careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,541 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ if TYPE_CHECKING:
10
+ from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
11
+
12
+ # TODO this module shouldn't be in lvae folder
13
+
14
+
15
+ def noise_model_factory(
16
+ model_config: Optional[MultiChannelNMConfig],
17
+ ) -> Optional[MultiChannelNoiseModel]:
18
+ """Noise model factory.
19
+
20
+ Parameters
21
+ ----------
22
+ model_config : Optional[MultiChannelNMConfig]
23
+ Noise model configuration, a `MultiChannelNMConfig` config that defines
24
+ noise models for the different output channels.
25
+
26
+ Returns
27
+ -------
28
+ Optional[MultiChannelNoiseModel]
29
+ A noise model instance.
30
+
31
+ Raises
32
+ ------
33
+ NotImplementedError
34
+ If the chosen noise model `model_type` is not implemented.
35
+ Currently only `GaussianMixtureNoiseModel` is implemented.
36
+ """
37
+ if model_config:
38
+ noise_models = []
39
+ for nm_config in model_config.noise_models:
40
+ if nm_config.path:
41
+ if nm_config.model_type == "GaussianMixtureNoiseModel":
42
+ noise_models.append(GaussianMixtureNoiseModel(nm_config))
43
+ else:
44
+ raise NotImplementedError(
45
+ f"Model {nm_config.model_type} is not implemented"
46
+ )
47
+
48
+ else: # TODO this means signal/obs are provided. Controlled in pydantic model
49
+ # TODO train a new model. Config should always be provided?
50
+ if nm_config.model_type == "GaussianMixtureNoiseModel":
51
+ trained_nm = train_gm_noise_model(nm_config)
52
+ noise_models.append(trained_nm)
53
+ else:
54
+ raise NotImplementedError(
55
+ f"Model {nm_config.model_type} is not implemented"
56
+ )
57
+ return MultiChannelNoiseModel(noise_models)
58
+ return None
59
+
60
+
61
+ def train_gm_noise_model(
62
+ model_config: GaussianMixtureNMConfig,
63
+ ) -> GaussianMixtureNoiseModel:
64
+ """Train a Gaussian mixture noise model.
65
+
66
+ Parameters
67
+ ----------
68
+ model_config : GaussianMixtureNoiseModel
69
+ _description_
70
+
71
+ Returns
72
+ -------
73
+ _description_
74
+ """
75
+ # TODO where to put train params?
76
+ # TODO any training params ? Different channels ?
77
+ noise_model = GaussianMixtureNoiseModel(model_config)
78
+ # TODO revisit config unpacking
79
+ noise_model.train_noise_model(noise_model.signal, noise_model.observation)
80
+ return noise_model
81
+
82
+
83
+ class MultiChannelNoiseModel(nn.Module):
84
+ def __init__(self, nmodels: list[GaussianMixtureNoiseModel]):
85
+ """Constructor.
86
+
87
+ To handle noise models and the relative likelihood computation for multiple
88
+ output channels (e.g., muSplit, denoiseSplit).
89
+
90
+ This class:
91
+ - receives as input a variable number of noise models, one for each channel.
92
+ - computes the likelihood of observations given signals for each channel.
93
+ - returns the concatenation of these likelihoods.
94
+
95
+ Parameters
96
+ ----------
97
+ nmodels : list[GaussianMixtureNoiseModel]
98
+ List of noise models, one for each output channel.
99
+ """
100
+ super().__init__()
101
+ for i, nmodel in enumerate(nmodels):
102
+ if nmodel is not None:
103
+ self.add_module(
104
+ f"nmodel_{i}", nmodel
105
+ ) # TODO: wouldn't be easier to use a list?
106
+
107
+ self._nm_cnt = 0
108
+ for nmodel in nmodels:
109
+ if nmodel is not None:
110
+ self._nm_cnt += 1
111
+
112
+ print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
113
+
114
+ def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
115
+ """Compute the likelihood of observations given signals for each channel.
116
+
117
+ Parameters
118
+ ----------
119
+ obs : torch.Tensor
120
+ Noisy observations, i.e., the target(s). Specifically, the input noisy
121
+ image for HDN, or the noisy unmixed images used for supervision
122
+ for denoiSplit. Shape: (B, C, [Z], Y, X), where C is the number of
123
+ unmixed channels.
124
+ signal : torch.Tensor
125
+ Underlying signals, i.e., the (clean) output of the model. Specifically, the
126
+ denoised image for HDN, or the unmixed images for denoiSplit.
127
+ Shape: (B, C, [Z], Y, X), where C is the number of unmixed channels.
128
+ """
129
+ # Case 1: obs and signal have a single channel (e.g., denoising)
130
+ if obs.shape[1] == 1:
131
+ assert signal.shape[1] == 1
132
+ return self.nmodel_0.likelihood(obs, signal)
133
+
134
+ # Case 2: obs and signal have multiple channels (e.g., denoiSplit)
135
+ assert obs.shape[1] == self._nm_cnt, (
136
+ "The number of channels in `obs` must match the number of noise models."
137
+ f" Got instead: obs={obs.shape[1]}, nm={self._nm_cnt}"
138
+ )
139
+ ll_list = []
140
+ for ch_idx in range(obs.shape[1]):
141
+ nmodel = getattr(self, f"nmodel_{ch_idx}")
142
+ ll_list.append(
143
+ nmodel.likelihood(
144
+ obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
145
+ ) # slicing to keep the channel dimension
146
+ )
147
+ return torch.cat(ll_list, dim=1)
148
+
149
+
150
+ # TODO: is this needed?
151
+ def fastShuffle(series, num):
152
+ """_summary_.
153
+
154
+ Parameters
155
+ ----------
156
+ series : _type_
157
+ _description_
158
+ num : _type_
159
+ _description_
160
+
161
+ Returns
162
+ -------
163
+ _type_
164
+ _description_
165
+ """
166
+ length = series.shape[0]
167
+ for _ in range(num):
168
+ series = series[np.random.permutation(length), :]
169
+ return series
170
+
171
+
172
+ class GaussianMixtureNoiseModel(nn.Module):
173
+ """Define a noise model parameterized as a mixture of gaussians.
174
+
175
+ If `config.path` is not provided a new object is initialized from scratch.
176
+ Otherwise, a model is loaded from `config.path`.
177
+
178
+ Parameters
179
+ ----------
180
+ config : GaussianMixtureNMConfig
181
+ A `pydantic` model that defines the configuration of the GMM noise model.
182
+
183
+ Attributes
184
+ ----------
185
+ min_signal : float
186
+ Minimum signal intensity expected in the image.
187
+ max_signal : float
188
+ Maximum signal intensity expected in the image.
189
+ path: Union[str, Path]
190
+ Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
191
+ weight : torch.nn.Parameter
192
+ A [3*n_gaussian, n_coeff] sized array containing the values of the weights
193
+ describing the GMM noise model, with each row corresponding to one
194
+ parameter of each gaussian, namely [mean, standard deviation and weight].
195
+ Specifically, rows are organized as follows:
196
+ - first n_gaussian rows correspond to the means
197
+ - next n_gaussian rows correspond to the weights
198
+ - last n_gaussian rows correspond to the standard deviations
199
+ If `weight=None`, the weight array is initialized using the `min_signal`
200
+ and `max_signal` parameters.
201
+ n_gaussian: int
202
+ Number of gaussians in the mixture.
203
+ n_coeff: int
204
+ Number of coefficients to describe the functional relationship between gaussian
205
+ parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
206
+ relationship and so on.
207
+ device: device
208
+ GPU device.
209
+ min_sigma: float
210
+ All values of `standard deviation` below this are clamped to this value.
211
+ """
212
+
213
+ # TODO training a NM relies on getting a clean data(N2V e.g,)
214
+ def __init__(self, config: GaussianMixtureNMConfig):
215
+ super().__init__()
216
+ self._learnable = False
217
+
218
+ if config.path is None:
219
+ # TODO this is (probably) to train a nm. We leave it for later refactoring
220
+ weight = config.weight
221
+ n_gaussian = config.n_gaussian
222
+ n_coeff = config.n_coeff
223
+ min_signal = config.min_signal
224
+ max_signal = config.max_signal
225
+ # self.device = kwargs.get('device')
226
+ # TODO min_sigma cant be None ?
227
+ self.min_sigma = config.min_sigma
228
+ if weight is None:
229
+ weight = np.random.randn(n_gaussian * 3, n_coeff)
230
+ weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
231
+ weight = torch.from_numpy(weight.astype(np.float32)).float().cuda()
232
+ weight.requires_grad = True
233
+
234
+ self.n_gaussian = weight.shape[0] // 3
235
+ self.n_coeff = weight.shape[1]
236
+ self.weight = weight
237
+ self.min_signal = torch.Tensor([min_signal])
238
+ self.max_signal = torch.Tensor([max_signal])
239
+ self.tol = torch.Tensor([1e-10])
240
+ else:
241
+ params = np.load(config.path)
242
+ # self.device = kwargs.get('device')
243
+
244
+ self.min_signal = torch.Tensor(params["min_signal"])
245
+ self.max_signal = torch.Tensor(params["max_signal"])
246
+
247
+ self.weight = torch.nn.Parameter(
248
+ torch.Tensor(params["trained_weight"]), requires_grad=False
249
+ )
250
+ self.min_sigma = params["min_sigma"].item()
251
+ self.n_gaussian = self.weight.shape[0] // 3
252
+ self.n_coeff = self.weight.shape[1]
253
+ self.tol = torch.Tensor([1e-10]) # .to(self.device)
254
+ self.min_signal = torch.Tensor([self.min_signal]) # .to(self.device)
255
+ self.max_signal = torch.Tensor([self.max_signal]) # .to(self.device)
256
+
257
+ print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
258
+
259
+ def make_learnable(self):
260
+ print(f"[{self.__class__.__name__}] Making noise model learnable")
261
+ self._learnable = True
262
+ self.weight.requires_grad = True
263
+
264
+ def to_device(self, cuda_tensor):
265
+ # TODO wtf is this ?
266
+ # move everything to GPU
267
+ if self.min_signal.device != cuda_tensor.device:
268
+ self.max_signal = self.max_signal.cuda()
269
+ self.min_signal = self.min_signal.cuda()
270
+ self.tol = self.tol.cuda()
271
+ # self.weight = self.weight.cuda()
272
+ if self._learnable:
273
+ self.weight.requires_grad = True
274
+
275
+ def polynomialRegressor(self, weightParams, signals):
276
+ """Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
277
+
278
+ Parameters
279
+ ----------
280
+ weightParams : torch.cuda.FloatTensor
281
+ Corresponds to specific rows of the `self.weight`
282
+ signals : torch.cuda.FloatTensor
283
+ Signals
284
+
285
+ Returns
286
+ -------
287
+ value : torch.cuda.FloatTensor
288
+ Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
289
+ """
290
+ value = 0
291
+ for i in range(weightParams.shape[0]):
292
+ value += weightParams[i] * (
293
+ ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
294
+ )
295
+ return value
296
+
297
+ def normalDens(
298
+ self, x: torch.Tensor, m_: torch.Tensor = 0.0, std_: torch.Tensor = None
299
+ ) -> torch.Tensor:
300
+ """Evaluates the normal probability density at `x` given the mean `m` and
301
+ standard deviation `std`.
302
+
303
+ Parameters
304
+ ----------
305
+ x: torch.Tensor
306
+ Observations (i.e., noisy image).
307
+ m_: torch.Tensor
308
+ Pixel-wise mean.
309
+ std_: torch.Tensor
310
+ Pixel-wise standard deviation.
311
+
312
+ Returns
313
+ -------
314
+ tmp: torch.Tensor
315
+ Normal probability density of `x` given `m_` and `std_`
316
+ """
317
+ tmp = -((x - m_) ** 2)
318
+ tmp = tmp / (2.0 * std_ * std_)
319
+ tmp = torch.exp(tmp)
320
+ tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
321
+ return tmp
322
+
323
+ def likelihood(
324
+ self, observations: torch.Tensor, signals: torch.Tensor
325
+ ) -> torch.Tensor:
326
+ """Evaluate the likelihood of observations given the signals and the
327
+ corresponding gaussian parameters.
328
+
329
+ Parameters
330
+ ----------
331
+ observations : torch.cuda.FloatTensor
332
+ Noisy observations.
333
+ signals : torch.cuda.FloatTensor
334
+ Underlying signals.
335
+
336
+ Returns
337
+ -------
338
+ value :p + self.tol
339
+ Likelihood of observations given the signals and the GMM noise model
340
+ """
341
+ self.to_device(signals) # move al needed stuff to the same device as `signals``
342
+ gaussianParameters = self.getGaussianParameters(signals)
343
+ p = 0
344
+ for gaussian in range(self.n_gaussian):
345
+ p += (
346
+ self.normalDens(
347
+ x=observations,
348
+ m_=gaussianParameters[gaussian],
349
+ std_=gaussianParameters[self.n_gaussian + gaussian],
350
+ )
351
+ * gaussianParameters[2 * self.n_gaussian + gaussian]
352
+ )
353
+ return p + self.tol
354
+
355
+ def getGaussianParameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
356
+ """Returns the noise model for given signals.
357
+
358
+ Parameters
359
+ ----------
360
+ signals : torch.Tensor
361
+ Underlying signals
362
+
363
+ Returns
364
+ -------
365
+ gmmParams: list[torch.Tensor]
366
+ A list containing tensors representing `mu`, `sigma` and `alpha`
367
+ parameters for the `n_gaussian` gaussians in the mixture.
368
+
369
+ """
370
+ gmmParams = []
371
+ mu = []
372
+ sigma = []
373
+ alpha = []
374
+ kernels = self.weight.shape[0] // 3
375
+ for num in range(kernels):
376
+ # For each Gaussian in the mixture, evaluate mean, std and weight
377
+ mu.append(self.polynomialRegressor(self.weight[num, :], signals))
378
+
379
+ expval = torch.exp(self.weight[kernels + num, :])
380
+ # TODO: why taking the exp? it is not in PPN2V paper...
381
+ sigmaTemp = self.polynomialRegressor(expval, signals)
382
+ sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
383
+ sigma.append(torch.sqrt(sigmaTemp))
384
+
385
+ expval = torch.exp(
386
+ self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
387
+ + self.tol
388
+ )
389
+ alpha.append(expval) # NOTE: these are the numerators of weights
390
+
391
+ sum_alpha = 0
392
+ for al in range(kernels):
393
+ sum_alpha = alpha[al] + sum_alpha
394
+
395
+ # sum of alpha is forced to be 1.
396
+ for ker in range(kernels):
397
+ alpha[ker] = alpha[ker] / sum_alpha
398
+
399
+ sum_means = 0
400
+ # sum_means is the alpha weighted average of the means
401
+ for ker in range(kernels):
402
+ sum_means = alpha[ker] * mu[ker] + sum_means
403
+
404
+ # subtracting the alpha weighted average of the means from the means
405
+ # ensures that the GMM has the inclination to have the mean=signals.
406
+ # TODO: I don't understand why we need to learn the mean?
407
+ for ker in range(kernels):
408
+ mu[ker] = mu[ker] - sum_means + signals
409
+
410
+ for i in range(kernels):
411
+ gmmParams.append(mu[i])
412
+ for j in range(kernels):
413
+ gmmParams.append(sigma[j])
414
+ for k in range(kernels):
415
+ gmmParams.append(alpha[k])
416
+
417
+ return gmmParams
418
+
419
+ # TODO: this is to train the noise model
420
+ def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
421
+ """Returns the Signal-Observation pixel intensities as a two-column array.
422
+
423
+ Parameters
424
+ ----------
425
+ signal : numpy array
426
+ Clean Signal Data
427
+ observation: numpy array
428
+ Noisy observation Data
429
+ lowerClip: float
430
+ Lower percentile bound for clipping.
431
+ upperClip: float
432
+ Upper percentile bound for clipping.
433
+
434
+ Returns
435
+ -------
436
+ gmmParams: list of torch floats
437
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
438
+ """
439
+ lb = np.percentile(signal, lowerClip)
440
+ ub = np.percentile(signal, upperClip)
441
+ stepsize = observation[0].size
442
+ n_observations = observation.shape[0]
443
+ n_signals = signal.shape[0]
444
+ sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
445
+
446
+ for i in range(n_observations):
447
+ j = i // (n_observations // n_signals)
448
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
449
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
450
+ sig_obs_pairs = sig_obs_pairs[
451
+ (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
452
+ ]
453
+ return fastShuffle(sig_obs_pairs, 2)
454
+
455
+ # TODO: what's the use of this method?
456
+ def forward(self, x, y):
457
+ """Temporary dummy forward method."""
458
+ return x, y
459
+
460
+ # TODO taken from pn2v. Ashesh needs to clarify this
461
+ def train_noise_model(
462
+ self,
463
+ signal,
464
+ observation,
465
+ learning_rate=1e-1,
466
+ batchSize=250000,
467
+ n_epochs=2000,
468
+ name="GMMNoiseModel.npz",
469
+ lowerClip=0,
470
+ upperClip=100,
471
+ ):
472
+ """Training to learn the noise model from signal - observation pairs.
473
+
474
+ Parameters
475
+ ----------
476
+ signal: numpy array
477
+ Clean Signal Data
478
+ observation: numpy array
479
+ Noisy Observation Data
480
+ learning_rate: float
481
+ Learning rate. Default = 1e-1.
482
+ batchSize: int
483
+ Nini-batch size. Default = 250000.
484
+ n_epochs: int
485
+ Number of epochs. Default = 2000.
486
+ name: string
487
+
488
+ Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`.
489
+
490
+ lowerClip : int
491
+ Lower percentile for clipping. Default is 0.
492
+ upperClip : int
493
+ Upper percentile for clipping. Default is 100.
494
+
495
+
496
+ """
497
+ sig_obs_pairs = self.getSignalObservationPairs(
498
+ signal, observation, lowerClip, upperClip
499
+ )
500
+ counter = 0
501
+ optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
502
+ for t in range(n_epochs):
503
+
504
+ jointLoss = 0
505
+ if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
506
+ counter = 0
507
+ sig_obs_pairs = fastShuffle(sig_obs_pairs, 1)
508
+
509
+ batch_vectors = sig_obs_pairs[
510
+ counter * batchSize : (counter + 1) * batchSize, :
511
+ ]
512
+ observations = batch_vectors[:, 1].astype(np.float32)
513
+ signals = batch_vectors[:, 0].astype(np.float32)
514
+ # TODO do we absolutely need to move to GPU?
515
+ observations = (
516
+ torch.from_numpy(observations.astype(np.float32)).float().cuda()
517
+ )
518
+ signals = torch.from_numpy(signals).float().cuda()
519
+ p = self.likelihood(observations, signals)
520
+ loss = torch.mean(-torch.log(p))
521
+ jointLoss = jointLoss + loss
522
+
523
+ if t % 100 == 0:
524
+ print(t, jointLoss.item())
525
+
526
+ if t % (int(n_epochs * 0.5)) == 0:
527
+ trained_weight = self.weight.cpu().detach().numpy()
528
+ min_signal = self.min_signal.cpu().detach().numpy()
529
+ max_signal = self.max_signal.cpu().detach().numpy()
530
+ # TODO do we need to save?
531
+ # np.savez(self.path+name, trained_weight=trained_weight, min_signal = min_signal, max_signal = max_signal, min_sigma = self.min_sigma)
532
+
533
+ optimizer.zero_grad()
534
+ jointLoss.backward()
535
+ optimizer.step()
536
+ counter += 1
537
+
538
+ print("===================\n")
539
+ # print("The trained parameters (" + name + ") is saved at location: "+ self.path)
540
+ # TODO return istead of save ?
541
+ return trained_weight, min_signal, max_signal, self.min_sigma