careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,26 +1,108 @@
1
- import json
2
- import os
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional
3
4
 
4
5
  import numpy as np
5
6
  import torch
6
7
  import torch.nn as nn
7
8
 
8
- from .utils import ModelType
9
+ if TYPE_CHECKING:
10
+ from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
9
11
 
12
+ # TODO this module shouldn't be in lvae folder
10
13
 
11
- class DisentNoiseModel(nn.Module):
12
14
 
13
- def __init__(self, *nmodels):
14
- """
15
- Constructor.
15
+ def noise_model_factory(
16
+ model_config: Optional[MultiChannelNMConfig],
17
+ ) -> Optional[MultiChannelNoiseModel]:
18
+ """Noise model factory.
19
+
20
+ Parameters
21
+ ----------
22
+ model_config : Optional[MultiChannelNMConfig]
23
+ Noise model configuration, a `MultiChannelNMConfig` config that defines
24
+ noise models for the different output channels.
25
+
26
+ Returns
27
+ -------
28
+ Optional[MultiChannelNoiseModel]
29
+ A noise model instance.
30
+
31
+ Raises
32
+ ------
33
+ NotImplementedError
34
+ If the chosen noise model `model_type` is not implemented.
35
+ Currently only `GaussianMixtureNoiseModel` is implemented.
36
+ """
37
+ if model_config:
38
+ noise_models = []
39
+ for nm_config in model_config.noise_models:
40
+ if nm_config.path:
41
+ if nm_config.model_type == "GaussianMixtureNoiseModel":
42
+ noise_models.append(GaussianMixtureNoiseModel(nm_config))
43
+ else:
44
+ raise NotImplementedError(
45
+ f"Model {nm_config.model_type} is not implemented"
46
+ )
47
+
48
+ else: # TODO this means signal/obs are provided. Controlled in pydantic model
49
+ # TODO train a new model. Config should always be provided?
50
+ if nm_config.model_type == "GaussianMixtureNoiseModel":
51
+ trained_nm = train_gm_noise_model(nm_config)
52
+ noise_models.append(trained_nm)
53
+ else:
54
+ raise NotImplementedError(
55
+ f"Model {nm_config.model_type} is not implemented"
56
+ )
57
+ return MultiChannelNoiseModel(noise_models)
58
+ return None
59
+
60
+
61
+ def train_gm_noise_model(
62
+ model_config: GaussianMixtureNMConfig,
63
+ ) -> GaussianMixtureNoiseModel:
64
+ """Train a Gaussian mixture noise model.
65
+
66
+ Parameters
67
+ ----------
68
+ model_config : GaussianMixtureNoiseModel
69
+ _description_
70
+
71
+ Returns
72
+ -------
73
+ _description_
74
+ """
75
+ # TODO where to put train params?
76
+ # TODO any training params ? Different channels ?
77
+ noise_model = GaussianMixtureNoiseModel(model_config)
78
+ # TODO revisit config unpacking
79
+ noise_model.train_noise_model(model_config.signal, model_config.observation)
80
+ return noise_model
81
+
16
82
 
17
- This class receives as input a variable number of noise models, each one corresponding to a channel.
83
+ class MultiChannelNoiseModel(nn.Module):
84
+ def __init__(self, nmodels: list[GaussianMixtureNoiseModel]):
85
+ """Constructor.
86
+
87
+ To handle noise models and the relative likelihood computation for multiple
88
+ output channels (e.g., muSplit, denoiseSplit).
89
+
90
+ This class:
91
+ - receives as input a variable number of noise models, one for each channel.
92
+ - computes the likelihood of observations given signals for each channel.
93
+ - returns the concatenation of these likelihoods.
94
+
95
+ Parameters
96
+ ----------
97
+ nmodels : list[GaussianMixtureNoiseModel]
98
+ List of noise models, one for each output channel.
18
99
  """
19
100
  super().__init__()
20
- # self.nmodels = nmodels
21
101
  for i, nmodel in enumerate(nmodels):
22
102
  if nmodel is not None:
23
- self.add_module(f"nmodel_{i}", nmodel)
103
+ self.add_module(
104
+ f"nmodel_{i}", nmodel
105
+ ) # TODO: wouldn't be easier to use a list?
24
106
 
25
107
  self._nm_cnt = 0
26
108
  for nmodel in nmodels:
@@ -30,181 +112,141 @@ class DisentNoiseModel(nn.Module):
30
112
  print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
31
113
 
32
114
  def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
115
+ """Compute the likelihood of observations given signals for each channel.
33
116
 
