careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,50 +1,111 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Script containing modules for
|
|
2
|
+
Script containing modules for defining different likelihood functions (as nn.Module).
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
5
7
|
import math
|
|
6
|
-
from typing import
|
|
8
|
+
from typing import Literal, Union, TYPE_CHECKING, Any, Optional
|
|
7
9
|
|
|
8
|
-
import numpy as np
|
|
9
10
|
import torch
|
|
10
11
|
from torch import nn
|
|
11
12
|
|
|
13
|
+
from careamics.config.likelihood_model import (
|
|
14
|
+
GaussianLikelihoodConfig,
|
|
15
|
+
NMLikelihoodConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from careamics.models.lvae.noise_models import (
|
|
20
|
+
GaussianMixtureNoiseModel,
|
|
21
|
+
MultiChannelNoiseModel,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def likelihood_factory(
|
|
28
|
+
config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig, None]
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Factory function for creating likelihood modules.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig]
|
|
36
|
+
The configuration object for the likelihood module.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
nn.Module
|
|
41
|
+
The likelihood module.
|
|
42
|
+
"""
|
|
43
|
+
if config is None:
|
|
44
|
+
return None
|
|
12
45
|
|
|
46
|
+
if isinstance(config, GaussianLikelihoodConfig):
|
|
47
|
+
return GaussianLikelihood(
|
|
48
|
+
predict_logvar=config.predict_logvar,
|
|
49
|
+
logvar_lowerbound=config.logvar_lowerbound,
|
|
50
|
+
)
|
|
51
|
+
elif isinstance(config, NMLikelihoodConfig):
|
|
52
|
+
return NoiseModelLikelihood(
|
|
53
|
+
data_mean=config.data_mean,
|
|
54
|
+
data_std=config.data_std,
|
|
55
|
+
noiseModel=config.noise_model,
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(f"Invalid likelihood model type: {config.model_type}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
|
|
13
62
|
class LikelihoodModule(nn.Module):
|
|
14
63
|
"""
|
|
15
64
|
The base class for all likelihood modules.
|
|
16
65
|
It defines the fundamental structure and methods for specialized likelihood models.
|
|
17
66
|
"""
|
|
18
67
|
|
|
19
|
-
def distr_params(self, x):
|
|
68
|
+
def distr_params(self, x: Any) -> None:
|
|
20
69
|
return None
|
|
21
70
|
|
|
22
|
-
def set_params_to_same_device_as(self, correct_device_tensor):
|
|
71
|
+
def set_params_to_same_device_as(self, correct_device_tensor: Any) -> None:
|
|
23
72
|
pass
|
|
24
73
|
|
|
25
74
|
@staticmethod
|
|
26
|
-
def logvar(params):
|
|
75
|
+
def logvar(params: Any) -> None:
|
|
27
76
|
return None
|
|
28
77
|
|
|
29
78
|
@staticmethod
|
|
30
|
-
def mean(params):
|
|
79
|
+
def mean(params: Any) -> None:
|
|
31
80
|
return None
|
|
32
81
|
|
|
33
82
|
@staticmethod
|
|
34
|
-
def mode(params):
|
|
83
|
+
def mode(params: Any) -> None:
|
|
35
84
|
return None
|
|
36
85
|
|
|
37
86
|
@staticmethod
|
|
38
|
-
def sample(params):
|
|
87
|
+
def sample(params: Any) -> None:
|
|
39
88
|
return None
|
|
40
89
|
|
|
41
|
-
def log_likelihood(self, x, params):
|
|
90
|
+
def log_likelihood(self, x: Any, params: Any) -> None:
|
|
42
91
|
return None
|
|
43
92
|
|
|
44
|
-
def
|
|
45
|
-
self,
|
|
46
|
-
) ->
|
|
93
|
+
def get_mean_lv(
|
|
94
|
+
self, x: torch.Tensor
|
|
95
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
|
|
47
96
|
|
|
97
|
+
def forward(
|
|
98
|
+
self, input_: torch.Tensor, x: Union[torch.Tensor, None]
|
|
99
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
100
|
+
"""
|
|
101
|
+
Parameters:
|
|
102
|
+
-----------
|
|
103
|
+
input_: torch.Tensor
|
|
104
|
+
The output of the top-down pass (e.g., reconstructed image in HDN,
|
|
105
|
+
or the unmixed images in 'Split' models).
|
|
106
|
+
x: Union[torch.Tensor, None]
|
|
107
|
+
The target tensor. If None, the log-likelihood is not computed.
|
|
108
|
+
"""
|
|
48
109
|
distr_params = self.distr_params(input_)
|
|
49
110
|
mean = self.mean(distr_params)
|
|
50
111
|
mode = self.mode(distr_params)
|
|
@@ -68,8 +129,7 @@ class LikelihoodModule(nn.Module):
|
|
|
68
129
|
|
|
69
130
|
|
|
70
131
|
class GaussianLikelihood(LikelihoodModule):
|
|
71
|
-
r"""
|
|
72
|
-
A specialize `LikelihoodModule` for Gaussian likelihood.
|
|
132
|
+
r"""A specialized `LikelihoodModule` for Gaussian likelihood.
|
|
73
133
|
|
|
74
134
|
Specifically, in the LVAE model, the likelihood is defined as:
|
|
75
135
|
p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
|
|
@@ -77,50 +137,32 @@ class GaussianLikelihood(LikelihoodModule):
|
|
|
77
137
|
|
|
78
138
|
def __init__(
|
|
79
139
|
self,
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
predict_logvar: Literal[None, "pixelwise", "global", "channelwise"] = None,
|
|
83
|
-
logvar_lowerbound: float = None,
|
|
84
|
-
conv2d_bias: bool = True,
|
|
140
|
+
predict_logvar: Union[Literal["pixelwise"], None] = None,
|
|
141
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
85
142
|
):
|
|
86
|
-
"""
|
|
87
|
-
Constructor.
|
|
143
|
+
"""Constructor.
|
|
88
144
|
|
|
89
145
|
Parameters
|
|
90
146
|
----------
|
|
91
|
-
predict_logvar: Literal[
|
|
92
|
-
If
|
|
93
|
-
|
|
94
|
-
- if `pixelwise`, log-variance is computed for each pixel.
|
|
95
|
-
- if `global`, log-variance is computed as the mean of all pixel-wise entries.
|
|
96
|
-
- if `channelwise`, log-variance is computed as the average over the channels.
|
|
97
|
-
Default is `None`.
|
|
147
|
+
predict_logvar: Union[Literal["pixelwise"], None], optional
|
|
148
|
+
If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
149
|
+
is not computed. Default is `None`.
|
|
98
150
|
logvar_lowerbound: float, optional
|
|
99
151
|
The lowerbound value for log-variance. Default is `None`.
|
|
100
|
-
conv2d_bias: bool, optional
|
|
101
|
-
Whether to use bias term in convolutions. Default is `True`.
|
|
102
152
|
"""
|
|
103
153
|
super().__init__()
|
|
104
154
|
|
|
105
|
-
# If True, then we also predict pixelwise logvar.
|
|
106
155
|
self.predict_logvar = predict_logvar
|
|
107
156
|
self.logvar_lowerbound = logvar_lowerbound
|
|
108
|
-
self.
|
|
109
|
-
assert self.predict_logvar in [None, "global", "pixelwise", "channelwise"]
|
|
110
|
-
|
|
111
|
-
# logvar_ch_needed = self.predict_logvar is not None
|
|
112
|
-
# self.parameter_net = nn.Conv2d(ch_in,
|
|
113
|
-
# color_channels * (1 + logvar_ch_needed),
|
|
114
|
-
# kernel_size=3,
|
|
115
|
-
# padding=1,
|
|
116
|
-
# bias=self.conv2d_bias)
|
|
117
|
-
self.parameter_net = nn.Identity()
|
|
157
|
+
assert self.predict_logvar in [None, "pixelwise"]
|
|
118
158
|
|
|
119
159
|
print(
|
|
120
160
|
f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
|
|
121
161
|
)
|
|
122
162
|
|
|
123
|
-
def get_mean_lv(
|
|
163
|
+
def get_mean_lv(
|
|
164
|
+
self, x: torch.Tensor
|
|
165
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
124
166
|
"""
|
|
125
167
|
Given the output of the top-down pass, compute the mean and log-variance of the
|
|
126
168
|
Gaussian distribution defining the likelihood.
|
|
@@ -128,50 +170,42 @@ class GaussianLikelihood(LikelihoodModule):
|
|
|
128
170
|
Parameters
|
|
129
171
|
----------
|
|
130
172
|
x: torch.Tensor
|
|
131
|
-
The input tensor to the likelihood module, i.e., the output of the top-down
|
|
173
|
+
The input tensor to the likelihood module, i.e., the output of the top-down
|
|
174
|
+
pass.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
tuple of (torch.tensor, optional torch.tensor)
|
|
179
|
+
The first element of the tuple is the mean, the second element is the
|
|
180
|
+
log-variance. If the attribute `predict_logvar` is `None` then the second
|
|
181
|
+
element will be `None`.
|
|
132
182
|
"""
|
|
133
|
-
# Feed the output of the top-down pass to a parameter network
|
|
134
|
-
# This network can be either a Conv2d or Identity module
|
|
135
|
-
x = self.parameter_net(x)
|
|
136
183
|
|
|
137
|
-
if
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
new_shape = (*mean.shape[:1], *([1] * len(mean.shape[1:])))
|
|
151
|
-
else:
|
|
152
|
-
raise ValueError(
|
|
153
|
-
f"Invalid value for self.predict_logvar:{self.predict_logvar}"
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
lv = torch.mean(lv.reshape(N, -1), dim=1)
|
|
157
|
-
lv = lv.reshape(new_shape)
|
|
158
|
-
|
|
159
|
-
# Optionally, clip log-var to a lower bound
|
|
160
|
-
if self.logvar_lowerbound is not None:
|
|
161
|
-
lv = torch.clip(lv, min=self.logvar_lowerbound)
|
|
162
|
-
else:
|
|
163
|
-
mean = x
|
|
164
|
-
lv = None
|
|
184
|
+
# if LadderVAE.predict_logvar is None, dim 1 of `x`` has no. of target channels
|
|
185
|
+
if self.predict_logvar is None:
|
|
186
|
+
return x, None
|
|
187
|
+
|
|
188
|
+
# Get pixel-wise mean and logvar
|
|
189
|
+
# if LadderVAE.predict_logvar is not None,
|
|
190
|
+
# dim 1 has double no. of target channels
|
|
191
|
+
mean, lv = x.chunk(2, dim=1)
|
|
192
|
+
|
|
193
|
+
# Optionally, clip log-var to a lower bound
|
|
194
|
+
if self.logvar_lowerbound is not None:
|
|
195
|
+
lv = torch.clip(lv, min=self.logvar_lowerbound)
|
|
196
|
+
|
|
165
197
|
return mean, lv
|
|
166
198
|
|
|
167
|
-
def distr_params(self, x: torch.Tensor) ->
|
|
199
|
+
def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
168
200
|
"""
|
|
169
201
|
Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
|
|
170
202
|
|
|
171
203
|
Parameters
|
|
172
204
|
----------
|
|
173
205
|
x: torch.Tensor
|
|
174
|
-
The input tensor to the likelihood module, i.e., the output
|
|
206
|
+
The input tensor to the likelihood module, i.e., the output
|
|
207
|
+
the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
|
|
208
|
+
`predict_logvar` is not None, or (B, C, [Z], Y, X) otherwise.
|
|
175
209
|
"""
|
|
176
210
|
mean, lv = self.get_mean_lv(x)
|
|
177
211
|
params = {
|
|
@@ -181,24 +215,41 @@ class GaussianLikelihood(LikelihoodModule):
|
|
|
181
215
|
return params
|
|
182
216
|
|
|
183
217
|
@staticmethod
|
|
184
|
-
def mean(params):
|
|
218
|
+
def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
185
219
|
return params["mean"]
|
|
186
220
|
|
|
187
221
|
@staticmethod
|
|
188
|
-
def mode(params):
|
|
222
|
+
def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
189
223
|
return params["mean"]
|
|
190
224
|
|
|
191
225
|
@staticmethod
|
|
192
|
-
def sample(params):
|
|
226
|
+
def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
193
227
|
# p = Normal(params['mean'], (params['logvar'] / 2).exp())
|
|
194
228
|
# return p.rsample()
|
|
195
229
|
return params["mean"]
|
|
196
230
|
|
|
197
231
|
@staticmethod
|
|
198
|
-
def logvar(params):
|
|
232
|
+
def logvar(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
199
233
|
return params["logvar"]
|
|
200
234
|
|
|
201
|
-
def log_likelihood(
|
|
235
|
+
def log_likelihood(
|
|
236
|
+
self, x: torch.Tensor, params: dict[str, Union[torch.Tensor, None]]
|
|
237
|
+
):
|
|
238
|
+
"""Compute Gaussian log-likelihood
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
x: torch.Tensor
|
|
243
|
+
The target tensor. Shape is (B, C, [Z], Y, X).
|
|
244
|
+
params: dict[str, Union[torch.Tensor, None]]
|
|
245
|
+
The tensors obtained by chunking the output of the top-down pass,
|
|
246
|
+
here used as parameters of the Gaussian distribution.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
torch.Tensor
|
|
251
|
+
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
252
|
+
"""
|
|
202
253
|
if self.predict_logvar is not None:
|
|
203
254
|
logprob = log_normal(x, params["mean"], params["logvar"])
|
|
204
255
|
else:
|
|
@@ -236,39 +287,39 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
236
287
|
|
|
237
288
|
def __init__(
|
|
238
289
|
self,
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
data_std: Union[Dict[str, torch.Tensor], torch.Tensor],
|
|
243
|
-
noiseModel: nn.Module,
|
|
290
|
+
data_mean: torch.Tensor,
|
|
291
|
+
data_std: torch.Tensor,
|
|
292
|
+
noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports...
|
|
244
293
|
):
|
|
294
|
+
"""Constructor.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
data_mean: torch.Tensor
|
|
299
|
+
The mean of the data, used to unnormalize data for noise model evaluation.
|
|
300
|
+
data_std: torch.Tensor
|
|
301
|
+
The standard deviation of the data, used to unnormalize data for noise
|
|
302
|
+
model evaluation.
|
|
303
|
+
noiseModel: NoiseModel
|
|
304
|
+
The noise model instance used to compute the likelihood.
|
|
305
|
+
"""
|
|
245
306
|
super().__init__()
|
|
246
|
-
self.parameter_net = (
|
|
247
|
-
nn.Identity()
|
|
248
|
-
) # nn.Conv2d(ch_in, color_channels, kernel_size=3, padding=1)
|
|
249
307
|
self.data_mean = data_mean
|
|
250
308
|
self.data_std = data_std
|
|
251
309
|
self.noiseModel = noiseModel
|
|
252
310
|
|
|
253
|
-
def set_params_to_same_device_as(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
)
|
|
263
|
-
self.data_std[key] = self.data_std[key].to(correct_device_tensor.device)
|
|
264
|
-
|
|
265
|
-
def get_mean_lv(self, x):
|
|
266
|
-
return self.parameter_net(x), None
|
|
267
|
-
|
|
268
|
-
def distr_params(self, x):
|
|
269
|
-
mean, lv = self.get_mean_lv(x)
|
|
270
|
-
# mean, lv = x.chunk(2, dim=1)
|
|
311
|
+
def set_params_to_same_device_as(
|
|
312
|
+
self, correct_device_tensor: torch.Tensor
|
|
313
|
+
) -> None: # TODO: needed?
|
|
314
|
+
if self.data_mean.device != correct_device_tensor.device:
|
|
315
|
+
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
316
|
+
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
317
|
+
|
|
318
|
+
def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
|
|
319
|
+
return x, None
|
|
271
320
|
|
|
321
|
+
def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
322
|
+
mean, lv = self.get_mean_lv(x)
|
|
272
323
|
params = {
|
|
273
324
|
"mean": mean,
|
|
274
325
|
"logvar": lv,
|
|
@@ -276,37 +327,38 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
276
327
|
return params
|
|
277
328
|
|
|
278
329
|
@staticmethod
|
|
279
|
-
def mean(params):
|
|
330
|
+
def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
280
331
|
return params["mean"]
|
|
281
332
|
|
|
282
333
|
@staticmethod
|
|
283
|
-
def mode(params):
|
|
334
|
+
def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
284
335
|
return params["mean"]
|
|
285
336
|
|
|
286
337
|
@staticmethod
|
|
287
|
-
def sample(params):
|
|
288
|
-
# p = Normal(params['mean'], (params['logvar'] / 2).exp())
|
|
289
|
-
# return p.rsample()
|
|
338
|
+
def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
290
339
|
return params["mean"]
|
|
291
340
|
|
|
292
|
-
def log_likelihood(self, x: torch.Tensor, params:
|
|
293
|
-
"""
|
|
294
|
-
|
|
341
|
+
def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
|
|
342
|
+
"""Compute the log-likelihood given the parameters `params` obtained
|
|
343
|
+
from the reconstruction tensor and the target tensor `x`.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
x: torch.Tensor
|
|
348
|
+
The target tensor. Shape is (B, C, [Z], Y, X).
|
|
349
|
+
params: dict[str, Union[torch.Tensor, None]]
|
|
350
|
+
The tensors obtained from output of the top-down pass.
|
|
351
|
+
Here, "mean" correspond to the whole output, while logvar is `None`.
|
|
352
|
+
|
|
353
|
+
Returns
|
|
354
|
+
-------
|
|
355
|
+
torch.Tensor
|
|
356
|
+
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
295
357
|
"""
|
|
296
|
-
predicted_s_denormalized =
|
|
297
|
-
|
|
298
|
-
)
|
|
299
|
-
x_denormalized = x * self.data_std["target"] + self.data_mean["target"]
|
|
300
|
-
# predicted_s_cloned = predicted_s_denormalized
|
|
301
|
-
# predicted_s_reduced = predicted_s_cloned.permute(1, 0, 2, 3)
|
|
302
|
-
|
|
303
|
-
# x_cloned = x_denormalized
|
|
304
|
-
# x_cloned = x_cloned.permute(1, 0, 2, 3)
|
|
305
|
-
# x_reduced = x_cloned[0, ...]
|
|
306
|
-
# import pdb;pdb.set_trace()
|
|
358
|
+
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
|
|
359
|
+
x_denormalized = x * self.data_std + self.data_mean
|
|
307
360
|
likelihoods = self.noiseModel.likelihood(
|
|
308
361
|
x_denormalized, predicted_s_denormalized
|
|
309
362
|
)
|
|
310
|
-
# likelihoods = self.noiseModel.likelihood(x, params['mean'])
|
|
311
363
|
logprob = torch.log(likelihoods)
|
|
312
364
|
return logprob
|