careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,409 @@
1
+ import json
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .utils import ModelType
9
+
10
+
11
+ class DisentNoiseModel(nn.Module):
12
+
13
+ def __init__(self, *nmodels):
14
+ """
15
+ Constructor.
16
+
17
+ This class receives as input a variable number of noise models, each one corresponding to a channel.
18
+ """
19
+ super().__init__()
20
+ # self.nmodels = nmodels
21
+ for i, nmodel in enumerate(nmodels):
22
+ if nmodel is not None:
23
+ self.add_module(f"nmodel_{i}", nmodel)
24
+
25
+ self._nm_cnt = 0
26
+ for nmodel in nmodels:
27
+ if nmodel is not None:
28
+ self._nm_cnt += 1
29
+
30
+ print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
31
+
32
+ def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
33
+
34
+ if obs.shape[1] == 1:
35
+ assert signal.shape[1] == 1
36
+ assert self.n2model is None
37
+ return self.nmodel_0.likelihood(obs, signal)
38
+
39
+ assert obs.shape[1] == self._nm_cnt, f"{obs.shape[1]} != {self._nm_cnt}"
40
+
41
+ ll_list = []
42
+ for ch_idx in range(obs.shape[1]):
43
+ nmodel = getattr(self, f"nmodel_{ch_idx}")
44
+ ll_list.append(
45
+ nmodel.likelihood(
46
+ obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
47
+ )
48
+ )
49
+
50
+ return torch.cat(ll_list, dim=1)
51
+
52
+
53
+ def last2path(fpath: str):
54
+ return os.path.join(*fpath.split("/")[-2:])
55
+
56
+
57
+ def get_nm_config(noise_model_fpath: str):
58
+ config_fpath = os.path.join(os.path.dirname(noise_model_fpath), "config.json")
59
+ with open(config_fpath) as f:
60
+ noise_model_config = json.load(f)
61
+ return noise_model_config
62
+
63
+
64
+ def fastShuffle(series, num):
65
+ length = series.shape[0]
66
+ for i in range(num):
67
+ series = series[np.random.permutation(length), :]
68
+ return series
69
+
70
+
71
+ def get_noise_model(
72
+ enable_noise_model: bool,
73
+ model_type: ModelType,
74
+ noise_model_type: str,
75
+ noise_model_ch1_fpath: str,
76
+ noise_model_ch2_fpath: str,
77
+ noise_model_learnable: bool = False,
78
+ denoise_channel: str = "input",
79
+ ):
80
+ if enable_noise_model:
81
+ nmodels = []
82
+ # HDN -> one single output -> one single noise model
83
+ if model_type == ModelType.Denoiser:
84
+ if noise_model_type == "hist":
85
+ raise NotImplementedError(
86
+ '"hist" noise model is not supported for now.'
87
+ )
88
+ elif noise_model_type == "gmm":
89
+ if denoise_channel == "Ch1":
90
+ nmodel_fpath = noise_model_ch1_fpath
91
+ print(f"Noise model Ch1: {nmodel_fpath}")
92
+ nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
93
+ nmodel2 = None
94
+ nmodels = [nmodel1, nmodel2]
95
+ elif denoise_channel == "Ch2":
96
+ nmodel_fpath = noise_model_ch2_fpath
97
+ print(f"Noise model Ch2: {nmodel_fpath}")
98
+ nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
99
+ nmodel2 = None
100
+ nmodels = [nmodel1, nmodel2]
101
+ elif denoise_channel == "input":
102
+ nmodel_fpath = noise_model_ch1_fpath
103
+ print(f"Noise model input: {nmodel_fpath}")
104
+ nmodel1 = GaussianMixtureNoiseModel(params=np.load(nmodel_fpath))
105
+ nmodel2 = None
106
+ nmodels = [nmodel1, nmodel2]
107
+ else:
108
+ raise ValueError(f"Invalid denoise_channel: {denoise_channel}")
109
+ # muSplit -> two outputs -> two noise models
110
+ elif noise_model_type == "gmm":
111
+ print(f"Noise model Ch1: {noise_model_ch1_fpath}")
112
+ print(f"Noise model Ch2: {noise_model_ch2_fpath}")
113
+
114
+ nmodel1 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch1_fpath))
115
+ nmodel2 = GaussianMixtureNoiseModel(params=np.load(noise_model_ch2_fpath))
116
+
117
+ nmodels = [nmodel1, nmodel2]
118
+
119
+ # if 'noise_model_ch3_fpath' in config.model:
120
+ # print(f'Noise model Ch3: {config.model.noise_model_ch3_fpath}')
121
+ # nmodel3 = GaussianMixtureNoiseModel(params=np.load(config.model.noise_model_ch3_fpath))
122
+ # nmodels = [nmodel1, nmodel2, nmodel3]
123
+ # else:
124
+ # nmodels = [nmodel1, nmodel2]
125
+ else:
126
+ raise ValueError(f"Invalid noise_model_type: {noise_model_type}")
127
+
128
+ if noise_model_learnable:
129
+ for nmodel in nmodels:
130
+ if nmodel is not None:
131
+ nmodel.make_learnable()
132
+
133
+ return DisentNoiseModel(*nmodels)
134
+ return None
135
+
136
+
137
+ class GaussianMixtureNoiseModel(nn.Module):
138
+ """
139
+ The GaussianMixtureNoiseModel class describes a noise model which is parameterized as a mixture of gaussians.
140
+ If you would like to initialize a new object from scratch, then set `params`= None and specify the other parameters as keyword arguments.
141
+ If you are instead loading a model, use only `params`.
142
+
143
+ Parameters
144
+ ----------
145
+ **kwargs: keyworded, variable-length argument dictionary.
146
+ Arguments include:
147
+ min_signal : float
148
+ Minimum signal intensity expected in the image.
149
+ max_signal : float
150
+ Maximum signal intensity expected in the image.
151
+ path: string
152
+ Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
153
+ weight : array
154
+ A [3*n_gaussian, n_coeff] sized array containing the values of the weights describing the noise model.
155
+ Each gaussian contributes three parameters (mean, standard deviation and weight), hence the number of rows in `weight` are 3*n_gaussian.
156
+ If `weight=None`, the weight array is initialized using the `min_signal` and `max_signal` parameters.
157
+ n_gaussian: int
158
+ Number of gaussians.
159
+ n_coeff: int
160
+ Number of coefficients to describe the functional relationship between gaussian parameters and the signal.
161
+ 2 implies a linear relationship, 3 implies a quadratic relationship and so on.
162
+ device: device
163
+ GPU device.
164
+ min_sigma: int
165
+ All values of sigma (`standard deviation`) below min_sigma are clamped to become equal to min_sigma.
166
+ params: dictionary
167
+ Use `params` if one wishes to load a model with trained weights.
168
+ While initializing a new object of the class `GaussianMixtureNoiseModel` from scratch, set this to `None`.
169
+ """
170
+
171
+ def __init__(self, **kwargs):
172
+ super().__init__()
173
+ self._learnable = False
174
+
175
+ if kwargs.get("params") is None:
176
+ weight = kwargs.get("weight")
177
+ n_gaussian = kwargs.get("n_gaussian")
178
+ n_coeff = kwargs.get("n_coeff")
179
+ min_signal = kwargs.get("min_signal")
180
+ max_signal = kwargs.get("max_signal")
181
+ # self.device = kwargs.get('device')
182
+ self.path = kwargs.get("path")
183
+ self.min_sigma = kwargs.get("min_sigma")
184
+ if weight is None:
185
+ weight = np.random.randn(n_gaussian * 3, n_coeff)
186
+ weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
187
+ weight = torch.from_numpy(
188
+ weight.astype(np.float32)
189
+ ).float() # .to(self.device)
190
+ weight = nn.Parameter(weight, requires_grad=True)
191
+
192
+ self.n_gaussian = weight.shape[0] // 3
193
+ self.n_coeff = weight.shape[1]
194
+ self.weight = weight
195
+ self.min_signal = torch.Tensor([min_signal]) # .to(self.device)
196
+ self.max_signal = torch.Tensor([max_signal]) # .to(self.device)
197
+ self.tol = torch.Tensor([1e-10]) # .to(self.device)
198
+ else:
199
+ params = kwargs.get("params")
200
+ # self.device = kwargs.get('device')
201
+
202
+ self.min_signal = torch.Tensor(params["min_signal"]) # .to(self.device)
203
+ self.max_signal = torch.Tensor(params["max_signal"]) # .to(self.device)
204
+
205
+ self.weight = torch.nn.Parameter(
206
+ torch.Tensor(params["trained_weight"]), requires_grad=False
207
+ ) # .to(self.device)
208
+ self.min_sigma = params["min_sigma"].item()
209
+ self.n_gaussian = self.weight.shape[0] // 3
210
+ self.n_coeff = self.weight.shape[1]
211
+ self.tol = torch.Tensor([1e-10]) # .to(self.device)
212
+ self.min_signal = torch.Tensor([self.min_signal]) # .to(self.device)
213
+ self.max_signal = torch.Tensor([self.max_signal]) # .to(self.device)
214
+
215
+ print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
216
+
217
+ def make_learnable(self):
218
+ print(f"[{self.__class__.__name__}] Making noise model learnable")
219
+
220
+ self._learnable = True
221
+ self.weight.requires_grad = True
222
+
223
+ #
224
+
225
+ def to_device(self, cuda_tensor):
226
+ # move everything to GPU
227
+ if self.min_signal.device != cuda_tensor.device:
228
+ self.max_signal = self.max_signal.to(cuda_tensor.device)
229
+ self.min_signal = self.min_signal.to(cuda_tensor.device)
230
+ self.tol = self.tol.to(cuda_tensor.device)
231
+ self.weight = self.weight.to(cuda_tensor.device)
232
+ if self._learnable:
233
+ self.weight.requires_grad = True
234
+
235
+ def polynomialRegressor(self, weightParams, signals):
236
+ """Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
237
+
238
+ Parameters
239
+ ----------
240
+ weightParams : torch.cuda.FloatTensor
241
+ Corresponds to specific rows of the `self.weight`
242
+ signals : torch.cuda.FloatTensor
243
+ Signals
244
+
245
+ Returns
246
+ -------
247
+ value : torch.cuda.FloatTensor
248
+ Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
249
+ """
250
+ value = 0
251
+ for i in range(weightParams.shape[0]):
252
+ value += weightParams[i] * (
253
+ ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
254
+ )
255
+ return value
256
+
257
+ def normalDens(self, x, m_=0.0, std_=None):
258
+ """Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`.
259
+
260
+ Parameters
261
+ ----------
262
+ x: torch.cuda.FloatTensor
263
+ Observations
264
+ m_: torch.cuda.FloatTensor
265
+ Mean
266
+ std_: torch.cuda.FloatTensor
267
+ Standard-deviation
268
+
269
+ Returns
270
+ -------
271
+ tmp: torch.cuda.FloatTensor
272
+ Normal probability density of `x` given `m_` and `std_`
273
+ """
274
+ tmp = -((x - m_) ** 2)
275
+ tmp = tmp / (2.0 * std_ * std_)
276
+ tmp = torch.exp(tmp)
277
+ tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
278
+ return tmp
279
+
280
+ def likelihood(self, observations, signals):
281
+ """Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
282
+
283
+ Parameters
284
+ ----------
285
+ observations : torch.cuda.FloatTensor
286
+ Noisy observations
287
+ signals : torch.cuda.FloatTensor
288
+ Underlying signals
289
+
290
+ Returns
291
+ -------
292
+ value :p + self.tol
293
+ Likelihood of observations given the signals and the GMM noise model
294
+ """
295
+ self.to_device(signals)
296
+ gaussianParameters = self.getGaussianParameters(signals)
297
+ p = 0
298
+ for gaussian in range(self.n_gaussian):
299
+ p += (
300
+ self.normalDens(
301
+ observations,
302
+ gaussianParameters[gaussian],
303
+ gaussianParameters[self.n_gaussian + gaussian],
304
+ )
305
+ * gaussianParameters[2 * self.n_gaussian + gaussian]
306
+ )
307
+ return p + self.tol
308
+
309
+ def getGaussianParameters(self, signals):
310
+ """Returns the noise model for given signals
311
+
312
+ Parameters
313
+ ----------
314
+ signals : torch.cuda.FloatTensor
315
+ Underlying signals
316
+
317
+ Returns
318
+ -------
319
+ noiseModel: list of torch.cuda.FloatTensor
320
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
321
+
322
+ """
323
+ noiseModel = []
324
+ mu = []
325
+ sigma = []
326
+ alpha = []
327
+ kernels = self.weight.shape[0] // 3
328
+ for num in range(kernels):
329
+ mu.append(self.polynomialRegressor(self.weight[num, :], signals))
330
+ # expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W))
331
+ expval = torch.exp(self.weight[kernels + num, :])
332
+ # self.maxval = max(self.maxval, expval.max().item())
333
+ sigmaTemp = self.polynomialRegressor(expval, signals)
334
+ sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
335
+ sigma.append(torch.sqrt(sigmaTemp))
336
+
337
+ # expval = torch.exp(
338
+ # torch.clamp(
339
+ # self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + self.tol, MAX_ALPHA_W))
340
+ expval = torch.exp(
341
+ self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
342
+ + self.tol
343
+ )
344
+ # self.maxval = max(self.maxval, expval.max().item())
345
+ alpha.append(expval)
346
+
347
+ sum_alpha = 0
348
+ for al in range(kernels):
349
+ sum_alpha = alpha[al] + sum_alpha
350
+
351
+ # sum of alpha is forced to be 1.
352
+ for ker in range(kernels):
353
+ alpha[ker] = alpha[ker] / sum_alpha
354
+
355
+ sum_means = 0
356
+ # sum_means is the alpha weighted average of the means
357
+ for ker in range(kernels):
358
+ sum_means = alpha[ker] * mu[ker] + sum_means
359
+
360
+ mu_shifted = []
361
+ # subtracting the alpha weighted average of the means from the means
362
+ # ensures that the GMM has the inclination to have the mean=signals.
363
+ # its like a residual conection. I don't understand why we need to learn the mean?
364
+ for ker in range(kernels):
365
+ mu[ker] = mu[ker] - sum_means + signals
366
+
367
+ for i in range(kernels):
368
+ noiseModel.append(mu[i])
369
+ for j in range(kernels):
370
+ noiseModel.append(sigma[j])
371
+ for k in range(kernels):
372
+ noiseModel.append(alpha[k])
373
+
374
+ return noiseModel
375
+
376
+ def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
377
+ """Returns the Signal-Observation pixel intensities as a two-column array
378
+
379
+ Parameters
380
+ ----------
381
+ signal : numpy array
382
+ Clean Signal Data
383
+ observation: numpy array
384
+ Noisy observation Data
385
+ lowerClip: float
386
+ Lower percentile bound for clipping.
387
+ upperClip: float
388
+ Upper percentile bound for clipping.
389
+
390
+ Returns
391
+ -------
392
+ noiseModel: list of torch floats
393
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
394
+ """
395
+ lb = np.percentile(signal, lowerClip)
396
+ ub = np.percentile(signal, upperClip)
397
+ stepsize = observation[0].size
398
+ n_observations = observation.shape[0]
399
+ n_signals = signal.shape[0]
400
+ sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
401
+
402
+ for i in range(n_observations):
403
+ j = i // (n_observations // n_signals)
404
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
405
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
406
+ sig_obs_pairs = sig_obs_pairs[
407
+ (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
408
+ ]
409
+ return fastShuffle(sig_obs_pairs, 2)