117
+ Parameters
118
+ ----------
119
+ obs : torch.Tensor
120
+ Noisy observations, i.e., the target(s). Specifically, the input noisy
121
+ image for HDN, or the noisy unmixed images used for supervision
122
+ for denoiSplit. Shape: (B, C, [Z], Y, X), where C is the number of
123
+ unmixed channels.
124
+ signal : torch.Tensor
125
+ Underlying signals, i.e., the (clean) output of the model. Specifically, the
126
+ denoised image for HDN, or the unmixed images for denoiSplit.
127
+ Shape: (B, C, [Z], Y, X), where C is the number of unmixed channels.
128
+ """
129
+ # Case 1: obs and signal have a single channel (e.g., denoising)
34
130
  if obs.shape[1] == 1:
35
131
  assert signal.shape[1] == 1
36
- assert self.n2model is None
37
132
  return self.nmodel_0.likelihood(obs, signal)
38
133
 
39
- assert obs.shape[1] == self._nm_cnt, f"{obs.shape[1]} != {self._nm_cnt}"
40
-
134
+ # Case 2: obs and signal have multiple channels (e.g., denoiSplit)
135
+ assert obs.shape[1] == self._nm_cnt, (
136
+ "The number of channels in `obs` must match the number of noise models."
137
+ f" Got instead: obs={obs.shape[1]}, nm={self._nm_cnt}"
138
+ )
41
139
  ll_list = []
42
140
  for ch_idx in range(obs.shape[1]):
43
141
  nmodel = getattr(self, f"nmodel_{ch_idx}")
44
142
  ll_list.append(
45
143
  nmodel.likelihood(
46
144
  obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
47
- )
145
+ ) # slicing to keep the channel dimension
48
146
  )
49
-
50
147
  return torch.cat(ll_list, dim=1)
51
148
 
52
149
 
53
- def last2path(fpath: str):
54
- return os.path.join(*fpath.split("/")[-2:])
55
-
56
-
57
- def get_nm_config(noise_model_fpath: str):
58
- config_fpath = os.path.join(os.path.dirname(noise_model_fpath), "config.json")
59
- with open(config_fpath) as f:
60
- noise_model_config = json.load(f)
61
- return noise_model_config
62
-
63
-
150
+ # TODO: is this needed?
64
151
  def fastShuffle(series, num):
152
+ """_summary_.
153
+
154
+ Parameters
155
+ ----------
156
+ series : _type_
157
+ _description_
158
+ num : _type_
159
+ _description_
160
+
161
+ Returns
162
+ -------
163
+ _type_
164
+ _description_
165
+ """
65
166
  length = series.shape[0]
66
- for i in range(num):
167
+ for _ in range(num):
67
168
  series = series[np.random.permutation(length), :]
68
169
  return series
69
170
 
70
171
 
71
- def get_noise_model(
72
- enable_noise_model: bool,
73
- model_type: ModelType,
74
- noise_model_type: str,
75
- noise_model_ch1_fpath: str,
76
- noise_model_ch2_fpath: str,
77
- noise_model_learnable: bool = False,
78
- denoise_channel: str = "input",
79
- ):
80
- if enable_noise_model:
81
- nmodels = []
82
- # HDN -> one single output -> one single noise model
83
- if model_type == ModelType.Denoiser:
84
- if noise_model_type == "hist":
85
- raise NotImplementedError(
86
- '"hist" noise model is not supported for now.'
87
- )
88
- elif noise_model_type == "gmm":
89
- if denoise_channel == "Ch1":
90
- nmodel_fpath = noise_model_ch1_fpath
91
- print(f"Noise model Ch1: {nmodel_fpath}")
92
- nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
93
- nmodel2 = None
94
- nmodels = [nmodel1, nmodel2]
95
- elif denoise_channel == "Ch2":
96
- nmodel_fpath = noise_model_ch2_fpath
97
- print(f"Noise model Ch2: {nmodel_fpath}")
98
- nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
99
- nmodel2 = None
100
- nmodels = [nmodel1, nmodel2]
101
- elif denoise_channel == "input":
102
- nmodel_fpath = noise_model_ch1_fpath
103
- print(f"Noise model input: {nmodel_fpath}")
104
- nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
105
- nmodel2 = None
106
- nmodels = [nmodel1, nmodel2]
107
- else:
108
- raise ValueError(f"Invalid denoise_channel: {denoise_channel}")
109
- # muSplit -> two outputs -> two noise models
110
- elif noise_model_type == "gmm":
111
- print(f"Noise model Ch1: {noise_model_ch1_fpath}")
112
- print(f"Noise model Ch2: {noise_model_ch2_fpath}")
113
-
114
- nmodel1 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch1_fpath))
115
- nmodel2 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch2_fpath))
116
-
117
- nmodels = [nmodel1, nmodel2]
118
-
119
- # if 'noise_model_ch3_fpath' in config.model:
120
- # print(f'Noise model Ch3: {config.model.noise_model_ch3_fpath}')
121
- # nmodel3 = GaussianMixtureNoiseModel(params=np.load(config.model.noise_model_ch3_fpath))
122
- # nmodels = [nmodel1, nmodel2, nmodel3]
123
- # else:
124
- # nmodels = [nmodel1, nmodel2]
125
- else:
126
- raise ValueError(f"Invalid noise_model_type: {noise_model_type}")
127
-
128
- if noise_model_learnable:
129
- for nmodel in nmodels:
130
- if nmodel is not None:
131
- nmodel.make_learnable()
132
-
133
- return DisentNoiseModel(*nmodels)
134
- return None
135
-
136
-
137
172
  class GaussianMixtureNoiseModel(nn.Module):
138
- """
139
- The GaussianMixtureNoiseModel class describes a noise model which is parameterized as a mixture of gaussians.
140
- If you would like to initialize a new object from scratch, then set `params`= None and specify the other parameters as keyword arguments.
141
- If you are instead loading a model, use only `params`.
173
+ """Define a noise model parameterized as a mixture of gaussians.
174
+
175
+ If `config.path` is not provided a new object is initialized from scratch.
176
+ Otherwise, a model is loaded from `config.path`.
142
177
 
143
178
  Parameters
144
179
  ----------
145
- **kwargs: keyworded, variable-length argument dictionary.
146
- Arguments include:
147
- min_signal : float
148
- Minimum signal intensity expected in the image.
149
- max_signal : float
150
- Maximum signal intensity expected in the image.
151
- path: string
152
- Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
153
- weight : array
154
- A [3*n_gaussian, n_coeff] sized array containing the values of the weights describing the noise model.
155
- Each gaussian contributes three parameters (mean, standard deviation and weight), hence the number of rows in `weight` are 3*n_gaussian.
156
- If `weight=None`, the weight array is initialized using the `min_signal` and `max_signal` parameters.
157
- n_gaussian: int
158
- Number of gaussians.
159
- n_coeff: int
160
- Number of coefficients to describe the functional relationship between gaussian parameters and the signal.
161
- 2 implies a linear relationship, 3 implies a quadratic relationship and so on.
162
- device: device
163
- GPU device.
164
- min_sigma: int
165
- All values of sigma (`standard deviation`) below min_sigma are clamped to become equal to min_sigma.
166
- params: dictionary
167
- Use `params` if one wishes to load a model with trained weights.
168
- While initializing a new object of the class `GaussianMixtureNoiseModel` from scratch, set this to `None`.
180
+ config : GaussianMixtureNMConfig
181
+ A `pydantic` model that defines the configuration of the GMM noise model.
182
+
183
+ Attributes
184
+ ----------
185
+ min_signal : float
186
+ Minimum signal intensity expected in the image.
187
+ max_signal : float
188
+ Maximum signal intensity expected in the image.
189
+ path: Union[str, Path]
190
+ Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
191
+ weight : torch.nn.Parameter
192
+ A [3*n_gaussian, n_coeff] sized array containing the values of the weights
193
+ describing the GMM noise model, with each row corresponding to one
194
+ parameter of each gaussian, namely [mean, standard deviation and weight].
195
+ Specifically, rows are organized as follows:
196
+ - first n_gaussian rows correspond to the means
197
+ - next n_gaussian rows correspond to the weights
198
+ - last n_gaussian rows correspond to the standard deviations
199
+ If `weight=None`, the weight array is initialized using the `min_signal`
200
+ and `max_signal` parameters.
201
+ n_gaussian: int
202
+ Number of gaussians in the mixture.
203
+ n_coeff: int
204
+ Number of coefficients to describe the functional relationship between gaussian
205
+ parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
206
+ relationship and so on.
207
+ device: device
208
+ GPU device.
209
+ min_sigma: float
210
+ All values of `standard deviation` below this are clamped to this value.
169
211
  """
