careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script containing modules for definining different likelihood functions (as nn.Module).
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from typing import Dict, Literal, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LikelihoodModule(nn.Module):
|
|
14
|
+
"""
|
|
15
|
+
The base class for all likelihood modules.
|
|
16
|
+
It defines the fundamental structure and methods for specialized likelihood models.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def distr_params(self, x):
|
|
20
|
+
return None
|
|
21
|
+
|
|
22
|
+
def set_params_to_same_device_as(self, correct_device_tensor):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def logvar(params):
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def mean(params):
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def mode(params):
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def sample(params):
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
def log_likelihood(self, x, params):
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
def forward(
|
|
45
|
+
self, input_: torch.Tensor, x: torch.Tensor
|
|
46
|
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
47
|
+
|
|
48
|
+
distr_params = self.distr_params(input_)
|
|
49
|
+
mean = self.mean(distr_params)
|
|
50
|
+
mode = self.mode(distr_params)
|
|
51
|
+
sample = self.sample(distr_params)
|
|
52
|
+
logvar = self.logvar(distr_params)
|
|
53
|
+
|
|
54
|
+
if x is None:
|
|
55
|
+
ll = None
|
|
56
|
+
else:
|
|
57
|
+
ll = self.log_likelihood(x, distr_params)
|
|
58
|
+
|
|
59
|
+
dct = {
|
|
60
|
+
"mean": mean,
|
|
61
|
+
"mode": mode,
|
|
62
|
+
"sample": sample,
|
|
63
|
+
"params": distr_params,
|
|
64
|
+
"logvar": logvar,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
return ll, dct
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class GaussianLikelihood(LikelihoodModule):
|
|
71
|
+
r"""
|
|
72
|
+
A specialize `LikelihoodModule` for Gaussian likelihood.
|
|
73
|
+
|
|
74
|
+
Specifically, in the LVAE model, the likelihood is defined as:
|
|
75
|
+
p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
ch_in: int,
|
|
81
|
+
color_channels: int,
|
|
82
|
+
predict_logvar: Literal[None, "pixelwise", "global", "channelwise"] = None,
|
|
83
|
+
logvar_lowerbound: float = None,
|
|
84
|
+
conv2d_bias: bool = True,
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Constructor.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
predict_logvar: Literal[None, 'global', 'pixelwise', 'channelwise'], optional
|
|
92
|
+
If not `None`, it expresses how to compute the log-variance.
|
|
93
|
+
Namely:
|
|
94
|
+
- if `pixelwise`, log-variance is computed for each pixel.
|
|
95
|
+
- if `global`, log-variance is computed as the mean of all pixel-wise entries.
|
|
96
|
+
- if `channelwise`, log-variance is computed as the average over the channels.
|
|
97
|
+
Default is `None`.
|
|
98
|
+
logvar_lowerbound: float, optional
|
|
99
|
+
The lowerbound value for log-variance. Default is `None`.
|
|
100
|
+
conv2d_bias: bool, optional
|
|
101
|
+
Whether to use bias term in convolutions. Default is `True`.
|
|
102
|
+
"""
|
|
103
|
+
super().__init__()
|
|
104
|
+
|
|
105
|
+
# If True, then we also predict pixelwise logvar.
|
|
106
|
+
self.predict_logvar = predict_logvar
|
|
107
|
+
self.logvar_lowerbound = logvar_lowerbound
|
|
108
|
+
self.conv2d_bias = conv2d_bias
|
|
109
|
+
assert self.predict_logvar in [None, "global", "pixelwise", "channelwise"]
|
|
110
|
+
|
|
111
|
+
# logvar_ch_needed = self.predict_logvar is not None
|
|
112
|
+
# self.parameter_net = nn.Conv2d(ch_in,
|
|
113
|
+
# color_channels * (1 + logvar_ch_needed),
|
|
114
|
+
# kernel_size=3,
|
|
115
|
+
# padding=1,
|
|
116
|
+
# bias=self.conv2d_bias)
|
|
117
|
+
self.parameter_net = nn.Identity()
|
|
118
|
+
|
|
119
|
+
print(
|
|
120
|
+
f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def get_mean_lv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
124
|
+
"""
|
|
125
|
+
Given the output of the top-down pass, compute the mean and log-variance of the
|
|
126
|
+
Gaussian distribution defining the likelihood.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
x: torch.Tensor
|
|
131
|
+
The input tensor to the likelihood module, i.e., the output of the top-down pass.
|
|
132
|
+
"""
|
|
133
|
+
# Feed the output of the top-down pass to a parameter network
|
|
134
|
+
# This network can be either a Conv2d or Identity module
|
|
135
|
+
x = self.parameter_net(x)
|
|
136
|
+
|
|
137
|
+
if self.predict_logvar is not None:
|
|
138
|
+
# Get pixel-wise mean and logvar
|
|
139
|
+
mean, lv = x.chunk(2, dim=1)
|
|
140
|
+
|
|
141
|
+
# Optionally, compute the global or channel-wise logvar
|
|
142
|
+
if self.predict_logvar in ["channelwise", "global"]:
|
|
143
|
+
if self.predict_logvar == "channelwise":
|
|
144
|
+
# logvar should be of the following shape (batch, num_channels, ). Other dims would be singletons.
|
|
145
|
+
N = np.prod(lv.shape[:2])
|
|
146
|
+
new_shape = (*mean.shape[:2], *([1] * len(mean.shape[2:])))
|
|
147
|
+
elif self.predict_logvar == "global":
|
|
148
|
+
# logvar should be of the following shape (batch, ). Other dims would be singletons.
|
|
149
|
+
N = lv.shape[0]
|
|
150
|
+
new_shape = (*mean.shape[:1], *([1] * len(mean.shape[1:])))
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Invalid value for self.predict_logvar:{self.predict_logvar}"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
lv = torch.mean(lv.reshape(N, -1), dim=1)
|
|
157
|
+
lv = lv.reshape(new_shape)
|
|
158
|
+
|
|
159
|
+
# Optionally, clip log-var to a lower bound
|
|
160
|
+
if self.logvar_lowerbound is not None:
|
|
161
|
+
lv = torch.clip(lv, min=self.logvar_lowerbound)
|
|
162
|
+
else:
|
|
163
|
+
mean = x
|
|
164
|
+
lv = None
|
|
165
|
+
return mean, lv
|
|
166
|
+
|
|
167
|
+
def distr_params(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
168
|
+
"""
|
|
169
|
+
Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
x: torch.Tensor
|
|
174
|
+
The input tensor to the likelihood module, i.e., the output of the top-down pass.
|
|
175
|
+
"""
|
|
176
|
+
mean, lv = self.get_mean_lv(x)
|
|
177
|
+
params = {
|
|
178
|
+
"mean": mean,
|
|
179
|
+
"logvar": lv,
|
|
180
|
+
}
|
|
181
|
+
return params
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def mean(params):
|
|
185
|
+
return params["mean"]
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def mode(params):
|
|
189
|
+
return params["mean"]
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def sample(params):
|
|
193
|
+
# p = Normal(params['mean'], (params['logvar'] / 2).exp())
|
|
194
|
+
# return p.rsample()
|
|
195
|
+
return params["mean"]
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def logvar(params):
|
|
199
|
+
return params["logvar"]
|
|
200
|
+
|
|
201
|
+
def log_likelihood(self, x, params):
|
|
202
|
+
if self.predict_logvar is not None:
|
|
203
|
+
logprob = log_normal(x, params["mean"], params["logvar"])
|
|
204
|
+
else:
|
|
205
|
+
logprob = -0.5 * (params["mean"] - x) ** 2
|
|
206
|
+
return logprob
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def log_normal(
|
|
210
|
+
x: torch.Tensor, mean: torch.Tensor, logvar: torch.Tensor
|
|
211
|
+
) -> torch.Tensor:
|
|
212
|
+
"""
|
|
213
|
+
Compute the log-probability at `x` of a Gaussian distribution
|
|
214
|
+
with parameters `(mean, exp(logvar))`.
|
|
215
|
+
|
|
216
|
+
NOTE: In the case of LVAE, the log-likeihood formula becomes:
|
|
217
|
+
\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{p_\theta(x|z_1)}]=-\frac{1}{2}(\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{2\\pi\\sigma_{p,0}^2(z_1)}] +\\mathbb{E}_{z_1\\sim{q_\\phi}}[\frac{(x-\\mu_{p,0}(z_1))^2}{\\sigma_{p,0}^2(z_1)}])
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
x: torch.Tensor
|
|
222
|
+
The ground-truth tensor. Shape is (batch, channels, dim1, dim2).
|
|
223
|
+
mean: torch.Tensor
|
|
224
|
+
The inferred mean of distribution. Shape is (batch, channels, dim1, dim2).
|
|
225
|
+
logvar: torch.Tensor
|
|
226
|
+
The inferred log-variance of distribution. Shape has to be either scalar or broadcastable.
|
|
227
|
+
"""
|
|
228
|
+
var = torch.exp(logvar)
|
|
229
|
+
log_prob = -0.5 * (
|
|
230
|
+
((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
|
|
231
|
+
)
|
|
232
|
+
return log_prob
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class NoiseModelLikelihood(LikelihoodModule):
|
|
236
|
+
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
ch_in: int,
|
|
240
|
+
color_channels: int,
|
|
241
|
+
data_mean: Union[Dict[str, torch.Tensor], torch.Tensor],
|
|
242
|
+
data_std: Union[Dict[str, torch.Tensor], torch.Tensor],
|
|
243
|
+
noiseModel: nn.Module,
|
|
244
|
+
):
|
|
245
|
+
super().__init__()
|
|
246
|
+
self.parameter_net = (
|
|
247
|
+
nn.Identity()
|
|
248
|
+
) # nn.Conv2d(ch_in, color_channels, kernel_size=3, padding=1)
|
|
249
|
+
self.data_mean = data_mean
|
|
250
|
+
self.data_std = data_std
|
|
251
|
+
self.noiseModel = noiseModel
|
|
252
|
+
|
|
253
|
+
def set_params_to_same_device_as(self, correct_device_tensor):
|
|
254
|
+
if isinstance(self.data_mean, torch.Tensor):
|
|
255
|
+
if self.data_mean.device != correct_device_tensor.device:
|
|
256
|
+
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
257
|
+
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
258
|
+
elif isinstance(self.data_mean, dict):
|
|
259
|
+
for key in self.data_mean.keys():
|
|
260
|
+
self.data_mean[key] = self.data_mean[key].to(
|
|
261
|
+
correct_device_tensor.device
|
|
262
|
+
)
|
|
263
|
+
self.data_std[key] = self.data_std[key].to(correct_device_tensor.device)
|
|
264
|
+
|
|
265
|
+
def get_mean_lv(self, x):
|
|
266
|
+
return self.parameter_net(x), None
|
|
267
|
+
|
|
268
|
+
def distr_params(self, x):
|
|
269
|
+
mean, lv = self.get_mean_lv(x)
|
|
270
|
+
# mean, lv = x.chunk(2, dim=1)
|
|
271
|
+
|
|
272
|
+
params = {
|
|
273
|
+
"mean": mean,
|
|
274
|
+
"logvar": lv,
|
|
275
|
+
}
|
|
276
|
+
return params
|
|
277
|
+
|
|
278
|
+
@staticmethod
|
|
279
|
+
def mean(params):
|
|
280
|
+
return params["mean"]
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def mode(params):
|
|
284
|
+
return params["mean"]
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def sample(params):
|
|
288
|
+
# p = Normal(params['mean'], (params['logvar'] / 2).exp())
|
|
289
|
+
# return p.rsample()
|
|
290
|
+
return params["mean"]
|
|
291
|
+
|
|
292
|
+
def log_likelihood(self, x: torch.Tensor, params: Dict[str, torch.Tensor]):
|
|
293
|
+
"""
|
|
294
|
+
Compute the log-likelihood given the parameters `params` obtained from the reconstruction tensor and the target tensor `x`.
|
|
295
|
+
"""
|
|
296
|
+
predicted_s_denormalized = (
|
|
297
|
+
params["mean"] * self.data_std["target"] + self.data_mean["target"]
|
|
298
|
+
)
|
|
299
|
+
x_denormalized = x * self.data_std["target"] + self.data_mean["target"]
|
|
300
|
+
# predicted_s_cloned = predicted_s_denormalized
|
|
301
|
+
# predicted_s_reduced = predicted_s_cloned.permute(1, 0, 2, 3)
|
|
302
|
+
|
|
303
|
+
# x_cloned = x_denormalized
|
|
304
|
+
# x_cloned = x_cloned.permute(1, 0, 2, 3)
|
|
305
|
+
# x_reduced = x_cloned[0, ...]
|
|
306
|
+
# import pdb;pdb.set_trace()
|
|
307
|
+
likelihoods = self.noiseModel.likelihood(
|
|
308
|
+
x_denormalized, predicted_s_denormalized
|
|
309
|
+
)
|
|
310
|
+
# likelihoods = self.noiseModel.likelihood(x, params['mean'])
|
|
311
|
+
logprob = torch.log(likelihoods)
|
|
312
|
+
return logprob
|