careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ..utils.logging import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO here "Model" clashes a bit with the naming convention of the Pydantic Models
|
|
12
|
+
class NoiseModel(ABC):
|
|
13
|
+
"""Base class for noise models."""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def instantiate(self):
|
|
17
|
+
"""Instantiate the noise model.
|
|
18
|
+
|
|
19
|
+
Method that should produce ready to use noise model.
|
|
20
|
+
"""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def likelihood(self, observations, signals):
|
|
25
|
+
"""Function that returns the likelihood of observations given the signals."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class HistogramNoiseModel(NoiseModel):
|
|
30
|
+
"""Creates a NoiseModel object.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
histogram: numpy array
|
|
35
|
+
A histogram as create by the 'createHistogram(...)' method.
|
|
36
|
+
device:
|
|
37
|
+
The device your NoiseModel lives on, e.g. your GPU.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, **kwargs):
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
def instantiate(self, bins, min_value, max_value, observation, signal):
|
|
44
|
+
"""Creates a nD histogram from 'observation' and 'signal'.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
bins: int
|
|
49
|
+
The number of bins in all dimensions. The total number of bins is
|
|
50
|
+
'bins' ** number_of_dimensions.
|
|
51
|
+
min_value: float
|
|
52
|
+
the lower bound of the lowest bin.
|
|
53
|
+
max_value: float
|
|
54
|
+
the highest bound of the highest bin.
|
|
55
|
+
observation: np.array
|
|
56
|
+
A stack of noisy images. The number has to be divisible by the number of
|
|
57
|
+
images in signal. N subsequent images in observation belong to one image
|
|
58
|
+
in the signal.
|
|
59
|
+
signal: np.array
|
|
60
|
+
A stack of clean images.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
histogram: numpy array
|
|
65
|
+
A 3D array:
|
|
66
|
+
'histogram[0,...]' holds the normalized nD counts.
|
|
67
|
+
Each row sums to 1, describing p(x_i|s_i).
|
|
68
|
+
'histogram[1,...]' holds the lower boundaries of each bin in y.
|
|
69
|
+
'histogram[2,...]' holds the upper boundaries of each bin in y.
|
|
70
|
+
The values for x can be obtained by transposing 'histogram[1,...]'
|
|
71
|
+
and 'histogram[2,...]'.
|
|
72
|
+
"""
|
|
73
|
+
img_factor = int(observation.shape[0] / signal.shape[0])
|
|
74
|
+
histogram = np.zeros((3, bins, bins))
|
|
75
|
+
value_range = [min_value, max_value]
|
|
76
|
+
|
|
77
|
+
for i in range(observation.shape[0]):
|
|
78
|
+
observation_i = observation[i].copy().ravel()
|
|
79
|
+
|
|
80
|
+
signal_i = (signal[i // img_factor].copy()).ravel()
|
|
81
|
+
|
|
82
|
+
histogram_i = np.histogramdd(
|
|
83
|
+
(signal_i, observation_i), bins=bins, range=[value_range, value_range]
|
|
84
|
+
)
|
|
85
|
+
# Adding a constant for numerical stability
|
|
86
|
+
histogram[0] = histogram[0] + histogram_i[0] + 1e-30
|
|
87
|
+
|
|
88
|
+
for i in range(bins):
|
|
89
|
+
# Exclude empty rows from normalization
|
|
90
|
+
if np.sum(histogram[0, i, :]) > 1e-20:
|
|
91
|
+
# Normalize each non-empty row
|
|
92
|
+
histogram[0, i, :] /= np.sum(histogram[0, i, :])
|
|
93
|
+
|
|
94
|
+
for i in range(bins):
|
|
95
|
+
# The lower boundaries of each bin in y are stored in dimension 1
|
|
96
|
+
histogram[1, :, i] = histogram_i[1][:-1]
|
|
97
|
+
# The upper boundaries of each bin in y are stored in dimension 2
|
|
98
|
+
histogram[2, :, i] = histogram_i[1][1:]
|
|
99
|
+
# The accordent numbers for x are just transposed.
|
|
100
|
+
|
|
101
|
+
return histogram
|
|
102
|
+
|
|
103
|
+
def likelihood(self, observed, signal):
|
|
104
|
+
"""Calculate the likelihood using a histogram based noise model.
|
|
105
|
+
|
|
106
|
+
For every pixel in a tensor, calculate (x_i|s_i). To ensure differentiability
|
|
107
|
+
in the direction of s_i, we linearly interpolate in this direction.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
observed: torch.Tensor
|
|
112
|
+
tensor holding your observed intesities x_i.
|
|
113
|
+
|
|
114
|
+
signal: torch.Tensor
|
|
115
|
+
tensor holding hypotheses for the clean signal at every pixel s_i^k.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
Torch.tensor containing the observation likelihoods according to the
|
|
120
|
+
noise model.
|
|
121
|
+
"""
|
|
122
|
+
observed_float = self.get_index_observed_float(observed)
|
|
123
|
+
observed_long = observed_float.floor().long()
|
|
124
|
+
signal_float = self.get_index_signal_float(signal)
|
|
125
|
+
signal_long = signal_float.floor().long()
|
|
126
|
+
fact = signal_float - signal_long.float()
|
|
127
|
+
|
|
128
|
+
# Finally we are looking ud the values and interpolate
|
|
129
|
+
return self.fullHist[signal_long, observed_long] * (1.0 - fact) + self.fullHist[
|
|
130
|
+
torch.clamp((signal_long + 1).long(), 0, self.bins.long()), observed_long
|
|
131
|
+
] * (fact)
|
|
132
|
+
|
|
133
|
+
def get_index_observed_float(self, x: float):
|
|
134
|
+
"""_summary_.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
x : _type_
|
|
139
|
+
_description_
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
_type_
|
|
144
|
+
_description_
|
|
145
|
+
"""
|
|
146
|
+
return torch.clamp(
|
|
147
|
+
self.bins * (x - self.minv) / (self.maxv - self.minv),
|
|
148
|
+
min=0.0,
|
|
149
|
+
max=self.bins - 1 - 1e-3,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def get_index_signal_float(self, x):
|
|
153
|
+
"""_summary_.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
x : _type_
|
|
158
|
+
_description_
|
|
159
|
+
|
|
160
|
+
Returns
|
|
161
|
+
-------
|
|
162
|
+
_type_
|
|
163
|
+
_description_
|
|
164
|
+
"""
|
|
165
|
+
return torch.clamp(
|
|
166
|
+
self.bins * (x - self.minv) / (self.maxv - self.minv),
|
|
167
|
+
min=0.0,
|
|
168
|
+
max=self.bins - 1 - 1e-3,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# TODO refactor this into Pydantic model
|
|
173
|
+
class GaussianMixtureNoiseModel(NoiseModel):
|
|
174
|
+
"""Describes a noise model parameterized as a mixture of gaussians.
|
|
175
|
+
|
|
176
|
+
If you would like to initialize a new object from scratch, then set `params` = None
|
|
177
|
+
and specify the other parameters as keyword arguments. If you are instead loading
|
|
178
|
+
a model, use only `params`.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
**kwargs: keyworded, variable-length argument dictionary.
|
|
183
|
+
Arguments include:
|
|
184
|
+
min_signal : float
|
|
185
|
+
Minimum signal intensity expected in the image.
|
|
186
|
+
max_signal : float
|
|
187
|
+
Maximum signal intensity expected in the image.
|
|
188
|
+
weight : array
|
|
189
|
+
A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
190
|
+
describing the noise model.
|
|
191
|
+
Each gaussian contributes three parameters (mean, standard deviation and weight),
|
|
192
|
+
hence the number of rows in `weight` are 3*n_gaussian.
|
|
193
|
+
If `weight = None`, the weight array is initialized using the `min_signal` and
|
|
194
|
+
`max_signal` parameters.
|
|
195
|
+
n_gaussian: int
|
|
196
|
+
Number of gaussians.
|
|
197
|
+
n_coeff: int
|
|
198
|
+
Number of coefficients to describe the functional relationship between gaussian
|
|
199
|
+
parameters and the signal.
|
|
200
|
+
2 implies a linear relationship, 3 implies a quadratic relationship and so on.
|
|
201
|
+
device: device
|
|
202
|
+
GPU device.
|
|
203
|
+
min_sigma: int
|
|
204
|
+
All values of sigma (`standard deviation`) below min_sigma are clamped to become
|
|
205
|
+
equal to min_sigma.
|
|
206
|
+
params: dictionary
|
|
207
|
+
Use `params` if one wishes to load a model with trained weights.
|
|
208
|
+
While initializing a new object of the class `GaussianMixtureNoiseModel` from
|
|
209
|
+
scratch, set this to `None`.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(self, **kwargs):
|
|
213
|
+
if kwargs.get("params") is None:
|
|
214
|
+
weight = kwargs.get("weight")
|
|
215
|
+
n_gaussian = kwargs.get("n_gaussian")
|
|
216
|
+
n_coeff = kwargs.get("n_coeff")
|
|
217
|
+
min_signal = kwargs.get("min_signal")
|
|
218
|
+
max_signal = kwargs.get("max_signal")
|
|
219
|
+
self.device = kwargs.get("device")
|
|
220
|
+
self.path = kwargs.get("path")
|
|
221
|
+
self.min_sigma = kwargs.get("min_sigma")
|
|
222
|
+
if weight is None:
|
|
223
|
+
weight = np.random.randn(n_gaussian * 3, n_coeff)
|
|
224
|
+
weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
|
|
225
|
+
weight = (
|
|
226
|
+
torch.from_numpy(weight.astype(np.float32)).float().to(self.device)
|
|
227
|
+
)
|
|
228
|
+
weight.requires_grad = True
|
|
229
|
+
self.n_gaussian = weight.shape[0] // 3
|
|
230
|
+
self.n_coeff = weight.shape[1]
|
|
231
|
+
self.weight = weight
|
|
232
|
+
self.min_signal = torch.Tensor([min_signal]).to(self.device)
|
|
233
|
+
self.max_signal = torch.Tensor([max_signal]).to(self.device)
|
|
234
|
+
self.tol = torch.Tensor([1e-10]).to(self.device)
|
|
235
|
+
else:
|
|
236
|
+
params = kwargs.get("params")
|
|
237
|
+
self.device = kwargs.get("device")
|
|
238
|
+
|
|
239
|
+
self.min_signal = torch.Tensor(params["min_signal"]).to(self.device)
|
|
240
|
+
self.max_signal = torch.Tensor(params["max_signal"]).to(self.device)
|
|
241
|
+
|
|
242
|
+
self.weight = torch.Tensor(params["trained_weight"]).to(self.device)
|
|
243
|
+
self.min_sigma = np.ndarray.item(params["min_sigma"])
|
|
244
|
+
self.n_gaussian = self.weight.shape[0] // 3
|
|
245
|
+
self.n_coeff = self.weight.shape[1]
|
|
246
|
+
self.tol = torch.Tensor([1e-10]).to(self.device)
|
|
247
|
+
self.min_signal = torch.Tensor([self.min_signal]).to(self.device)
|
|
248
|
+
self.max_signal = torch.Tensor([self.max_signal]).to(self.device)
|
|
249
|
+
|
|
250
|
+
def fast_shuffle(self, series, num):
|
|
251
|
+
""".
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
series : _type_
|
|
256
|
+
_description_
|
|
257
|
+
num : _type_
|
|
258
|
+
_description_
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
_type_
|
|
263
|
+
_description_
|
|
264
|
+
"""
|
|
265
|
+
length = series.shape[0]
|
|
266
|
+
for _i in range(num):
|
|
267
|
+
series = series[np.random.permutation(length), :]
|
|
268
|
+
return series
|
|
269
|
+
|
|
270
|
+
def polynomial_regressor(self, weightParams, signals):
|
|
271
|
+
"""Combines weight_parameters and signals to perform regression.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
weightParams : torch.cuda.FloatTensor
|
|
276
|
+
Corresponds to specific rows of the `self.weight'
|
|
277
|
+
|
|
278
|
+
signals : torch.cuda.FloatTensor
|
|
279
|
+
Signals
|
|
280
|
+
|
|
281
|
+
Returns
|
|
282
|
+
-------
|
|
283
|
+
value : torch.cuda.FloatTensor
|
|
284
|
+
Corresponds to either of mean, standard deviation or weight, evaluated at
|
|
285
|
+
`signals`
|
|
286
|
+
"""
|
|
287
|
+
value = 0
|
|
288
|
+
for i in range(weightParams.shape[0]):
|
|
289
|
+
value += weightParams[i] * (
|
|
290
|
+
((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
|
|
291
|
+
)
|
|
292
|
+
return value
|
|
293
|
+
|
|
294
|
+
def normal_density(self, x, m_=0.0, std_=None):
|
|
295
|
+
"""Evaluates the normal probability density.
|
|
296
|
+
|
|
297
|
+
Parameters
|
|
298
|
+
----------
|
|
299
|
+
x: torch.cuda.FloatTensor
|
|
300
|
+
Observations
|
|
301
|
+
m_: torch.cuda.FloatTensor
|
|
302
|
+
Mean
|
|
303
|
+
std_: torch.cuda.FloatTensor
|
|
304
|
+
Standard-deviation
|
|
305
|
+
|
|
306
|
+
Returns
|
|
307
|
+
-------
|
|
308
|
+
tmp: torch.cuda.FloatTensor
|
|
309
|
+
Normal probability density of `x` given `m_` and `std_`
|
|
310
|
+
|
|
311
|
+
"""
|
|
312
|
+
tmp = -((x - m_) ** 2)
|
|
313
|
+
tmp = tmp / (2.0 * std_ * std_)
|
|
314
|
+
tmp = torch.exp(tmp)
|
|
315
|
+
tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
|
|
316
|
+
return tmp
|
|
317
|
+
|
|
318
|
+
def likelihood(self, observations, signals):
|
|
319
|
+
"""Evaluates the likelihood of observations.
|
|
320
|
+
|
|
321
|
+
Given the signals and the corresponding gaussian parameters evaluates the
|
|
322
|
+
likelihood of observations.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
observations : torch.cuda.FloatTensor
|
|
327
|
+
Noisy observations
|
|
328
|
+
signals : torch.cuda.FloatTensor
|
|
329
|
+
Underlying signals
|
|
330
|
+
|
|
331
|
+
Returns
|
|
332
|
+
-------
|
|
333
|
+
value :p + self.tol
|
|
334
|
+
Likelihood of observations given the signals and the GMM noise model
|
|
335
|
+
|
|
336
|
+
"""
|
|
337
|
+
gaussianParameters = self.getGaussianParameters(signals)
|
|
338
|
+
p = 0
|
|
339
|
+
for gaussian in range(self.n_gaussian):
|
|
340
|
+
p += (
|
|
341
|
+
self.normalDens(
|
|
342
|
+
observations,
|
|
343
|
+
gaussianParameters[gaussian],
|
|
344
|
+
gaussianParameters[self.n_gaussian + gaussian],
|
|
345
|
+
)
|
|
346
|
+
* gaussianParameters[2 * self.n_gaussian + gaussian]
|
|
347
|
+
)
|
|
348
|
+
return p + self.tol
|
|
349
|
+
|
|
350
|
+
def get_gaussian_parameters(self, signals):
|
|
351
|
+
"""Returns the noise model for given signals.
|
|
352
|
+
|
|
353
|
+
Parameters
|
|
354
|
+
----------
|
|
355
|
+
signals : torch.cuda.FloatTensor
|
|
356
|
+
Underlying signals
|
|
357
|
+
|
|
358
|
+
Returns
|
|
359
|
+
-------
|
|
360
|
+
noiseModel: list of torch.cuda.FloatTensor
|
|
361
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
362
|
+
|
|
363
|
+
"""
|
|
364
|
+
noiseModel = []
|
|
365
|
+
mu = []
|
|
366
|
+
sigma = []
|
|
367
|
+
alpha = []
|
|
368
|
+
kernels = self.weight.shape[0] // 3
|
|
369
|
+
for num in range(kernels):
|
|
370
|
+
mu.append(self.polynomialRegressor(self.weight[num, :], signals))
|
|
371
|
+
|
|
372
|
+
sigmaTemp = self.polynomialRegressor(
|
|
373
|
+
torch.exp(self.weight[kernels + num, :]), signals
|
|
374
|
+
)
|
|
375
|
+
sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
|
|
376
|
+
sigma.append(torch.sqrt(sigmaTemp))
|
|
377
|
+
alpha.append(
|
|
378
|
+
torch.exp(
|
|
379
|
+
self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
|
|
380
|
+
+ self.tol
|
|
381
|
+
)
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
sum_alpha = 0
|
|
385
|
+
for al in range(kernels):
|
|
386
|
+
sum_alpha = alpha[al] + sum_alpha
|
|
387
|
+
for ker in range(kernels):
|
|
388
|
+
alpha[ker] = alpha[ker] / sum_alpha
|
|
389
|
+
|
|
390
|
+
sum_means = 0
|
|
391
|
+
for ker in range(kernels):
|
|
392
|
+
sum_means = alpha[ker] * mu[ker] + sum_means
|
|
393
|
+
|
|
394
|
+
for ker in range(kernels):
|
|
395
|
+
mu[ker] = mu[ker] - sum_means + signals
|
|
396
|
+
|
|
397
|
+
for i in range(kernels):
|
|
398
|
+
noiseModel.append(mu[i])
|
|
399
|
+
for j in range(kernels):
|
|
400
|
+
noiseModel.append(sigma[j])
|
|
401
|
+
for k in range(kernels):
|
|
402
|
+
noiseModel.append(alpha[k])
|
|
403
|
+
|
|
404
|
+
return noiseModel
|
|
405
|
+
|
|
406
|
+
def get_signal_observation_pairs(self, signal, observation, lowerClip, upperClip):
|
|
407
|
+
"""Returns the Signal-Observation pixel intensities as a two-column array.
|
|
408
|
+
|
|
409
|
+
Parameters
|
|
410
|
+
----------
|
|
411
|
+
signal : numpy array
|
|
412
|
+
Clean Signal Data
|
|
413
|
+
observation: numpy array
|
|
414
|
+
Noisy observation Data
|
|
415
|
+
lowerClip: float
|
|
416
|
+
Lower percentile bound for clipping.
|
|
417
|
+
upperClip: float
|
|
418
|
+
Upper percentile bound for clipping.
|
|
419
|
+
|
|
420
|
+
Returns
|
|
421
|
+
-------
|
|
422
|
+
noiseModel: list of torch floats
|
|
423
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
424
|
+
|
|
425
|
+
"""
|
|
426
|
+
lb = np.percentile(signal, lowerClip)
|
|
427
|
+
ub = np.percentile(signal, upperClip)
|
|
428
|
+
stepsize = observation[0].size
|
|
429
|
+
n_observations = observation.shape[0]
|
|
430
|
+
n_signals = signal.shape[0]
|
|
431
|
+
sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
|
|
432
|
+
|
|
433
|
+
for i in range(n_observations):
|
|
434
|
+
j = i // (n_observations // n_signals)
|
|
435
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
|
|
436
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
|
|
437
|
+
sig_obs_pairs = sig_obs_pairs[
|
|
438
|
+
(sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
|
|
439
|
+
]
|
|
440
|
+
return self.fast_shuffle(sig_obs_pairs, 2)
|
|
441
|
+
|
|
442
|
+
def train(
|
|
443
|
+
self,
|
|
444
|
+
signal,
|
|
445
|
+
observation,
|
|
446
|
+
learning_rate=1e-1,
|
|
447
|
+
batchSize=250000,
|
|
448
|
+
n_epochs=2000,
|
|
449
|
+
name="GMMNoiseModel.npz",
|
|
450
|
+
lowerClip=0,
|
|
451
|
+
upperClip=100,
|
|
452
|
+
):
|
|
453
|
+
"""Training to learn the noise model from signal - observation pairs.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
signal: numpy array
|
|
458
|
+
Clean Signal Data
|
|
459
|
+
observation: numpy array
|
|
460
|
+
Noisy Observation Data
|
|
461
|
+
learning_rate: float
|
|
462
|
+
Learning rate. Default = 1e-1.
|
|
463
|
+
batchSize: int
|
|
464
|
+
Nini-batch size. Default = 250000.
|
|
465
|
+
n_epochs: int
|
|
466
|
+
Number of epochs. Default = 2000.
|
|
467
|
+
name: string
|
|
468
|
+
Model name. Default is `GMMNoiseModel`. This model after being trained is
|
|
469
|
+
saved at the location `path`.
|
|
470
|
+
|
|
471
|
+
lowerClip : int
|
|
472
|
+
Lower percentile for clipping. Default is 0.
|
|
473
|
+
upperClip : int
|
|
474
|
+
Upper percentile for clipping. Default is 100.
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
"""
|
|
478
|
+
sig_obs_pairs = self.getSignalObservationPairs(
|
|
479
|
+
signal, observation, lowerClip, upperClip
|
|
480
|
+
)
|
|
481
|
+
counter = 0
|
|
482
|
+
optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
|
|
483
|
+
for t in range(n_epochs):
|
|
484
|
+
jointLoss = 0
|
|
485
|
+
if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
|
|
486
|
+
counter = 0
|
|
487
|
+
sig_obs_pairs = self.fast_shuffle(sig_obs_pairs, 1)
|
|
488
|
+
|
|
489
|
+
batch_vectors = sig_obs_pairs[
|
|
490
|
+
counter * batchSize : (counter + 1) * batchSize, :
|
|
491
|
+
]
|
|
492
|
+
observations = batch_vectors[:, 1].astype(np.float32)
|
|
493
|
+
signals = batch_vectors[:, 0].astype(np.float32)
|
|
494
|
+
observations = (
|
|
495
|
+
torch.from_numpy(observations.astype(np.float32))
|
|
496
|
+
.float()
|
|
497
|
+
.to(self.device)
|
|
498
|
+
)
|
|
499
|
+
signals = torch.from_numpy(signals).float().to(self.device)
|
|
500
|
+
p = self.likelihood(observations, signals)
|
|
501
|
+
loss = torch.mean(-torch.log(p))
|
|
502
|
+
jointLoss = jointLoss + loss
|
|
503
|
+
|
|
504
|
+
if t % 100 == 0:
|
|
505
|
+
print(t, jointLoss.item())
|
|
506
|
+
|
|
507
|
+
if t % (int(n_epochs * 0.5)) == 0:
|
|
508
|
+
trained_weight = self.weight.cpu().detach().numpy()
|
|
509
|
+
min_signal = self.min_signal.cpu().detach().numpy()
|
|
510
|
+
max_signal = self.max_signal.cpu().detach().numpy()
|
|
511
|
+
np.savez(
|
|
512
|
+
self.path + name,
|
|
513
|
+
trained_weight=trained_weight,
|
|
514
|
+
min_signal=min_signal,
|
|
515
|
+
max_signal=max_signal,
|
|
516
|
+
min_sigma=self.min_sigma,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
optimizer.zero_grad()
|
|
520
|
+
jointLoss.backward()
|
|
521
|
+
optimizer.step()
|
|
522
|
+
counter += 1
|
|
523
|
+
|
|
524
|
+
logger.info(f"The trained parameters {name} is saved at location: " + self.path)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Bioimage Model Zoo format functions."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"create_model_description",
|
|
5
|
+
"extract_model_path",
|
|
6
|
+
"get_unzip_path",
|
|
7
|
+
"create_env_text",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
from .bioimage_utils import create_env_text, get_unzip_path
|
|
11
|
+
from .model_description import create_model_description, extract_model_path
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Functions used to create a README.md file for BMZ export."""
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
|
|
7
|
+
from careamics.config import Configuration
|
|
8
|
+
from careamics.utils import cwd, get_careamics_home
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _yaml_block(yaml_str: str) -> str:
|
|
12
|
+
"""Return a markdown code block with a yaml string.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
yaml_str : str
|
|
17
|
+
YAML string.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
str
|
|
22
|
+
Markdown code block with the YAML string.
|
|
23
|
+
"""
|
|
24
|
+
return f"```yaml\n{yaml_str}\n```"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def readme_factory(
|
|
28
|
+
config: Configuration,
|
|
29
|
+
careamics_version: str,
|
|
30
|
+
data_description: Optional[str] = None,
|
|
31
|
+
) -> Path:
|
|
32
|
+
"""Create a README file for the model.
|
|
33
|
+
|
|
34
|
+
`data_description` can be used to add more information about the content of the
|
|
35
|
+
data the model was trained on.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
config : Configuration
|
|
40
|
+
CAREamics configuration.
|
|
41
|
+
careamics_version : str
|
|
42
|
+
CAREamics version.
|
|
43
|
+
data_description : Optional[str], optional
|
|
44
|
+
Description of the data, by default None.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
Path
|
|
49
|
+
Path to the README file.
|
|
50
|
+
"""
|
|
51
|
+
algorithm = config.algorithm_config
|
|
52
|
+
training = config.training_config
|
|
53
|
+
data = config.data_config
|
|
54
|
+
|
|
55
|
+
# create file
|
|
56
|
+
# TODO use tempfile as in the bmz_io module
|
|
57
|
+
with cwd(get_careamics_home()):
|
|
58
|
+
readme = Path("README.md")
|
|
59
|
+
readme.touch()
|
|
60
|
+
|
|
61
|
+
# algorithm pretty name
|
|
62
|
+
algorithm_flavour = config.get_algorithm_flavour()
|
|
63
|
+
algorithm_pretty_name = algorithm_flavour + " - CAREamics"
|
|
64
|
+
|
|
65
|
+
description = [f"# {algorithm_pretty_name}\n\n"]
|
|
66
|
+
|
|
67
|
+
# algorithm description
|
|
68
|
+
description.append("Algorithm description:\n\n")
|
|
69
|
+
description.append(config.get_algorithm_description())
|
|
70
|
+
description.append("\n\n")
|
|
71
|
+
|
|
72
|
+
# algorithm details
|
|
73
|
+
description.append(
|
|
74
|
+
f"{algorithm_flavour} was trained using CAREamics (version "
|
|
75
|
+
f"{careamics_version}) with the following algorithm "
|
|
76
|
+
f"parameters:\n\n"
|
|
77
|
+
)
|
|
78
|
+
description.append(
|
|
79
|
+
_yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True)))
|
|
80
|
+
)
|
|
81
|
+
description.append("\n\n")
|
|
82
|
+
|
|
83
|
+
# data description
|
|
84
|
+
description.append("## Data description\n\n")
|
|
85
|
+
if data_description is not None:
|
|
86
|
+
description.append(data_description)
|
|
87
|
+
description.append("\n\n")
|
|
88
|
+
|
|
89
|
+
description.append("The data was processed using the following parameters:\n\n")
|
|
90
|
+
|
|
91
|
+
description.append(_yaml_block(yaml.dump(data.model_dump(exclude_none=True))))
|
|
92
|
+
description.append("\n\n")
|
|
93
|
+
|
|
94
|
+
# training description
|
|
95
|
+
description.append("## Training description\n\n")
|
|
96
|
+
|
|
97
|
+
description.append("The model was trained using the following parameters:\n\n")
|
|
98
|
+
|
|
99
|
+
description.append(
|
|
100
|
+
_yaml_block(yaml.dump(training.model_dump(exclude_none=True)))
|
|
101
|
+
)
|
|
102
|
+
description.append("\n\n")
|
|
103
|
+
|
|
104
|
+
# references
|
|
105
|
+
reference = config.get_algorithm_references()
|
|
106
|
+
if reference != "":
|
|
107
|
+
description.append("## References\n\n")
|
|
108
|
+
description.append(reference)
|
|
109
|
+
description.append("\n\n")
|
|
110
|
+
|
|
111
|
+
# links
|
|
112
|
+
description.append(
|
|
113
|
+
"## Links\n\n"
|
|
114
|
+
"- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
|
|
115
|
+
"- [CAREamics documentation](https://careamics.github.io/latest/)\n"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
readme.write_text("".join(description))
|
|
119
|
+
|
|
120
|
+
return readme
|