170
212
 
171
- def __init__(self, **kwargs):
213
+ # TODO training a NM relies on getting a clean data(N2V e.g,)
214
+ def __init__(self, config: GaussianMixtureNMConfig):
172
215
  super().__init__()
173
216
  self._learnable = False
174
217
 
175
- if kwargs.get("params") is None:
176
- weight = kwargs.get("weight")
177
- n_gaussian = kwargs.get("n_gaussian")
178
- n_coeff = kwargs.get("n_coeff")
179
- min_signal = kwargs.get("min_signal")
180
- max_signal = kwargs.get("max_signal")
218
+ if config.path is None:
219
+ # TODO this is (probably) to train a nm. We leave it for later refactoring
220
+ weight = config.weight
221
+ n_gaussian = config.n_gaussian
222
+ n_coeff = config.n_coeff
223
+ min_signal = config.min_signal
224
+ max_signal = config.max_signal
181
225
  # self.device = kwargs.get('device')
182
- self.path = kwargs.get("path")
183
- self.min_sigma = kwargs.get("min_sigma")
226
+ # TODO min_sigma cant be None ?
227
+ self.min_sigma = config.min_sigma
184
228
  if weight is None:
185
229
  weight = np.random.randn(n_gaussian * 3, n_coeff)
186
230
  weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
