careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,50 +1,112 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Script containing modules for
|
|
2
|
+
Script containing modules for defining different likelihood functions (as nn.Module).
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
5
7
|
import math
|
|
6
|
-
from typing import
|
|
8
|
+
from typing import Literal, Union, TYPE_CHECKING, Any, Optional
|
|
7
9
|
|
|
8
10
|
import numpy as np
|
|
9
11
|
import torch
|
|
10
12
|
from torch import nn
|
|
11
13
|
|
|
14
|
+
from careamics.config.likelihood_model import (
|
|
15
|
+
GaussianLikelihoodConfig,
|
|
16
|
+
NMLikelihoodConfig,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from careamics.models.lvae.noise_models import (
|
|
21
|
+
GaussianMixtureNoiseModel,
|
|
22
|
+
MultiChannelNoiseModel,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def likelihood_factory(
|
|
29
|
+
config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig, None]
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Factory function for creating likelihood modules.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig]
|
|
37
|
+
The configuration object for the likelihood module.
|
|
12
38
|
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
nn.Module
|
|
42
|
+
The likelihood module.
|
|
43
|
+
"""
|
|
44
|
+
if config is None:
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
if isinstance(config, GaussianLikelihoodConfig):
|
|
48
|
+
return GaussianLikelihood(
|
|
49
|
+
predict_logvar=config.predict_logvar,
|
|
50
|
+
logvar_lowerbound=config.logvar_lowerbound,
|
|
51
|
+
)
|
|
52
|
+
elif isinstance(config, NMLikelihoodConfig):
|
|
53
|
+
return NoiseModelLikelihood(
|
|
54
|
+
data_mean=config.data_mean,
|
|
55
|
+
data_std=config.data_std,
|
|
56
|
+
noiseModel=config.noise_model,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Invalid likelihood model type: {config.model_type}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
|
|
13
63
|
class LikelihoodModule(nn.Module):
|
|
14
64
|
"""
|
|
15
65
|
The base class for all likelihood modules.
|
|
16
66
|
It defines the fundamental structure and methods for specialized likelihood models.
|
|
17
67
|
"""
|
|
18
68
|
|
|
19
|
-
def distr_params(self, x):
|
|
69
|
+
def distr_params(self, x: Any) -> None:
|
|
20
70
|
return None
|
|
21
71
|
|
|
22
|
-
def set_params_to_same_device_as(self, correct_device_tensor):
|
|
72
|
+
def set_params_to_same_device_as(self, correct_device_tensor: Any) -> None:
|
|
23
73
|
pass
|
|
24
74
|
|
|
25
75
|
@staticmethod
|
|
26
|
-
def logvar(params):
|
|
76
|
+
def logvar(params: Any) -> None:
|
|
27
77
|
return None
|
|
28
78
|
|
|
29
79
|
@staticmethod
|
|
30
|
-
def mean(params):
|
|
80
|
+
def mean(params: Any) -> None:
|
|
31
81
|
return None
|
|
32
82
|
|
|
33
83
|
@staticmethod
|
|
34
|
-
def mode(params):
|
|
84
|
+
def mode(params: Any) -> None:
|
|
35
85
|
return None
|
|
36
86
|
|
|
37
87
|
@staticmethod
|
|
38
|
-
def sample(params):
|
|
88
|
+
def sample(params: Any) -> None:
|
|
39
89
|
return None
|
|
40
90
|
|
|
41
|
-
def log_likelihood(self, x, params):
|
|
91
|
+
def log_likelihood(self, x: Any, params: Any) -> None:
|
|
42
92
|
return None
|
|
43
93
|
|
|
44
|
-
def
|
|
45
|
-
self,
|
|
46
|
-
) ->
|
|
94
|
+
def get_mean_lv(
|
|
95
|
+
self, x: torch.Tensor
|
|
96
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
|
|
47
97
|
|
|
98
|
+
def forward(
|
|
99
|
+
self, input_: torch.Tensor, x: Union[torch.Tensor, None]
|
|
100
|
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
101
|
+
"""
|
|
102
|
+
Parameters:
|
|
103
|
+
-----------
|
|
104
|
+
input_: torch.Tensor
|
|
105
|
+
The output of the top-down pass (e.g., reconstructed image in HDN,
|
|
106
|
+
or the unmixed images in 'Split' models).
|
|
107
|
+
x: Union[torch.Tensor, None]
|
|
108
|
+
The target tensor. If None, the log-likelihood is not computed.
|
|
109
|
+
"""
|
|
48
110
|
distr_params = self.distr_params(input_)
|
|
49
111
|
mean = self.mean(distr_params)
|
|
50
112
|
mode = self.mode(distr_params)
|
|
@@ -68,8 +130,7 @@ class LikelihoodModule(nn.Module):
|
|
|
68
130
|
|
|
69
131
|
|
|
70
132
|
class GaussianLikelihood(LikelihoodModule):
|
|
71
|
-
r"""
|
|
72
|
-
A specialize `LikelihoodModule` for Gaussian likelihood.
|
|
133
|
+
r"""A specialized `LikelihoodModule` for Gaussian likelihood.
|
|
73
134
|
|
|
74
135
|
Specifically, in the LVAE model, the likelihood is defined as:
|
|
75
136
|
p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
|
|
@@ -77,50 +138,32 @@ class GaussianLikelihood(LikelihoodModule):
|
|
|
77
138
|
|
|
78
139
|
def __init__(
|
|
79
140
|
self,
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
predict_logvar: Literal[None, "pixelwise", "global", "channelwise"] = None,
|
|
83
|
-
logvar_lowerbound: float = None,
|
|
84
|
-
conv2d_bias: bool = True,
|
|
141
|
+
predict_logvar: Union[Literal["pixelwise"], None] = None,
|
|
142
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
85
143
|
):
|
|
86
|
-
"""
|
|
87
|
-
Constructor.
|
|
144
|
+
"""Constructor.
|
|
88
145
|
|
|
89
146
|
Parameters
|
|
90
147
|
----------
|
|
91
|
-
predict_logvar: Literal[
|
|
92
|
-
If
|
|
93
|
-
|
|
94
|
-
- if `pixelwise`, log-variance is computed for each pixel.
|
|
95
|
-
- if `global`, log-variance is computed as the mean of all pixel-wise entries.
|
|
96
|
-
- if `channelwise`, log-variance is computed as the average over the channels.
|
|
97
|
-
Default is `None`.
|
|
148
|
+
predict_logvar: Union[Literal["pixelwise"], None], optional
|
|
149
|
+
If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
150
|
+
is not computed. Default is `None`.
|
|
98
151
|
logvar_lowerbound: float, optional
|
|
99
152
|
The lowerbound value for log-variance. Default is `None`.
|
|
100
|
-
conv2d_bias: bool, optional
|
|
101
|
-
Whether to use bias term in convolutions. Default is `True`.
|
|
102
153
|
"""
|
|
103
154
|
super().__init__()
|
|
104
155
|
|
|
105
|
-
# If True, then we also predict pixelwise logvar.
|
|
106
156
|
self.predict_logvar = predict_logvar
|
|
107
157
|
self.logvar_lowerbound = logvar_lowerbound
|
|
108
|
-
self.
|
|
109
|
-
assert self.predict_logvar in [None, "global", "pixelwise", "channelwise"]
|
|
110
|
-
|
|
111
|
-
# logvar_ch_needed = self.predict_logvar is not None
|
|
112
|
-
# self.parameter_net = nn.Conv2d(ch_in,
|
|
113
|
-
# color_channels * (1 + logvar_ch_needed),
|
|
114
|
-
# kernel_size=3,
|
|
115
|
-
# padding=1,
|
|
116
|
-
# bias=self.conv2d_bias)
|
|
117
|
-
self.parameter_net = nn.Identity()
|
|
158
|
+
assert self.predict_logvar in [None, "pixelwise"]
|
|
118
159
|
|
|
119
160
|
print(
|
|
120
161
|
f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
|
|
121
162
|
)
|
|
122
163
|
|
|
123
|
-
def get_mean_lv(
|
|
164
|
+
def get_mean_lv(
|
|
165
|
+
self, x: torch.Tensor
|
|
166
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
124
167
|
"""
|
|
125
168
|
Given the output of the top-down pass, compute the mean and log-variance of the
|
|
126
169
|
Gaussian distribution defining the likelihood.
|
|
@@ -128,50 +171,42 @@ class GaussianLikelihood(LikelihoodModule):
|
|
|
128
171
|
Parameters
|
|
129
172
|
----------
|
|
130
173
|
x: torch.Tensor
|
|
131
|
-
The input tensor to the likelihood module, i.e., the output of the top-down
|
|
174
|
+
The input tensor to the likelihood module, i.e., the output of the top-down
|
|
175
|
+
pass.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
tuple of (torch.tensor, optional torch.tensor)
|
|
180
|
+
The first element of the tuple is the mean, the second element is the
|
|
181
|
+
log-variance. If the attribute `predict_logvar` is `None` then the second
|
|
182
|
+
element will be `None`.
|
|
132
183
|
"""
|
|
133
|
-
# Feed the output of the top-down pass to a parameter network
|
|
134
|
-
# This network can be either a Conv2d or Identity module
|
|
135
|
-
x = self.parameter_net(x)
|
|
136
184
|
|
|
137
|
-
if
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
new_shape = (*mean.shape[:1], *([1] * len(mean.shape[1:])))
|
|
151
|
-
else:
|
|
152
|
-
raise ValueError(
|
|
153
|
-
f"Invalid value for self.predict_logvar:{self.predict_logvar}"
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
lv = torch.mean(lv.reshape(N, -1), dim=1)
|
|
157
|
-
lv = lv.reshape(new_shape)
|
|
158
|
-
|
|
159
|
-
# Optionally, clip log-var to a lower bound
|
|
160
|
-
if self.logvar_lowerbound is not None:
|
|
161
|
-
lv = torch.clip(lv, min=self.logvar_lowerbound)
|
|
162
|
-
else:
|
|
163
|
-
mean = x
|
|
164
|
-
lv = None
|
|
185
|
+
# if LadderVAE.predict_logvar is None, dim 1 of `x`` has no. of target channels
|
|
186
|
+
if self.predict_logvar is None:
|
|
187
|
+
return x, None
|
|
188
|
+
|
|
189
|
+
# Get pixel-wise mean and logvar
|
|
190
|
+
# if LadderVAE.predict_logvar is not None,
|
|
191
|
+
# dim 1 has double no. of target channels
|
|
192
|
+
mean, lv = x.chunk(2, dim=1)
|
|
193
|
+
|
|
194
|
+
# Optionally, clip log-var to a lower bound
|
|
195
|
+
if self.logvar_lowerbound is not None:
|
|
196
|
+
lv = torch.clip(lv, min=self.logvar_lowerbound)
|
|
197
|
+
|
|
165
198
|
return mean, lv
|
|
166
199
|
|
|
167
|
-
def distr_params(self, x: torch.Tensor) ->
|
|
200
|
+
def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
168
201
|
"""
|
|
169
202
|
Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
|
|
170
203
|
|
|
171
204
|
Parameters
|
|
172
205
|
----------
|
|
173
206
|
x: torch.Tensor
|
|
174
|
-
The input tensor to the likelihood module, i.e., the output
|
|
207
|
+
The input tensor to the likelihood module, i.e., the output
|
|
208
|
+
the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
|
|
209
|
+
`predict_logvar` is not None, or (B, C, [Z], Y, X) otherwise.
|
|
175
210
|
"""
|
|
176
211
|
mean, lv = self.get_mean_lv(x)
|
|
177
212
|
params = {
|
|
@@ -181,24 +216,41 @@ class GaussianLikelihood(LikelihoodModule):
|
|
|
181
216
|
return params
|
|
182
217
|
|
|
183
218
|
@staticmethod
|
|
184
|
-
def mean(params):
|
|
219
|
+
def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
185
220
|
return params["mean"]
|
|
186
221
|
|
|
187
222
|
@staticmethod
|
|
188
|
-
def mode(params):
|
|
223
|
+
def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
189
224
|
return params["mean"]
|
|
190
225
|
|
|
191
226
|
@staticmethod
|
|
192
|
-
def sample(params):
|
|
227
|
+
def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
193
228
|
# p = Normal(params['mean'], (params['logvar'] / 2).exp())
|
|
194
229
|
# return p.rsample()
|
|
195
230
|
return params["mean"]
|
|
196
231
|
|
|
197
232
|
@staticmethod
|
|
198
|
-
def logvar(params):
|
|
233
|
+
def logvar(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
199
234
|
return params["logvar"]
|
|
200
235
|
|
|
201
|
-
def log_likelihood(
|
|
236
|
+
def log_likelihood(
|
|
237
|
+
self, x: torch.Tensor, params: dict[str, Union[torch.Tensor, None]]
|
|
238
|
+
):
|
|
239
|
+
"""Compute Gaussian log-likelihood
|
|
240
|
+
|
|
241
|
+
Parameters
|
|
242
|
+
----------
|
|
243
|
+
x: torch.Tensor
|
|
244
|
+
The target tensor. Shape is (B, C, [Z], Y, X).
|
|
245
|
+
params: dict[str, Union[torch.Tensor, None]]
|
|
246
|
+
The tensors obtained by chunking the output of the top-down pass,
|
|
247
|
+
here used as parameters of the Gaussian distribution.
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
torch.Tensor
|
|
252
|
+
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
253
|
+
"""
|
|
202
254
|
if self.predict_logvar is not None:
|
|
203
255
|
logprob = log_normal(x, params["mean"], params["logvar"])
|
|
204
256
|
else:
|
|
@@ -236,39 +288,46 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
236
288
|
|
|
237
289
|
def __init__(
|
|
238
290
|
self,
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
data_std: Union[Dict[str, torch.Tensor], torch.Tensor],
|
|
243
|
-
noiseModel: nn.Module,
|
|
291
|
+
data_mean: Union[np.ndarray, torch.Tensor],
|
|
292
|
+
data_std: Union[np.ndarray, torch.Tensor],
|
|
293
|
+
noiseModel: NoiseModel,
|
|
244
294
|
):
|
|
295
|
+
"""Constructor.
|
|
296
|
+
|
|
297
|
+
Parameters
|
|
298
|
+
----------
|
|
299
|
+
data_mean: Union[np.ndarray, torch.Tensor]
|
|
300
|
+
The mean of the data, used to unnormalize data for noise model evaluation.
|
|
301
|
+
data_std: Union[np.ndarray, torch.Tensor]
|
|
302
|
+
The standard deviation of the data, used to unnormalize data for noise
|
|
303
|
+
model evaluation.
|
|
304
|
+
noiseModel: NoiseModel
|
|
305
|
+
The noise model instance used to compute the likelihood.
|
|
306
|
+
"""
|
|
245
307
|
super().__init__()
|
|
246
|
-
self.
|
|
247
|
-
|
|
248
|
-
) # nn.Conv2d(ch_in, color_channels, kernel_size=3, padding=1)
|
|
249
|
-
self.data_mean = data_mean
|
|
250
|
-
self.data_std = data_std
|
|
308
|
+
self.data_mean = torch.Tensor(data_mean)
|
|
309
|
+
self.data_std = torch.Tensor(data_std)
|
|
251
310
|
self.noiseModel = noiseModel
|
|
252
311
|
|
|
253
|
-
def
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
mean, lv = self.get_mean_lv(x)
|
|
270
|
-
# mean, lv = x.chunk(2, dim=1)
|
|
312
|
+
def _set_params_to_same_device_as(
|
|
313
|
+
self, correct_device_tensor: torch.Tensor
|
|
314
|
+
) -> None:
|
|
315
|
+
"""Set the parameters to the same device as the input tensor.
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
correct_device_tensor: torch.Tensor
|
|
320
|
+
The tensor whose device is used to set the parameters.
|
|
321
|
+
"""
|
|
322
|
+
if self.data_mean.device != correct_device_tensor.device:
|
|
323
|
+
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
324
|
+
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
325
|
+
|
|
326
|
+
def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
|
|
327
|
+
return x, None
|
|
271
328
|
|
|
329
|
+
def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
330
|
+
mean, lv = self.get_mean_lv(x)
|
|
272
331
|
params = {
|
|
273
332
|
"mean": mean,
|
|
274
333
|
"logvar": lv,
|
|
@@ -276,37 +335,39 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
276
335
|
return params
|
|
277
336
|
|
|
278
337
|
@staticmethod
|
|
279
|
-
def mean(params):
|
|
338
|
+
def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
280
339
|
return params["mean"]
|
|
281
340
|
|
|
282
341
|
@staticmethod
|
|
283
|
-
def mode(params):
|
|
342
|
+
def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
284
343
|
return params["mean"]
|
|
285
344
|
|
|
286
345
|
@staticmethod
|
|
287
|
-
def sample(params):
|
|
288
|
-
# p = Normal(params['mean'], (params['logvar'] / 2).exp())
|
|
289
|
-
# return p.rsample()
|
|
346
|
+
def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
290
347
|
return params["mean"]
|
|
291
348
|
|
|
292
|
-
def log_likelihood(self, x: torch.Tensor, params:
|
|
293
|
-
"""
|
|
294
|
-
|
|
349
|
+
def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
|
|
350
|
+
"""Compute the log-likelihood given the parameters `params` obtained
|
|
351
|
+
from the reconstruction tensor and the target tensor `x`.
|
|
352
|
+
|
|
353
|
+
Parameters
|
|
354
|
+
----------
|
|
355
|
+
x: torch.Tensor
|
|
356
|
+
The target tensor. Shape is (B, C, [Z], Y, X).
|
|
357
|
+
params: dict[str, Union[torch.Tensor, None]]
|
|
358
|
+
The tensors obtained from output of the top-down pass.
|
|
359
|
+
Here, "mean" correspond to the whole output, while logvar is `None`.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
torch.Tensor
|
|
364
|
+
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
|
|
295
365
|
"""
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
x_denormalized = x * self.data_std["target"] + self.data_mean["target"]
|
|
300
|
-
# predicted_s_cloned = predicted_s_denormalized
|
|
301
|
-
# predicted_s_reduced = predicted_s_cloned.permute(1, 0, 2, 3)
|
|
302
|
-
|
|
303
|
-
# x_cloned = x_denormalized
|
|
304
|
-
# x_cloned = x_cloned.permute(1, 0, 2, 3)
|
|
305
|
-
# x_reduced = x_cloned[0, ...]
|
|
306
|
-
# import pdb;pdb.set_trace()
|
|
366
|
+
self._set_params_to_same_device_as(x)
|
|
367
|
+
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
|
|
368
|
+
x_denormalized = x * self.data_std + self.data_mean
|
|
307
369
|
likelihoods = self.noiseModel.likelihood(
|
|
308
370
|
x_denormalized, predicted_s_denormalized
|
|
309
371
|
)
|
|
310
|
-
# likelihoods = self.noiseModel.likelihood(x, params['mean'])
|
|
311
372
|
logprob = torch.log(likelihoods)
|
|
312
373
|
return logprob
|