careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,524 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from ..utils.logging import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ # TODO here "Model" clashes a bit with the naming convention of the Pydantic Models
12
+ class NoiseModel(ABC):
13
+ """Base class for noise models."""
14
+
15
+ @abstractmethod
16
+ def instantiate(self):
17
+ """Instantiate the noise model.
18
+
19
+ Method that should produce ready to use noise model.
20
+ """
21
+ pass
22
+
23
+ @abstractmethod
24
+ def likelihood(self, observations, signals):
25
+ """Function that returns the likelihood of observations given the signals."""
26
+ pass
27
+
28
+
29
+ class HistogramNoiseModel(NoiseModel):
30
+ """Creates a NoiseModel object.
31
+
32
+ Parameters
33
+ ----------
34
+ histogram: numpy array
35
+ A histogram as create by the 'createHistogram(...)' method.
36
+ device:
37
+ The device your NoiseModel lives on, e.g. your GPU.
38
+ """
39
+
40
+ def __init__(self, **kwargs):
41
+ pass
42
+
43
+ def instantiate(self, bins, min_value, max_value, observation, signal):
44
+ """Creates a nD histogram from 'observation' and 'signal'.
45
+
46
+ Parameters
47
+ ----------
48
+ bins: int
49
+ The number of bins in all dimensions. The total number of bins is
50
+ 'bins' ** number_of_dimensions.
51
+ min_value: float
52
+ the lower bound of the lowest bin.
53
+ max_value: float
54
+ the highest bound of the highest bin.
55
+ observation: np.array
56
+ A stack of noisy images. The number has to be divisible by the number of
57
+ images in signal. N subsequent images in observation belong to one image
58
+ in the signal.
59
+ signal: np.array
60
+ A stack of clean images.
61
+
62
+ Returns
63
+ -------
64
+ histogram: numpy array
65
+ A 3D array:
66
+ 'histogram[0,...]' holds the normalized nD counts.
67
+ Each row sums to 1, describing p(x_i|s_i).
68
+ 'histogram[1,...]' holds the lower boundaries of each bin in y.
69
+ 'histogram[2,...]' holds the upper boundaries of each bin in y.
70
+ The values for x can be obtained by transposing 'histogram[1,...]'
71
+ and 'histogram[2,...]'.
72
+ """
73
+ img_factor = int(observation.shape[0] / signal.shape[0])
74
+ histogram = np.zeros((3, bins, bins))
75
+ value_range = [min_value, max_value]
76
+
77
+ for i in range(observation.shape[0]):
78
+ observation_i = observation[i].copy().ravel()
79
+
80
+ signal_i = (signal[i // img_factor].copy()).ravel()
81
+
82
+ histogram_i = np.histogramdd(
83
+ (signal_i, observation_i), bins=bins, range=[value_range, value_range]
84
+ )
85
+ # Adding a constant for numerical stability
86
+ histogram[0] = histogram[0] + histogram_i[0] + 1e-30
87
+
88
+ for i in range(bins):
89
+ # Exclude empty rows from normalization
90
+ if np.sum(histogram[0, i, :]) > 1e-20:
91
+ # Normalize each non-empty row
92
+ histogram[0, i, :] /= np.sum(histogram[0, i, :])
93
+
94
+ for i in range(bins):
95
+ # The lower boundaries of each bin in y are stored in dimension 1
96
+ histogram[1, :, i] = histogram_i[1][:-1]
97
+ # The upper boundaries of each bin in y are stored in dimension 2
98
+ histogram[2, :, i] = histogram_i[1][1:]
99
+ # The accordent numbers for x are just transposed.
100
+
101
+ return histogram
102
+
103
+ def likelihood(self, observed, signal):
104
+ """Calculate the likelihood using a histogram based noise model.
105
+
106
+ For every pixel in a tensor, calculate (x_i|s_i). To ensure differentiability
107
+ in the direction of s_i, we linearly interpolate in this direction.
108
+
109
+ Parameters
110
+ ----------
111
+ observed: torch.Tensor
112
+ tensor holding your observed intesities x_i.
113
+
114
+ signal: torch.Tensor
115
+ tensor holding hypotheses for the clean signal at every pixel s_i^k.
116
+
117
+ Returns
118
+ -------
119
+ Torch.tensor containing the observation likelihoods according to the
120
+ noise model.
121
+ """
122
+ observed_float = self.get_index_observed_float(observed)
123
+ observed_long = observed_float.floor().long()
124
+ signal_float = self.get_index_signal_float(signal)
125
+ signal_long = signal_float.floor().long()
126
+ fact = signal_float - signal_long.float()
127
+
128
+ # Finally we are looking ud the values and interpolate
129
+ return self.fullHist[signal_long, observed_long] * (1.0 - fact) + self.fullHist[
130
+ torch.clamp((signal_long + 1).long(), 0, self.bins.long()), observed_long
131
+ ] * (fact)
132
+
133
+ def get_index_observed_float(self, x: float):
134
+ """_summary_.
135
+
136
+ Parameters
137
+ ----------
138
+ x : _type_
139
+ _description_
140
+
141
+ Returns
142
+ -------
143
+ _type_
144
+ _description_
145
+ """
146
+ return torch.clamp(
147
+ self.bins * (x - self.minv) / (self.maxv - self.minv),
148
+ min=0.0,
149
+ max=self.bins - 1 - 1e-3,
150
+ )
151
+
152
+ def get_index_signal_float(self, x):
153
+ """_summary_.
154
+
155
+ Parameters
156
+ ----------
157
+ x : _type_
158
+ _description_
159
+
160
+ Returns
161
+ -------
162
+ _type_
163
+ _description_
164
+ """
165
+ return torch.clamp(
166
+ self.bins * (x - self.minv) / (self.maxv - self.minv),
167
+ min=0.0,
168
+ max=self.bins - 1 - 1e-3,
169
+ )
170
+
171
+
172
+ # TODO refactor this into Pydantic model
173
+ class GaussianMixtureNoiseModel(NoiseModel):
174
+ """Describes a noise model parameterized as a mixture of gaussians.
175
+
176
+ If you would like to initialize a new object from scratch, then set `params` = None
177
+ and specify the other parameters as keyword arguments. If you are instead loading
178
+ a model, use only `params`.
179
+
180
+ Parameters
181
+ ----------
182
+ **kwargs: keyworded, variable-length argument dictionary.
183
+ Arguments include:
184
+ min_signal : float
185
+ Minimum signal intensity expected in the image.
186
+ max_signal : float
187
+ Maximum signal intensity expected in the image.
188
+ weight : array
189
+ A [3*n_gaussian, n_coeff] sized array containing the values of the weights
190
+ describing the noise model.
191
+ Each gaussian contributes three parameters (mean, standard deviation and weight),
192
+ hence the number of rows in `weight` are 3*n_gaussian.
193
+ If `weight = None`, the weight array is initialized using the `min_signal` and
194
+ `max_signal` parameters.
195
+ n_gaussian: int
196
+ Number of gaussians.
197
+ n_coeff: int
198
+ Number of coefficients to describe the functional relationship between gaussian
199
+ parameters and the signal.
200
+ 2 implies a linear relationship, 3 implies a quadratic relationship and so on.
201
+ device: device
202
+ GPU device.
203
+ min_sigma: int
204
+ All values of sigma (`standard deviation`) below min_sigma are clamped to become
205
+ equal to min_sigma.
206
+ params: dictionary
207
+ Use `params` if one wishes to load a model with trained weights.
208
+ While initializing a new object of the class `GaussianMixtureNoiseModel` from
209
+ scratch, set this to `None`.
210
+ """
211
+
212
+ def __init__(self, **kwargs):
213
+ if kwargs.get("params") is None:
214
+ weight = kwargs.get("weight")
215
+ n_gaussian = kwargs.get("n_gaussian")
216
+ n_coeff = kwargs.get("n_coeff")
217
+ min_signal = kwargs.get("min_signal")
218
+ max_signal = kwargs.get("max_signal")
219
+ self.device = kwargs.get("device")
220
+ self.path = kwargs.get("path")
221
+ self.min_sigma = kwargs.get("min_sigma")
222
+ if weight is None:
223
+ weight = np.random.randn(n_gaussian * 3, n_coeff)
224
+ weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
225
+ weight = (
226
+ torch.from_numpy(weight.astype(np.float32)).float().to(self.device)
227
+ )
228
+ weight.requires_grad = True
229
+ self.n_gaussian = weight.shape[0] // 3
230
+ self.n_coeff = weight.shape[1]
231
+ self.weight = weight
232
+ self.min_signal = torch.Tensor([min_signal]).to(self.device)
233
+ self.max_signal = torch.Tensor([max_signal]).to(self.device)
234
+ self.tol = torch.Tensor([1e-10]).to(self.device)
235
+ else:
236
+ params = kwargs.get("params")
237
+ self.device = kwargs.get("device")
238
+
239
+ self.min_signal = torch.Tensor(params["min_signal"]).to(self.device)
240
+ self.max_signal = torch.Tensor(params["max_signal"]).to(self.device)
241
+
242
+ self.weight = torch.Tensor(params["trained_weight"]).to(self.device)
243
+ self.min_sigma = np.ndarray.item(params["min_sigma"])
244
+ self.n_gaussian = self.weight.shape[0] // 3
245
+ self.n_coeff = self.weight.shape[1]
246
+ self.tol = torch.Tensor([1e-10]).to(self.device)
247
+ self.min_signal = torch.Tensor([self.min_signal]).to(self.device)
248
+ self.max_signal = torch.Tensor([self.max_signal]).to(self.device)
249
+
250
+ def fast_shuffle(self, series, num):
251
+ """.
252
+
253
+ Parameters
254
+ ----------
255
+ series : _type_
256
+ _description_
257
+ num : _type_
258
+ _description_
259
+
260
+ Returns
261
+ -------
262
+ _type_
263
+ _description_
264
+ """
265
+ length = series.shape[0]
266
+ for _i in range(num):
267
+ series = series[np.random.permutation(length), :]
268
+ return series
269
+
270
+ def polynomial_regressor(self, weightParams, signals):
271
+ """Combines weight_parameters and signals to perform regression.
272
+
273
+ Parameters
274
+ ----------
275
+ weightParams : torch.cuda.FloatTensor
276
+ Corresponds to specific rows of the `self.weight'
277
+
278
+ signals : torch.cuda.FloatTensor
279
+ Signals
280
+
281
+ Returns
282
+ -------
283
+ value : torch.cuda.FloatTensor
284
+ Corresponds to either of mean, standard deviation or weight, evaluated at
285
+ `signals`
286
+ """
287
+ value = 0
288
+ for i in range(weightParams.shape[0]):
289
+ value += weightParams[i] * (
290
+ ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
291
+ )
292
+ return value
293
+
294
+ def normal_density(self, x, m_=0.0, std_=None):
295
+ """Evaluates the normal probability density.
296
+
297
+ Parameters
298
+ ----------
299
+ x: torch.cuda.FloatTensor
300
+ Observations
301
+ m_: torch.cuda.FloatTensor
302
+ Mean
303
+ std_: torch.cuda.FloatTensor
304
+ Standard-deviation
305
+
306
+ Returns
307
+ -------
308
+ tmp: torch.cuda.FloatTensor
309
+ Normal probability density of `x` given `m_` and `std_`
310
+
311
+ """
312
+ tmp = -((x - m_) ** 2)
313
+ tmp = tmp / (2.0 * std_ * std_)
314
+ tmp = torch.exp(tmp)
315
+ tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
316
+ return tmp
317
+
318
+ def likelihood(self, observations, signals):
319
+ """Evaluates the likelihood of observations.
320
+
321
+ Given the signals and the corresponding gaussian parameters evaluates the
322
+ likelihood of observations.
323
+
324
+ Parameters
325
+ ----------
326
+ observations : torch.cuda.FloatTensor
327
+ Noisy observations
328
+ signals : torch.cuda.FloatTensor
329
+ Underlying signals
330
+
331
+ Returns
332
+ -------
333
+ value :p + self.tol
334
+ Likelihood of observations given the signals and the GMM noise model
335
+
336
+ """
337
+ gaussianParameters = self.getGaussianParameters(signals)
338
+ p = 0
339
+ for gaussian in range(self.n_gaussian):
340
+ p += (
341
+ self.normalDens(
342
+ observations,
343
+ gaussianParameters[gaussian],
344
+ gaussianParameters[self.n_gaussian + gaussian],
345
+ )
346
+ * gaussianParameters[2 * self.n_gaussian + gaussian]
347
+ )
348
+ return p + self.tol
349
+
350
+ def get_gaussian_parameters(self, signals):
351
+ """Returns the noise model for given signals.
352
+
353
+ Parameters
354
+ ----------
355
+ signals : torch.cuda.FloatTensor
356
+ Underlying signals
357
+
358
+ Returns
359
+ -------
360
+ noiseModel: list of torch.cuda.FloatTensor
361
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
362
+
363
+ """
364
+ noiseModel = []
365
+ mu = []
366
+ sigma = []
367
+ alpha = []
368
+ kernels = self.weight.shape[0] // 3
369
+ for num in range(kernels):
370
+ mu.append(self.polynomialRegressor(self.weight[num, :], signals))
371
+
372
+ sigmaTemp = self.polynomialRegressor(
373
+ torch.exp(self.weight[kernels + num, :]), signals
374
+ )
375
+ sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
376
+ sigma.append(torch.sqrt(sigmaTemp))
377
+ alpha.append(
378
+ torch.exp(
379
+ self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
380
+ + self.tol
381
+ )
382
+ )
383
+
384
+ sum_alpha = 0
385
+ for al in range(kernels):
386
+ sum_alpha = alpha[al] + sum_alpha
387
+ for ker in range(kernels):
388
+ alpha[ker] = alpha[ker] / sum_alpha
389
+
390
+ sum_means = 0
391
+ for ker in range(kernels):
392
+ sum_means = alpha[ker] * mu[ker] + sum_means
393
+
394
+ for ker in range(kernels):
395
+ mu[ker] = mu[ker] - sum_means + signals
396
+
397
+ for i in range(kernels):
398
+ noiseModel.append(mu[i])
399
+ for j in range(kernels):
400
+ noiseModel.append(sigma[j])
401
+ for k in range(kernels):
402
+ noiseModel.append(alpha[k])
403
+
404
+ return noiseModel
405
+
406
+ def get_signal_observation_pairs(self, signal, observation, lowerClip, upperClip):
407
+ """Returns the Signal-Observation pixel intensities as a two-column array.
408
+
409
+ Parameters
410
+ ----------
411
+ signal : numpy array
412
+ Clean Signal Data
413
+ observation: numpy array
414
+ Noisy observation Data
415
+ lowerClip: float
416
+ Lower percentile bound for clipping.
417
+ upperClip: float
418
+ Upper percentile bound for clipping.
419
+
420
+ Returns
421
+ -------
422
+ noiseModel: list of torch floats
423
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
424
+
425
+ """
426
+ lb = np.percentile(signal, lowerClip)
427
+ ub = np.percentile(signal, upperClip)
428
+ stepsize = observation[0].size
429
+ n_observations = observation.shape[0]
430
+ n_signals = signal.shape[0]
431
+ sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
432
+
433
+ for i in range(n_observations):
434
+ j = i // (n_observations // n_signals)
435
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
436
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
437
+ sig_obs_pairs = sig_obs_pairs[
438
+ (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
439
+ ]
440
+ return self.fast_shuffle(sig_obs_pairs, 2)
441
+
442
+ def train(
443
+ self,
444
+ signal,
445
+ observation,
446
+ learning_rate=1e-1,
447
+ batchSize=250000,
448
+ n_epochs=2000,
449
+ name="GMMNoiseModel.npz",
450
+ lowerClip=0,
451
+ upperClip=100,
452
+ ):
453
+ """Training to learn the noise model from signal - observation pairs.
454
+
455
+ Parameters
456
+ ----------
457
+ signal: numpy array
458
+ Clean Signal Data
459
+ observation: numpy array
460
+ Noisy Observation Data
461
+ learning_rate: float
462
+ Learning rate. Default = 1e-1.
463
+ batchSize: int
464
+ Nini-batch size. Default = 250000.
465
+ n_epochs: int
466
+ Number of epochs. Default = 2000.
467
+ name: string
468
+ Model name. Default is `GMMNoiseModel`. This model after being trained is
469
+ saved at the location `path`.
470
+
471
+ lowerClip : int
472
+ Lower percentile for clipping. Default is 0.
473
+ upperClip : int
474
+ Upper percentile for clipping. Default is 100.
475
+
476
+
477
+ """
478
+ sig_obs_pairs = self.getSignalObservationPairs(
479
+ signal, observation, lowerClip, upperClip
480
+ )
481
+ counter = 0
482
+ optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
483
+ for t in range(n_epochs):
484
+ jointLoss = 0
485
+ if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
486
+ counter = 0
487
+ sig_obs_pairs = self.fast_shuffle(sig_obs_pairs, 1)
488
+
489
+ batch_vectors = sig_obs_pairs[
490
+ counter * batchSize : (counter + 1) * batchSize, :
491
+ ]
492
+ observations = batch_vectors[:, 1].astype(np.float32)
493
+ signals = batch_vectors[:, 0].astype(np.float32)
494
+ observations = (
495
+ torch.from_numpy(observations.astype(np.float32))
496
+ .float()
497
+ .to(self.device)
498
+ )
499
+ signals = torch.from_numpy(signals).float().to(self.device)
500
+ p = self.likelihood(observations, signals)
501
+ loss = torch.mean(-torch.log(p))
502
+ jointLoss = jointLoss + loss
503
+
504
+ if t % 100 == 0:
505
+ print(t, jointLoss.item())
506
+
507
+ if t % (int(n_epochs * 0.5)) == 0:
508
+ trained_weight = self.weight.cpu().detach().numpy()
509
+ min_signal = self.min_signal.cpu().detach().numpy()
510
+ max_signal = self.max_signal.cpu().detach().numpy()
511
+ np.savez(
512
+ self.path + name,
513
+ trained_weight=trained_weight,
514
+ min_signal=min_signal,
515
+ max_signal=max_signal,
516
+ min_sigma=self.min_sigma,
517
+ )
518
+
519
+ optimizer.zero_grad()
520
+ jointLoss.backward()
521
+ optimizer.step()
522
+ counter += 1
523
+
524
+ logger.info(f"The trained parameters {name} is saved at location: " + self.path)
@@ -0,0 +1,8 @@
1
+ """Model I/O utilities."""
2
+
3
+
4
+ __all__ = ["load_pretrained", "export_to_bmz"]
5
+
6
+
7
+ from .bmz_io import export_to_bmz
8
+ from .model_io_utils import load_pretrained
@@ -0,0 +1,11 @@
1
+ """Bioimage Model Zoo format functions."""
2
+
3
+ __all__ = [
4
+ "create_model_description",
5
+ "extract_model_path",
6
+ "get_unzip_path",
7
+ "create_env_text",
8
+ ]
9
+
10
+ from .bioimage_utils import create_env_text, get_unzip_path
11
+ from .model_description import create_model_description, extract_model_path
@@ -0,0 +1,120 @@
1
+ """Functions used to create a README.md file for BMZ export."""
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import yaml
6
+
7
+ from careamics.config import Configuration
8
+ from careamics.utils import cwd, get_careamics_home
9
+
10
+
11
+ def _yaml_block(yaml_str: str) -> str:
12
+ """Return a markdown code block with a yaml string.
13
+
14
+ Parameters
15
+ ----------
16
+ yaml_str : str
17
+ YAML string.
18
+
19
+ Returns
20
+ -------
21
+ str
22
+ Markdown code block with the YAML string.
23
+ """
24
+ return f"```yaml\n{yaml_str}\n```"
25
+
26
+
27
+ def readme_factory(
28
+ config: Configuration,
29
+ careamics_version: str,
30
+ data_description: Optional[str] = None,
31
+ ) -> Path:
32
+ """Create a README file for the model.
33
+
34
+ `data_description` can be used to add more information about the content of the
35
+ data the model was trained on.
36
+
37
+ Parameters
38
+ ----------
39
+ config : Configuration
40
+ CAREamics configuration.
41
+ careamics_version : str
42
+ CAREamics version.
43
+ data_description : Optional[str], optional
44
+ Description of the data, by default None.
45
+
46
+ Returns
47
+ -------
48
+ Path
49
+ Path to the README file.
50
+ """
51
+ algorithm = config.algorithm_config
52
+ training = config.training_config
53
+ data = config.data_config
54
+
55
+ # create file
56
+ # TODO use tempfile as in the bmz_io module
57
+ with cwd(get_careamics_home()):
58
+ readme = Path("README.md")
59
+ readme.touch()
60
+
61
+ # algorithm pretty name
62
+ algorithm_flavour = config.get_algorithm_flavour()
63
+ algorithm_pretty_name = algorithm_flavour + " - CAREamics"
64
+
65
+ description = [f"# {algorithm_pretty_name}\n\n"]
66
+
67
+ # algorithm description
68
+ description.append("Algorithm description:\n\n")
69
+ description.append(config.get_algorithm_description())
70
+ description.append("\n\n")
71
+
72
+ # algorithm details
73
+ description.append(
74
+ f"{algorithm_flavour} was trained using CAREamics (version "
75
+ f"{careamics_version}) with the following algorithm "
76
+ f"parameters:\n\n"
77
+ )
78
+ description.append(
79
+ _yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True)))
80
+ )
81
+ description.append("\n\n")
82
+
83
+ # data description
84
+ description.append("## Data description\n\n")
85
+ if data_description is not None:
86
+ description.append(data_description)
87
+ description.append("\n\n")
88
+
89
+ description.append("The data was processed using the following parameters:\n\n")
90
+
91
+ description.append(_yaml_block(yaml.dump(data.model_dump(exclude_none=True))))
92
+ description.append("\n\n")
93
+
94
+ # training description
95
+ description.append("## Training description\n\n")
96
+
97
+ description.append("The model was trained using the following parameters:\n\n")
98
+
99
+ description.append(
100
+ _yaml_block(yaml.dump(training.model_dump(exclude_none=True)))
101
+ )
102
+ description.append("\n\n")
103
+
104
+ # references
105
+ reference = config.get_algorithm_references()
106
+ if reference != "":
107
+ description.append("## References\n\n")
108
+ description.append(reference)
109
+ description.append("\n\n")
110
+
111
+ # links
112
+ description.append(
113
+ "## Links\n\n"
114
+ "- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
115
+ "- [CAREamics documentation](https://careamics.github.io/latest/)\n"
116
+ )
117
+
118
+ readme.write_text("".join(description))
119
+
120
+ return readme