187
- weight = torch.from_numpy(
188
- weight.astype(np.float32)
189
- ).float() # .to(self.device)
190
- weight = nn.Parameter(weight, requires_grad=True)
231
+ weight = torch.from_numpy(weight.astype(np.float32)).float().cuda()
232
+ weight.requires_grad = True
191
233
 
192
234
  self.n_gaussian = weight.shape[0] // 3
193
235
  self.n_coeff = weight.shape[1]
194
236
  self.weight = weight
195
- self.min_signal = torch.Tensor([min_signal]) # .to(self.device)
196
- self.max_signal = torch.Tensor([max_signal]) # .to(self.device)
197
- self.tol = torch.Tensor([1e-10]) # .to(self.device)
237
+ self.min_signal = torch.Tensor([min_signal])
238
+ self.max_signal = torch.Tensor([max_signal])
239
+ self.tol = torch.Tensor([1e-10])
198
240
  else:
199
- params = kwargs.get("params")
241
+ params = np.load(config.path)
200
242
  # self.device = kwargs.get('device')
201
243
 
202
- self.min_signal = torch.Tensor(params["min_signal"]) # .to(self.device)
203
- self.max_signal = torch.Tensor(params["max_signal"]) # .to(self.device)
244
+ self.min_signal = torch.Tensor(params["min_signal"])
245
+ self.max_signal = torch.Tensor(params["max_signal"])
204
246
 
