careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script containing modules for defining different likelihood functions (as nn.Module).
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import Literal, Union, TYPE_CHECKING, Any, Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from careamics.config.likelihood_model import (
|
|
14
|
+
GaussianLikelihoodConfig,
|
|
15
|
+
NMLikelihoodConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from careamics.models.lvae.noise_models import (
|
|
20
|
+
GaussianMixtureNoiseModel,
|
|
21
|
+
MultiChannelNoiseModel,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def likelihood_factory(
|
|
28
|
+
config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig, None]
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Factory function for creating likelihood modules.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig]
|
|
36
|
+
The configuration object for the likelihood module.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
nn.Module
|
|
41
|
+
The likelihood module.
|
|
42
|
+
"""
|
|
43
|
+
if config is None:
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
if isinstance(config, GaussianLikelihoodConfig):
|
|
47
|
+
return GaussianLikelihood(
|
|
48
|
+
predict_logvar=config.predict_logvar,
|
|
49
|
+
logvar_lowerbound=config.logvar_lowerbound,
|
|
50
|
+
)
|
|
51
|
+
elif isinstance(config, NMLikelihoodConfig):
|
|
52
|
+
return NoiseModelLikelihood(
|
|
53
|
+
data_mean=config.data_mean,
|
|
54
|
+
data_std=config.data_std,
|
|
55
|
+
noiseModel=config.noise_model,
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(f"Invalid likelihood model type: {config.model_type}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
|
|
62
|
+
class LikelihoodModule(nn.Module):
|
|
63
|
+
"""
|
|
64
|
+
The base class for all likelihood modules.
|
|
65
|
+
It defines the fundamental structure and methods for specialized likelihood models.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def distr_params(self, x: Any) -> None:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
def set_params_to_same_device_as(self, correct_device_tensor: Any) -> None:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def logvar(params: Any) -> None:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def mean(params: Any) -> None:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def mode(params: Any) -> None:
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def sample(params: Any) -> None:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
def log_likelihood(self, x: Any, params: Any) -> None:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
def get_mean_lv(
|
|
94
|
+
self, x: torch.Tensor
|
|
95
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
|
|
96
|
+
|
|
97
|
+
def forward(
|
|
98
|
+
self, input_: torch.Tensor, x: Union[torch.Tensor, None]
|
|
99
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
100
|
+
"""
|
|
101
|
+
Parameters:
|
|
102
|
+
-----------
|
|
103
|
+
input_: torch.Tensor
|
|
104
|
+
The output of the top-down pass (e.g., reconstructed image in HDN,
|
|
105
|
+
or the unmixed images in 'Split' models).
|
|
106
|
+
x: Union[torch.Tensor, None]
|
|
107
|
+
The target tensor. If None, the log-likelihood is not computed.
|
|
108
|
+
"""
|
|
109
|
+
distr_params = self.distr_params(input_)
|
|
110
|
+
mean = self.mean(distr_params)
|
|
111
|
+
mode = self.mode(distr_params)
|
|
112
|
+
sample = self.sample(distr_params)
|
|
113
|
+
logvar = self.logvar(distr_params)
|
|
114
|
+
|
|
115
|
+
if x is None:
|
|
116
|
+
ll = None
|
|
117
|
+
else:
|
|
118
|
+
ll = self.log_likelihood(x, distr_params)
|
|
119
|
+
|
|
120
|
+
dct = {
|
|
121
|
+
"mean": mean,
|
|
122
|
+
"mode": mode,
|
|
123
|
+
"sample": sample,
|
|
124
|
+
"params": distr_params,
|
|
125
|
+
"logvar": logvar,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
return ll, dct
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class GaussianLikelihood(LikelihoodModule):
|
|
132
|
+
r"""A specialized `LikelihoodModule` for Gaussian likelihood.
|
|
133
|
+
|
|
134
|
+
Specifically, in the LVAE model, the likelihood is defined as:
|
|
135
|
+
p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
predict_logvar: Union[Literal["pixelwise"], None] = None,
|
|
141
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
142
|
+
):
|
|
143
|
+
"""Constructor.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
predict_logvar: Union[Literal["pixelwise"], None], optional
|
|
148
|
+
If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
149
|
+
is not computed. Default is `None`.
|
|
150
|
+
logvar_lowerbound: float, optional
|
|
151
|
+
The lowerbound value for log-variance. Default is `None`.
|
|
152
|
+
"""
|
|
153
|
+
super().__init__()
|
|
154
|
+
|
|
155
|
+
self.predict_logvar = predict_logvar
|
|
156
|
+
self.logvar_lowerbound = logvar_lowerbound
|
|
157
|
+
assert self.predict_logvar in [None, "pixelwise"]
|
|
158
|
+
|
|
159
|
+
print(
|
|
160
|
+
f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def get_mean_lv(
|
|
164
|
+
self, x: torch.Tensor
|
|
165
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
166
|
+
"""
|
|
167
|
+
Given the output of the top-down pass, compute the mean and log-variance of the
|
|
168
|
+
Gaussian distribution defining the likelihood.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
x: torch.Tensor
|
|
173
|
+
The input tensor to the likelihood module, i.e., the output of the top-down
|
|
174
|
+
pass.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
tuple of (torch.tensor, optional torch.tensor)
|
|
179
|
+
The first element of the tuple is the mean, the second element is the
|
|
180
|
+
log-variance. If the attribute `predict_logvar` is `None` then the second
|
|
181
|
+
element will be `None`.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
# if LadderVAE.predict_logvar is None, dim 1 of `x`` has no. of target channels
|
|
185
|
+
if self.predict_logvar is None:
|
|
186
|
+
return x, None
|
|
187
|
+
|
|
188
|
+
# Get pixel-wise mean and logvar
|
|
189
|
+
# if LadderVAE.predict_logvar is not None,
|
|
190
|
+
# dim 1 has double no. of target channels
|
|
191
|
+
mean, lv = x.chunk(2, dim=1)
|
|
192
|
+
|
|
193
|
+
# Optionally, clip log-var to a lower bound
|
|
194
|
+
if self.logvar_lowerbound is not None:
|
|
195
|
+
lv = torch.clip(lv, min=self.logvar_lowerbound)
|
|
196
|
+
|
|
197
|
+
return mean, lv
|
|
198
|
+
|
|
199
|
+
def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
200
|
+
"""
|
|
201
|
+
Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
x: torch.Tensor
|
|
206
|
+
The input tensor to the likelihood module, i.e., the output
|
|
207
|
+
the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
|
|
208
|
+
`predict_logvar` is not None, or (B, C, [Z], Y, X) otherwise.
|
|
209
|
+
"""
|
|
210
|
+
mean, lv = self.get_mean_lv(x)
|
|
211
|
+
params = {
|
|
212
|
+
"mean": mean,
|
|
213
|
+
"logvar": lv,
|
|
214
|
+
}
|
|
215
|
+
return params
|
|
216
|
+
|
|
217
|
+
@staticmethod
|
|
218
|
+
def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
219
|
+
return params["mean"]
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
223
|
+
return params["mean"]
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
227
|
+
# p = Normal(params['mean'], (params['logvar'] / 2).exp())
|
|
228
|
+
# return p.rsample()
|
|
229
|
+
return params["mean"]
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
def logvar(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
233
|
+
return params["logvar"]
|
|
234
|
+
|
|
235
|
+
def log_likelihood(
|
|
236
|
+
self, x: torch.Tensor, params: dict[str, Union[torch.Tensor, None]]
|
|
237
|
+
):
|
|
238
|
+
"""Compute Gaussian log-likelihood
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
x: torch.Tensor
|
|
243
|
+
The target tensor. Shape is (B, C, [Z], Y, X).
|
|
244
|
+
params: dict[str, Union[torch.Tensor, None]]
|
|
245
|
+
The tensors obtained by chunking the output of the top-down pass,
|
|
246
|
+
here used as parameters of the Gaussian distribution.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
torch.Tensor
|
|
251
|
+
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
252
|
+
"""
|
|
253
|
+
if self.predict_logvar is not None:
|
|
254
|
+
logprob = log_normal(x, params["mean"], params["logvar"])
|
|
255
|
+
else:
|
|
256
|
+
logprob = -0.5 * (params["mean"] - x) ** 2
|
|
257
|
+
return logprob
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def log_normal(
|
|
261
|
+
x: torch.Tensor, mean: torch.Tensor, logvar: torch.Tensor
|
|
262
|
+
) -> torch.Tensor:
|
|
263
|
+
"""
|
|
264
|
+
Compute the log-probability at `x` of a Gaussian distribution
|
|
265
|
+
with parameters `(mean, exp(logvar))`.
|
|
266
|
+
|
|
267
|
+
NOTE: In the case of LVAE, the log-likeihood formula becomes:
|
|
268
|
+
\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{p_\theta(x|z_1)}]=-\frac{1}{2}(\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{2\\pi\\sigma_{p,0}^2(z_1)}] +\\mathbb{E}_{z_1\\sim{q_\\phi}}[\frac{(x-\\mu_{p,0}(z_1))^2}{\\sigma_{p,0}^2(z_1)}])
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
x: torch.Tensor
|
|
273
|
+
The ground-truth tensor. Shape is (batch, channels, dim1, dim2).
|
|
274
|
+
mean: torch.Tensor
|
|
275
|
+
The inferred mean of distribution. Shape is (batch, channels, dim1, dim2).
|
|
276
|
+
logvar: torch.Tensor
|
|
277
|
+
The inferred log-variance of distribution. Shape has to be either scalar or broadcastable.
|
|
278
|
+
"""
|
|
279
|
+
var = torch.exp(logvar)
|
|
280
|
+
log_prob = -0.5 * (
|
|
281
|
+
((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
|
|
282
|
+
)
|
|
283
|
+
return log_prob
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class NoiseModelLikelihood(LikelihoodModule):
|
|
287
|
+
|
|
288
|
+
def __init__(
|
|
289
|
+
self,
|
|
290
|
+
data_mean: torch.Tensor,
|
|
291
|
+
data_std: torch.Tensor,
|
|
292
|
+
noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports...
|
|
293
|
+
):
|
|
294
|
+
"""Constructor.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
data_mean: torch.Tensor
|
|
299
|
+
The mean of the data, used to unnormalize data for noise model evaluation.
|
|
300
|
+
data_std: torch.Tensor
|
|
301
|
+
The standard deviation of the data, used to unnormalize data for noise
|
|
302
|
+
model evaluation.
|
|
303
|
+
noiseModel: NoiseModel
|
|
304
|
+
The noise model instance used to compute the likelihood.
|
|
305
|
+
"""
|
|
306
|
+
super().__init__()
|
|
307
|
+
self.data_mean = data_mean
|
|
308
|
+
self.data_std = data_std
|
|
309
|
+
self.noiseModel = noiseModel
|
|
310
|
+
|
|
311
|
+
def set_params_to_same_device_as(
|
|
312
|
+
self, correct_device_tensor: torch.Tensor
|
|
313
|
+
) -> None: # TODO: needed?
|
|
314
|
+
if self.data_mean.device != correct_device_tensor.device:
|
|
315
|
+
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
316
|
+
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
317
|
+
|
|
318
|
+
def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
|
|
319
|
+
return x, None
|
|
320
|
+
|
|
321
|
+
def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
322
|
+
mean, lv = self.get_mean_lv(x)
|
|
323
|
+
params = {
|
|
324
|
+
"mean": mean,
|
|
325
|
+
"logvar": lv,
|
|
326
|
+
}
|
|
327
|
+
return params
|
|
328
|
+
|
|
329
|
+
@staticmethod
|
|
330
|
+
def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
331
|
+
return params["mean"]
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
335
|
+
return params["mean"]
|
|
336
|
+
|
|
337
|
+
@staticmethod
|
|
338
|
+
def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
339
|
+
return params["mean"]
|
|
340
|
+
|
|
341
|
+
def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
|
|
342
|
+
"""Compute the log-likelihood given the parameters `params` obtained
|
|
343
|
+
from the reconstruction tensor and the target tensor `x`.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
x: torch.Tensor
|
|
348
|
+
The target tensor. Shape is (B, C, [Z], Y, X).
|
|
349
|
+
params: dict[str, Union[torch.Tensor, None]]
|
|
350
|
+
The tensors obtained from output of the top-down pass.
|
|
351
|
+
Here, "mean" correspond to the whole output, while logvar is `None`.
|
|
352
|
+
|
|
353
|
+
Returns
|
|
354
|
+
-------
|
|
355
|
+
torch.Tensor
|
|
356
|
+
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
357
|
+
"""
|
|
358
|
+
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
|
|
359
|
+
x_denormalized = x * self.data_std + self.data_mean
|
|
360
|
+
likelihoods = self.noiseModel.likelihood(
|
|
361
|
+
x_denormalized, predicted_s_denormalized
|
|
362
|
+
)
|
|
363
|
+
logprob = torch.log(likelihoods)
|
|
364
|
+
return logprob
|