careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from .utils import ModelType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DisentNoiseModel(nn.Module):
|
|
12
|
+
|
|
13
|
+
def __init__(self, *nmodels):
|
|
14
|
+
"""
|
|
15
|
+
Constructor.
|
|
16
|
+
|
|
17
|
+
This class receives as input a variable number of noise models, each one corresponding to a channel.
|
|
18
|
+
"""
|
|
19
|
+
super().__init__()
|
|
20
|
+
# self.nmodels = nmodels
|
|
21
|
+
for i, nmodel in enumerate(nmodels):
|
|
22
|
+
if nmodel is not None:
|
|
23
|
+
self.add_module(f"nmodel_{i}", nmodel)
|
|
24
|
+
|
|
25
|
+
self._nm_cnt = 0
|
|
26
|
+
for nmodel in nmodels:
|
|
27
|
+
if nmodel is not None:
|
|
28
|
+
self._nm_cnt += 1
|
|
29
|
+
|
|
30
|
+
print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
|
|
31
|
+
|
|
32
|
+
def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
|
|
33
|
+
|
|
34
|
+
if obs.shape[1] == 1:
|
|
35
|
+
assert signal.shape[1] == 1
|
|
36
|
+
assert self.n2model is None
|
|
37
|
+
return self.nmodel_0.likelihood(obs, signal)
|
|
38
|
+
|
|
39
|
+
assert obs.shape[1] == self._nm_cnt, f"{obs.shape[1]} != {self._nm_cnt}"
|
|
40
|
+
|
|
41
|
+
ll_list = []
|
|
42
|
+
for ch_idx in range(obs.shape[1]):
|
|
43
|
+
nmodel = getattr(self, f"nmodel_{ch_idx}")
|
|
44
|
+
ll_list.append(
|
|
45
|
+
nmodel.likelihood(
|
|
46
|
+
obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return torch.cat(ll_list, dim=1)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def last2path(fpath: str):
|
|
54
|
+
return os.path.join(*fpath.split("/")[-2:])
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_nm_config(noise_model_fpath: str):
|
|
58
|
+
config_fpath = os.path.join(os.path.dirname(noise_model_fpath), "config.json")
|
|
59
|
+
with open(config_fpath) as f:
|
|
60
|
+
noise_model_config = json.load(f)
|
|
61
|
+
return noise_model_config
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def fastShuffle(series, num):
|
|
65
|
+
length = series.shape[0]
|
|
66
|
+
for i in range(num):
|
|
67
|
+
series = series[np.random.permutation(length), :]
|
|
68
|
+
return series
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_noise_model(
|
|
72
|
+
enable_noise_model: bool,
|
|
73
|
+
model_type: ModelType,
|
|
74
|
+
noise_model_type: str,
|
|
75
|
+
noise_model_ch1_fpath: str,
|
|
76
|
+
noise_model_ch2_fpath: str,
|
|
77
|
+
noise_model_learnable: bool = False,
|
|
78
|
+
denoise_channel: str = "input",
|
|
79
|
+
):
|
|
80
|
+
if enable_noise_model:
|
|
81
|
+
nmodels = []
|
|
82
|
+
# HDN -> one single output -> one single noise model
|
|
83
|
+
if model_type == ModelType.Denoiser:
|
|
84
|
+
if noise_model_type == "hist":
|
|
85
|
+
raise NotImplementedError(
|
|
86
|
+
'"hist" noise model is not supported for now.'
|
|
87
|
+
)
|
|
88
|
+
elif noise_model_type == "gmm":
|
|
89
|
+
if denoise_channel == "Ch1":
|
|
90
|
+
nmodel_fpath = noise_model_ch1_fpath
|
|
91
|
+
print(f"Noise model Ch1: {nmodel_fpath}")
|
|
92
|
+
nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
|
|
93
|
+
nmodel2 = None
|
|
94
|
+
nmodels = [nmodel1, nmodel2]
|
|
95
|
+
elif denoise_channel == "Ch2":
|
|
96
|
+
nmodel_fpath = noise_model_ch2_fpath
|
|
97
|
+
print(f"Noise model Ch2: {nmodel_fpath}")
|
|
98
|
+
nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
|
|
99
|
+
nmodel2 = None
|
|
100
|
+
nmodels = [nmodel1, nmodel2]
|
|
101
|
+
elif denoise_channel == "input":
|
|
102
|
+
nmodel_fpath = noise_model_ch1_fpath
|
|
103
|
+
print(f"Noise model input: {nmodel_fpath}")
|
|
104
|
+
nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
|
|
105
|
+
nmodel2 = None
|
|
106
|
+
nmodels = [nmodel1, nmodel2]
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Invalid denoise_channel: {denoise_channel}")
|
|
109
|
+
# muSplit -> two outputs -> two noise models
|
|
110
|
+
elif noise_model_type == "gmm":
|
|
111
|
+
print(f"Noise model Ch1: {noise_model_ch1_fpath}")
|
|
112
|
+
print(f"Noise model Ch2: {noise_model_ch2_fpath}")
|
|
113
|
+
|
|
114
|
+
nmodel1 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch1_fpath))
|
|
115
|
+
nmodel2 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch2_fpath))
|
|
116
|
+
|
|
117
|
+
nmodels = [nmodel1, nmodel2]
|
|
118
|
+
|
|
119
|
+
# if 'noise_model_ch3_fpath' in config.model:
|
|
120
|
+
# print(f'Noise model Ch3: {config.model.noise_model_ch3_fpath}')
|
|
121
|
+
# nmodel3 = GaussianMixtureNoiseModel(params=np.load(config.model.noise_model_ch3_fpath))
|
|
122
|
+
# nmodels = [nmodel1, nmodel2, nmodel3]
|
|
123
|
+
# else:
|
|
124
|
+
# nmodels = [nmodel1, nmodel2]
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(f"Invalid noise_model_type: {noise_model_type}")
|
|
127
|
+
|
|
128
|
+
if noise_model_learnable:
|
|
129
|
+
for nmodel in nmodels:
|
|
130
|
+
if nmodel is not None:
|
|
131
|
+
nmodel.make_learnable()
|
|
132
|
+
|
|
133
|
+
return DisentNoiseModel(*nmodels)
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class GaussianMixtureNoiseModel(nn.Module):
|
|
138
|
+
"""
|
|
139
|
+
The GaussianMixtureNoiseModel class describes a noise model which is parameterized as a mixture of gaussians.
|
|
140
|
+
If you would like to initialize a new object from scratch, then set `params`= None and specify the other parameters as keyword arguments.
|
|
141
|
+
If you are instead loading a model, use only `params`.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
**kwargs: keyworded, variable-length argument dictionary.
|
|
146
|
+
Arguments include:
|
|
147
|
+
min_signal : float
|
|
148
|
+
Minimum signal intensity expected in the image.
|
|
149
|
+
max_signal : float
|
|
150
|
+
Maximum signal intensity expected in the image.
|
|
151
|
+
path: string
|
|
152
|
+
Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
|
|
153
|
+
weight : array
|
|
154
|
+
A [3*n_gaussian, n_coeff] sized array containing the values of the weights describing the noise model.
|
|
155
|
+
Each gaussian contributes three parameters (mean, standard deviation and weight), hence the number of rows in `weight` are 3*n_gaussian.
|
|
156
|
+
If `weight=None`, the weight array is initialized using the `min_signal` and `max_signal` parameters.
|
|
157
|
+
n_gaussian: int
|
|
158
|
+
Number of gaussians.
|
|
159
|
+
n_coeff: int
|
|
160
|
+
Number of coefficients to describe the functional relationship between gaussian parameters and the signal.
|
|
161
|
+
2 implies a linear relationship, 3 implies a quadratic relationship and so on.
|
|
162
|
+
device: device
|
|
163
|
+
GPU device.
|
|
164
|
+
min_sigma: int
|
|
165
|
+
All values of sigma (`standard deviation`) below min_sigma are clamped to become equal to min_sigma.
|
|
166
|
+
params: dictionary
|
|
167
|
+
Use `params` if one wishes to load a model with trained weights.
|
|
168
|
+
While initializing a new object of the class `GaussianMixtureNoiseModel` from scratch, set this to `None`.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, **kwargs):
|
|
172
|
+
super().__init__()
|
|
173
|
+
self._learnable = False
|
|
174
|
+
|
|
175
|
+
if kwargs.get("params") is None:
|
|
176
|
+
weight = kwargs.get("weight")
|
|
177
|
+
n_gaussian = kwargs.get("n_gaussian")
|
|
178
|
+
n_coeff = kwargs.get("n_coeff")
|
|
179
|
+
min_signal = kwargs.get("min_signal")
|
|
180
|
+
max_signal = kwargs.get("max_signal")
|
|
181
|
+
# self.device = kwargs.get('device')
|
|
182
|
+
self.path = kwargs.get("path")
|
|
183
|
+
self.min_sigma = kwargs.get("min_sigma")
|
|
184
|
+
if weight is None:
|
|
185
|
+
weight = np.random.randn(n_gaussian * 3, n_coeff)
|
|
186
|
+
weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
|
|
187
|
+
weight = torch.from_numpy(
|
|
188
|
+
weight.astype(np.float32)
|
|
189
|
+
).float() # .to(self.device)
|
|
190
|
+
weight = nn.Parameter(weight, requires_grad=True)
|
|
191
|
+
|
|
192
|
+
self.n_gaussian = weight.shape[0] // 3
|
|
193
|
+
self.n_coeff = weight.shape[1]
|
|
194
|
+
self.weight = weight
|
|
195
|
+
self.min_signal = torch.Tensor([min_signal]) # .to(self.device)
|
|
196
|
+
self.max_signal = torch.Tensor([max_signal]) # .to(self.device)
|
|
197
|
+
self.tol = torch.Tensor([1e-10]) # .to(self.device)
|
|
198
|
+
else:
|
|
199
|
+
params = kwargs.get("params")
|
|
200
|
+
# self.device = kwargs.get('device')
|
|
201
|
+
|
|
202
|
+
self.min_signal = torch.Tensor(params["min_signal"]) # .to(self.device)
|
|
203
|
+
self.max_signal = torch.Tensor(params["max_signal"]) # .to(self.device)
|
|
204
|
+
|
|
205
|
+
self.weight = torch.nn.Parameter(
|
|
206
|
+
torch.Tensor(params["trained_weight"]), requires_grad=False
|
|
207
|
+
) # .to(self.device)
|
|
208
|
+
self.min_sigma = params["min_sigma"].item()
|
|
209
|
+
self.n_gaussian = self.weight.shape[0] // 3
|
|
210
|
+
self.n_coeff = self.weight.shape[1]
|
|
211
|
+
self.tol = torch.Tensor([1e-10]) # .to(self.device)
|
|
212
|
+
self.min_signal = torch.Tensor([self.min_signal]) # .to(self.device)
|
|
213
|
+
self.max_signal = torch.Tensor([self.max_signal]) # .to(self.device)
|
|
214
|
+
|
|
215
|
+
print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
|
|
216
|
+
|
|
217
|
+
def make_learnable(self):
|
|
218
|
+
print(f"[{self.__class__.__name__}] Making noise model learnable")
|
|
219
|
+
|
|
220
|
+
self._learnable = True
|
|
221
|
+
self.weight.requires_grad = True
|
|
222
|
+
|
|
223
|
+
#
|
|
224
|
+
|
|
225
|
+
def to_device(self, cuda_tensor):
|
|
226
|
+
# move everything to GPU
|
|
227
|
+
if self.min_signal.device != cuda_tensor.device:
|
|
228
|
+
self.max_signal = self.max_signal.to(cuda_tensor.device)
|
|
229
|
+
self.min_signal = self.min_signal.to(cuda_tensor.device)
|
|
230
|
+
self.tol = self.tol.to(cuda_tensor.device)
|
|
231
|
+
self.weight = self.weight.to(cuda_tensor.device)
|
|
232
|
+
if self._learnable:
|
|
233
|
+
self.weight.requires_grad = True
|
|
234
|
+
|
|
235
|
+
def polynomialRegressor(self, weightParams, signals):
|
|
236
|
+
"""Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
weightParams : torch.cuda.FloatTensor
|
|
241
|
+
Corresponds to specific rows of the `self.weight`
|
|
242
|
+
signals : torch.cuda.FloatTensor
|
|
243
|
+
Signals
|
|
244
|
+
|
|
245
|
+
Returns
|
|
246
|
+
-------
|
|
247
|
+
value : torch.cuda.FloatTensor
|
|
248
|
+
Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
|
|
249
|
+
"""
|
|
250
|
+
value = 0
|
|
251
|
+
for i in range(weightParams.shape[0]):
|
|
252
|
+
value += weightParams[i] * (
|
|
253
|
+
((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
|
|
254
|
+
)
|
|
255
|
+
return value
|
|
256
|
+
|
|
257
|
+
def normalDens(self, x, m_=0.0, std_=None):
|
|
258
|
+
"""Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
x: torch.cuda.FloatTensor
|
|
263
|
+
Observations
|
|
264
|
+
m_: torch.cuda.FloatTensor
|
|
265
|
+
Mean
|
|
266
|
+
std_: torch.cuda.FloatTensor
|
|
267
|
+
Standard-deviation
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
tmp: torch.cuda.FloatTensor
|
|
272
|
+
Normal probability density of `x` given `m_` and `std_`
|
|
273
|
+
"""
|
|
274
|
+
tmp = -((x - m_) ** 2)
|
|
275
|
+
tmp = tmp / (2.0 * std_ * std_)
|
|
276
|
+
tmp = torch.exp(tmp)
|
|
277
|
+
tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
|
|
278
|
+
return tmp
|
|
279
|
+
|
|
280
|
+
def likelihood(self, observations, signals):
|
|
281
|
+
"""Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
observations : torch.cuda.FloatTensor
|
|
286
|
+
Noisy observations
|
|
287
|
+
signals : torch.cuda.FloatTensor
|
|
288
|
+
Underlying signals
|
|
289
|
+
|
|
290
|
+
Returns
|
|
291
|
+
-------
|
|
292
|
+
value :p + self.tol
|
|
293
|
+
Likelihood of observations given the signals and the GMM noise model
|
|
294
|
+
"""
|
|
295
|
+
self.to_device(signals)
|
|
296
|
+
gaussianParameters = self.getGaussianParameters(signals)
|
|
297
|
+
p = 0
|
|
298
|
+
for gaussian in range(self.n_gaussian):
|
|
299
|
+
p += (
|
|
300
|
+
self.normalDens(
|
|
301
|
+
observations,
|
|
302
|
+
gaussianParameters[gaussian],
|
|
303
|
+
gaussianParameters[self.n_gaussian + gaussian],
|
|
304
|
+
)
|
|
305
|
+
* gaussianParameters[2 * self.n_gaussian + gaussian]
|
|
306
|
+
)
|
|
307
|
+
return p + self.tol
|
|
308
|
+
|
|
309
|
+
def getGaussianParameters(self, signals):
|
|
310
|
+
"""Returns the noise model for given signals
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
signals : torch.cuda.FloatTensor
|
|
315
|
+
Underlying signals
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
noiseModel: list of torch.cuda.FloatTensor
|
|
320
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
321
|
+
|
|
322
|
+
"""
|
|
323
|
+
noiseModel = []
|
|
324
|
+
mu = []
|
|
325
|
+
sigma = []
|
|
326
|
+
alpha = []
|
|
327
|
+
kernels = self.weight.shape[0] // 3
|
|
328
|
+
for num in range(kernels):
|
|
329
|
+
mu.append(self.polynomialRegressor(self.weight[num, :], signals))
|
|
330
|
+
# expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W))
|
|
331
|
+
expval = torch.exp(self.weight[kernels + num, :])
|
|
332
|
+
# self.maxval = max(self.maxval, expval.max().item())
|
|
333
|
+
sigmaTemp = self.polynomialRegressor(expval, signals)
|
|
334
|
+
sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
|
|
335
|
+
sigma.append(torch.sqrt(sigmaTemp))
|
|
336
|
+
|
|
337
|
+
# expval = torch.exp(
|
|
338
|
+
# torch.clamp(
|
|
339
|
+
# self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + self.tol, MAX_ALPHA_W))
|
|
340
|
+
expval = torch.exp(
|
|
341
|
+
self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
|
|
342
|
+
+ self.tol
|
|
343
|
+
)
|
|
344
|
+
# self.maxval = max(self.maxval, expval.max().item())
|
|
345
|
+
alpha.append(expval)
|
|
346
|
+
|
|
347
|
+
sum_alpha = 0
|
|
348
|
+
for al in range(kernels):
|
|
349
|
+
sum_alpha = alpha[al] + sum_alpha
|
|
350
|
+
|
|
351
|
+
# sum of alpha is forced to be 1.
|
|
352
|
+
for ker in range(kernels):
|
|
353
|
+
alpha[ker] = alpha[ker] / sum_alpha
|
|
354
|
+
|
|
355
|
+
sum_means = 0
|
|
356
|
+
# sum_means is the alpha weighted average of the means
|
|
357
|
+
for ker in range(kernels):
|
|
358
|
+
sum_means = alpha[ker] * mu[ker] + sum_means
|
|
359
|
+
|
|
360
|
+
mu_shifted = []
|
|
361
|
+
# subtracting the alpha weighted average of the means from the means
|
|
362
|
+
# ensures that the GMM has the inclination to have the mean=signals.
|
|
363
|
+
# its like a residual conection. I don't understand why we need to learn the mean?
|
|
364
|
+
for ker in range(kernels):
|
|
365
|
+
mu[ker] = mu[ker] - sum_means + signals
|
|
366
|
+
|
|
367
|
+
for i in range(kernels):
|
|
368
|
+
noiseModel.append(mu[i])
|
|
369
|
+
for j in range(kernels):
|
|
370
|
+
noiseModel.append(sigma[j])
|
|
371
|
+
for k in range(kernels):
|
|
372
|
+
noiseModel.append(alpha[k])
|
|
373
|
+
|
|
374
|
+
return noiseModel
|
|
375
|
+
|
|
376
|
+
def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
|
|
377
|
+
"""Returns the Signal-Observation pixel intensities as a two-column array
|
|
378
|
+
|
|
379
|
+
Parameters
|
|
380
|
+
----------
|
|
381
|
+
signal : numpy array
|
|
382
|
+
Clean Signal Data
|
|
383
|
+
observation: numpy array
|
|
384
|
+
Noisy observation Data
|
|
385
|
+
lowerClip: float
|
|
386
|
+
Lower percentile bound for clipping.
|
|
387
|
+
upperClip: float
|
|
388
|
+
Upper percentile bound for clipping.
|
|
389
|
+
|
|
390
|
+
Returns
|
|
391
|
+
-------
|
|
392
|
+
noiseModel: list of torch floats
|
|
393
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
394
|
+
"""
|
|
395
|
+
lb = np.percentile(signal, lowerClip)
|
|
396
|
+
ub = np.percentile(signal, upperClip)
|
|
397
|
+
stepsize = observation[0].size
|
|
398
|
+
n_observations = observation.shape[0]
|
|
399
|
+
n_signals = signal.shape[0]
|
|
400
|
+
sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
|
|
401
|
+
|
|
402
|
+
for i in range(n_observations):
|
|
403
|
+
j = i // (n_observations // n_signals)
|
|
404
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
|
|
405
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
|
|
406
|
+
sig_obs_pairs = sig_obs_pairs[
|
|
407
|
+
(sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
|
|
408
|
+
]
|
|
409
|
+
return fastShuffle(sig_obs_pairs, 2)
|