205
247
  self.weight = torch.nn.Parameter(
206
248
  torch.Tensor(params["trained_weight"]), requires_grad=False
207
- ) # .to(self.device)
249
+ )
208
250
  self.min_sigma = params["min_sigma"].item()
209
251
  self.n_gaussian = self.weight.shape[0] // 3
210
252
  self.n_coeff = self.weight.shape[1]
@@ -216,19 +258,17 @@ class GaussianMixtureNoiseModel(nn.Module):
216
258
 
217
259
  def make_learnable(self):
218
260
  print(f"[{self.__class__.__name__}] Making noise model learnable")
219
-
220
261
  self._learnable = True
221
262
  self.weight.requires_grad = True
222
263
 
223
- #
224
-
225
264
  def to_device(self, cuda_tensor):
265
+ # TODO wtf is this ?
226
266
  # move everything to GPU
227
267
  if self.min_signal.device != cuda_tensor.device:
228
- self.max_signal = self.max_signal.to(cuda_tensor.device)
229
- self.min_signal = self.min_signal.to(cuda_tensor.device)
230
- self.tol = self.tol.to(cuda_tensor.device)
231
- self.weight = self.weight.to(cuda_tensor.device)
268
+ self.max_signal = self.max_signal.cuda()
269
+ self.min_signal = self.min_signal.cuda()
270
+ self.tol = self.tol.cuda()
271
+ # self.weight = self.weight.cuda()
232
272
  if self._learnable:
233
273
  self.weight.requires_grad = True
234
274
 
@@ -254,21 +294,24 @@ class GaussianMixtureNoiseModel(nn.Module):
254
294
  )
255
295
  return value
256
296
 
257
- def normalDens(self, x, m_=0.0, std_=None):
258
- """Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`.
297
+ def normalDens(
298
+ self, x: torch.Tensor, m_: torch.Tensor = 0.0, std_: torch.Tensor = None
299
+ ) -> torch.Tensor:
300
+ """Evaluates the normal probability density at `x` given the mean `m` and
301
+ standard deviation `std`.
259
302
 
260
303
  Parameters
261
304
  ----------
262
- x: torch.cuda.FloatTensor
263
- Observations
264
- m_: torch.cuda.FloatTensor
265
- Mean
266
- std_: torch.cuda.FloatTensor
267
- Standard-deviation
305
+ x: torch.Tensor
306
+ Observations (i.e., noisy image).
307
+ m_: torch.Tensor
308
+ Pixel-wise mean.
309
+ std_: torch.Tensor
310
+ Pixel-wise standard deviation.
268
311
 
269
312
  Returns
270
313
  -------
271
- tmp: torch.cuda.FloatTensor
314
+ tmp: torch.Tensor
272
315
  Normal probability density of `x` given `m_` and `std_`
273
316
  """
274
317
  tmp = -((x - m_) ** 2)
@@ -277,72 +320,73 @@ class GaussianMixtureNoiseModel(nn.Module):
277
320
  tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
278
321
  return tmp
279
322
 
280
- def likelihood(self, observations, signals):
281
- """Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
323
+ def likelihood(
324
+ self, observations: torch.Tensor, signals: torch.Tensor
325
+ ) -> torch.Tensor:
326
+ """Evaluate the likelihood of observations given the signals and the
327
+ corresponding gaussian parameters.
282
328
 
283
329
  Parameters
284
330
  ----------
285
331
  observations : torch.cuda.FloatTensor
286
- Noisy observations
332
+ Noisy observations.
287
333
  signals : torch.cuda.FloatTensor
288
- Underlying signals
334
+ Underlying signals.
289
335
 
290
336
  Returns
291
337
  -------
292
338
  value :p + self.tol
293
339
  Likelihood of observations given the signals and the GMM noise model
294
340
  """
295
- self.to_device(signals)
341
+ self.to_device(signals) # move al needed stuff to the same device as `signals``
296
342
  gaussianParameters = self.getGaussianParameters(signals)
297
343
  p = 0
298
344
  for gaussian in range(self.n_gaussian):
299
345
  p += (
300
346
  self.normalDens(
301
- observations,
302
- gaussianParameters[gaussian],
303
- gaussianParameters[self.n_gaussian + gaussian],
347
+ x=observations,
348
+ m_=gaussianParameters[gaussian],
349
+ std_=gaussianParameters[self.n_gaussian + gaussian],
304
350
  )
305
351
  * gaussianParameters[2 * self.n_gaussian + gaussian]
306
352
  )
307
353
  return p + self.tol
308
354
 
309
- def getGaussianParameters(self, signals):
310
- """Returns the noise model for given signals
355
+ def getGaussianParameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
356
+ """Returns the noise model for given signals.
311
357
 
312
358
  Parameters
313
359
  ----------
314
- signals : torch.cuda.FloatTensor
360
+ signals : torch.Tensor
315
361
  Underlying signals
316
362
 
317
363
  Returns
318
364
  -------
319
- noiseModel: list of torch.cuda.FloatTensor
320
- Contains a list of `mu`, `sigma` and `alpha` for the `signals`
365
+ gmmParams: list[torch.Tensor]
366
+ A list containing tensors representing `mu`, `sigma` and `alpha`
367
+ parameters for the `n_gaussian` gaussians in the mixture.
321
368
 
322
369
  """
323
- noiseModel = []
370
+ gmmParams = []
324
371
  mu = []
325
372
  sigma = []
326
373
  alpha = []
327
374
  kernels = self.weight.shape[0] // 3
328
375
  for num in range(kernels):
376
+ # For each Gaussian in the mixture, evaluate mean, std and weight
329
377
  mu.append(self.polynomialRegressor(self.weight[num, :], signals))
330
- # expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W))
378
+
331
379
  expval = torch.exp(self.weight[kernels + num, :])
332
- # self.maxval = max(self.maxval, expval.max().item())
380
+ # TODO: why taking the exp? it is not in PPN2V paper...
333
381
  sigmaTemp = self.polynomialRegressor(expval, signals)
334
382
  sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
335
383
  sigma.append(torch.sqrt(sigmaTemp))
336
384
 
337
- # expval = torch.exp(
338
- # torch.clamp(
339
- # self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + self.tol, MAX_ALPHA_W))
340
385
  expval = torch.exp(
341
386
  self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
342
387
  + self.tol
343
388
  )
344
- # self.maxval = max(self.maxval, expval.max().item())
345
- alpha.append(expval)
389
+ alpha.append(expval) # NOTE: these are the numerators of weights
346
390
 
347
391
  sum_alpha = 0
348
392
  for al in range(kernels):
@@ -357,24 +401,24 @@ class GaussianMixtureNoiseModel(nn.Module):
357
401
  for ker in range(kernels):
358
402
  sum_means = alpha[ker] * mu[ker] + sum_means
359
403
 
360
- mu_shifted = []
361
404
  # subtracting the alpha weighted average of the means from the means
362
405
  # ensures that the GMM has the inclination to have the mean=signals.
363
- # its like a residual conection. I don't understand why we need to learn the mean?
406
+ # TODO: I don't understand why we need to learn the mean?
364
407
  for ker in range(kernels):
365
408
  mu[ker] = mu[ker] - sum_means + signals
366
409
 
367
410
  for i in range(kernels):
368
- noiseModel.append(mu[i])
411
+ gmmParams.append(mu[i])
369
412
  for j in range(kernels):
370
- noiseModel.append(sigma[j])
413
+ gmmParams.append(sigma[j])
371
414
  for k in range(kernels):
372
- noiseModel.append(alpha[k])
415
+ gmmParams.append(alpha[k])
373
416
 
374
- return noiseModel
417
+ return gmmParams
375
418
 
419
+ # TODO: this is to train the noise model
376
420
  def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
377
- """Returns the Signal-Observation pixel intensities as a two-column array
421
+ """Returns the Signal-Observation pixel intensities as a two-column array.
378
422
 
379
423
  Parameters
380
424
  ----------
@@ -389,7 +433,7 @@ class GaussianMixtureNoiseModel(nn.Module):
389
433
 
390
434
  Returns
391
435
  -------
392
- noiseModel: list of torch floats
436
+ gmmParams: list of torch floats
393
437
  Contains a list of `mu`, `sigma` and `alpha` for the `signals`
394
438
  """
395
439
  lb = np.percentile(signal, lowerClip)
@@ -407,3 +451,91 @@ class GaussianMixtureNoiseModel(nn.Module):
407
451
  (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
408
452
  ]
409
453
  return fastShuffle(sig_obs_pairs, 2)
454
+
455
+ # TODO: what's the use of this method?
456
+ def forward(self, x, y):
457
+ """Temporary dummy forward method."""
458
+ return x, y
459
+
460
+ # TODO taken from pn2v. Ashesh needs to clarify this
461
+ def train_noise_model(
462
+ self,
463
+ signal,
464
+ observation,
465
+ learning_rate=1e-1,
466
+ batchSize=250000,
467
+ n_epochs=2000,
468
+ name="GMMNoiseModel.npz",
469
+ lowerClip=0,
470
+ upperClip=100,
471
+ ):
472
+ """Training to learn the noise model from signal - observation pairs.
473
+
474
+ Parameters
475
+ ----------
476
+ signal: numpy array
477
+ Clean Signal Data
478
+ observation: numpy array
479
+ Noisy Observation Data
480
+ learning_rate: float
481
+ Learning rate. Default = 1e-1.
482
+ batchSize: int
483
+ Nini-batch size. Default = 250000.
484
+ n_epochs: int
485
+ Number of epochs. Default = 2000.
486
+ name: string
487
+
488
+ Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`.
489
+
490
+ lowerClip : int
491
+ Lower percentile for clipping. Default is 0.
492
+ upperClip : int
493
+ Upper percentile for clipping. Default is 100.
494
+
495
+
496
+ """
497
+ sig_obs_pairs = self.getSignalObservationPairs(
498
+ signal, observation, lowerClip, upperClip
499
+ )
500
+ counter = 0
501
+ optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
502
+ for t in range(n_epochs):
503
+
504
+ jointLoss = 0
505
+ if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
506
+ counter = 0
507
+ sig_obs_pairs = fastShuffle(sig_obs_pairs, 1)
508
+
509
+ batch_vectors = sig_obs_pairs[
510
+ counter * batchSize : (counter + 1) * batchSize, :
511
+ ]
512
+ observations = batch_vectors[:, 1].astype(np.float32)
513
+ signals = batch_vectors[:, 0].astype(np.float32)
514
+ # TODO do we absolutely need to move to GPU?
515
+ observations = (
516
+ torch.from_numpy(observations.astype(np.float32)).float().cuda()
517
+ )
518
+ signals = torch.from_numpy(signals).float().cuda()
519
+ p = self.likelihood(observations, signals)
520
+ loss = torch.mean(-torch.log(p))
521
+ jointLoss = jointLoss + loss
522
+
523
+ if t % 100 == 0:
524
+ print(t, jointLoss.item())
525
+
526
+ if t % (int(n_epochs * 0.5)) == 0:
527
+ trained_weight = self.weight.cpu().detach().numpy()
528
+ min_signal = self.min_signal.cpu().detach().numpy()
529
+ max_signal = self.max_signal.cpu().detach().numpy()
530
+ # TODO do we need to save?
531
+ # np.savez(self.path+name, trained_weight=trained_weight, min_signal = min_signal, max_signal = max_signal, min_sigma = self.min_sigma)
532
+
533
+ optimizer.zero_grad()
534
+ jointLoss.backward()
535
+ optimizer.step()
536
+ counter += 1
537
+
538
+ print("===================\n")
539
+ # print("The trained parameters (" + name + ") is saved at location: "+ self.path)
540
+ # TODO return istead of save ?
541
+ return trained_weight, min_signal, max_signal, self.min_sigma