careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,50 +1,112 @@
1
1
  """
2
- Script containing modules for definining different likelihood functions (as nn.Module).
2
+ Script containing modules for defining different likelihood functions (as nn.Module).
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+
5
7
  import math
6
- from typing import Dict, Literal, Tuple, Union
8
+ from typing import Literal, Union, TYPE_CHECKING, Any, Optional
7
9
 
8
10
  import numpy as np
9
11
  import torch
10
12
  from torch import nn
11
13
 
14
+ from careamics.config.likelihood_model import (
15
+ GaussianLikelihoodConfig,
16
+ NMLikelihoodConfig,
17
+ )
18
+
19
+ if TYPE_CHECKING:
20
+ from careamics.models.lvae.noise_models import (
21
+ GaussianMixtureNoiseModel,
22
+ MultiChannelNoiseModel,
23
+ )
24
+
25
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
26
+
27
+
28
+ def likelihood_factory(
29
+ config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig, None]
30
+ ):
31
+ """
32
+ Factory function for creating likelihood modules.
33
+
34
+ Parameters
35
+ ----------
36
+ config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig]
37
+ The configuration object for the likelihood module.
12
38
 
39
+ Returns
40
+ -------
41
+ nn.Module
42
+ The likelihood module.
43
+ """
44
+ if config is None:
45
+ return None
46
+
47
+ if isinstance(config, GaussianLikelihoodConfig):
48
+ return GaussianLikelihood(
49
+ predict_logvar=config.predict_logvar,
50
+ logvar_lowerbound=config.logvar_lowerbound,
51
+ )
52
+ elif isinstance(config, NMLikelihoodConfig):
53
+ return NoiseModelLikelihood(
54
+ data_mean=config.data_mean,
55
+ data_std=config.data_std,
56
+ noiseModel=config.noise_model,
57
+ )
58
+ else:
59
+ raise ValueError(f"Invalid likelihood model type: {config.model_type}")
60
+
61
+
62
+ # TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
13
63
  class LikelihoodModule(nn.Module):
14
64
  """
15
65
  The base class for all likelihood modules.
16
66
  It defines the fundamental structure and methods for specialized likelihood models.
17
67
  """
18
68
 
19
- def distr_params(self, x):
69
+ def distr_params(self, x: Any) -> None:
20
70
  return None
21
71
 
22
- def set_params_to_same_device_as(self, correct_device_tensor):
72
+ def set_params_to_same_device_as(self, correct_device_tensor: Any) -> None:
23
73
  pass
24
74
 
25
75
  @staticmethod
26
- def logvar(params):
76
+ def logvar(params: Any) -> None:
27
77
  return None
28
78
 
29
79
  @staticmethod
30
- def mean(params):
80
+ def mean(params: Any) -> None:
31
81
  return None
32
82
 
33
83
  @staticmethod
34
- def mode(params):
84
+ def mode(params: Any) -> None:
35
85
  return None
36
86
 
37
87
  @staticmethod
38
- def sample(params):
88
+ def sample(params: Any) -> None:
39
89
  return None
40
90
 
41
- def log_likelihood(self, x, params):
91
+ def log_likelihood(self, x: Any, params: Any) -> None:
42
92
  return None
43
93
 
44
- def forward(
45
- self, input_: torch.Tensor, x: torch.Tensor
46
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
94
+ def get_mean_lv(
95
+ self, x: torch.Tensor
96
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
47
97
 
98
+ def forward(
99
+ self, input_: torch.Tensor, x: Union[torch.Tensor, None]
100
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
101
+ """
102
+ Parameters:
103
+ -----------
104
+ input_: torch.Tensor
105
+ The output of the top-down pass (e.g., reconstructed image in HDN,
106
+ or the unmixed images in 'Split' models).
107
+ x: Union[torch.Tensor, None]
108
+ The target tensor. If None, the log-likelihood is not computed.
109
+ """
48
110
  distr_params = self.distr_params(input_)
49
111
  mean = self.mean(distr_params)
50
112
  mode = self.mode(distr_params)
@@ -68,8 +130,7 @@ class LikelihoodModule(nn.Module):
68
130
 
69
131
 
70
132
  class GaussianLikelihood(LikelihoodModule):
71
- r"""
72
- A specialize `LikelihoodModule` for Gaussian likelihood.
133
+ r"""A specialized `LikelihoodModule` for Gaussian likelihood.
73
134
 
74
135
  Specifically, in the LVAE model, the likelihood is defined as:
75
136
  p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
@@ -77,50 +138,32 @@ class GaussianLikelihood(LikelihoodModule):
77
138
 
78
139
  def __init__(
79
140
  self,
80
- ch_in: int,
81
- color_channels: int,
82
- predict_logvar: Literal[None, "pixelwise", "global", "channelwise"] = None,
83
- logvar_lowerbound: float = None,
84
- conv2d_bias: bool = True,
141
+ predict_logvar: Union[Literal["pixelwise"], None] = None,
142
+ logvar_lowerbound: Union[float, None] = None,
85
143
  ):
86
- """
87
- Constructor.
144
+ """Constructor.
88
145
 
89
146
  Parameters
90
147
  ----------
91
- predict_logvar: Literal[None, 'global', 'pixelwise', 'channelwise'], optional
92
- If not `None`, it expresses how to compute the log-variance.
93
- Namely:
94
- - if `pixelwise`, log-variance is computed for each pixel.
95
- - if `global`, log-variance is computed as the mean of all pixel-wise entries.
96
- - if `channelwise`, log-variance is computed as the average over the channels.
97
- Default is `None`.
148
+ predict_logvar: Union[Literal["pixelwise"], None], optional
149
+ If `pixelwise`, log-variance is computed for each pixel, else log-variance
150
+ is not computed. Default is `None`.
98
151
  logvar_lowerbound: float, optional
99
152
  The lowerbound value for log-variance. Default is `None`.
100
- conv2d_bias: bool, optional
101
- Whether to use bias term in convolutions. Default is `True`.
102
153
  """
103
154
  super().__init__()
104
155
 
105
- # If True, then we also predict pixelwise logvar.
106
156
  self.predict_logvar = predict_logvar
107
157
  self.logvar_lowerbound = logvar_lowerbound
108
- self.conv2d_bias = conv2d_bias
109
- assert self.predict_logvar in [None, "global", "pixelwise", "channelwise"]
110
-
111
- # logvar_ch_needed = self.predict_logvar is not None
112
- # self.parameter_net = nn.Conv2d(ch_in,
113
- # color_channels * (1 + logvar_ch_needed),
114
- # kernel_size=3,
115
- # padding=1,
116
- # bias=self.conv2d_bias)
117
- self.parameter_net = nn.Identity()
158
+ assert self.predict_logvar in [None, "pixelwise"]
118
159
 
119
160
  print(
120
161
  f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
121
162
  )
122
163
 
123
- def get_mean_lv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ def get_mean_lv(
165
+ self, x: torch.Tensor
166
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
124
167
  """
125
168
  Given the output of the top-down pass, compute the mean and log-variance of the
126
169
  Gaussian distribution defining the likelihood.
@@ -128,50 +171,42 @@ class GaussianLikelihood(LikelihoodModule):
128
171
  Parameters
129
172
  ----------
130
173
  x: torch.Tensor
131
- The input tensor to the likelihood module, i.e., the output of the top-down pass.
174
+ The input tensor to the likelihood module, i.e., the output of the top-down
175
+ pass.
176
+
177
+ Returns
178
+ -------
179
+ tuple of (torch.tensor, optional torch.tensor)
180
+ The first element of the tuple is the mean, the second element is the
181
+ log-variance. If the attribute `predict_logvar` is `None` then the second
182
+ element will be `None`.
132
183
  """
133
- # Feed the output of the top-down pass to a parameter network
134
- # This network can be either a Conv2d or Identity module
135
- x = self.parameter_net(x)
136
184
 
137
- if self.predict_logvar is not None:
138
- # Get pixel-wise mean and logvar
139
- mean, lv = x.chunk(2, dim=1)
140
-
141
- # Optionally, compute the global or channel-wise logvar
142
- if self.predict_logvar in ["channelwise", "global"]:
143
- if self.predict_logvar == "channelwise":
144
- # logvar should be of the following shape (batch, num_channels, ). Other dims would be singletons.
145
- N = np.prod(lv.shape[:2])
146
- new_shape = (*mean.shape[:2], *([1] * len(mean.shape[2:])))
147
- elif self.predict_logvar == "global":
148
- # logvar should be of the following shape (batch, ). Other dims would be singletons.
149
- N = lv.shape[0]
150
- new_shape = (*mean.shape[:1], *([1] * len(mean.shape[1:])))
151
- else:
152
- raise ValueError(
153
- f"Invalid value for self.predict_logvar:{self.predict_logvar}"
154
- )
155
-
156
- lv = torch.mean(lv.reshape(N, -1), dim=1)
157
- lv = lv.reshape(new_shape)
158
-
159
- # Optionally, clip log-var to a lower bound
160
- if self.logvar_lowerbound is not None:
161
- lv = torch.clip(lv, min=self.logvar_lowerbound)
162
- else:
163
- mean = x
164
- lv = None
185
+ # if LadderVAE.predict_logvar is None, dim 1 of `x`` has no. of target channels
186
+ if self.predict_logvar is None:
187
+ return x, None
188
+
189
+ # Get pixel-wise mean and logvar
190
+ # if LadderVAE.predict_logvar is not None,
191
+ # dim 1 has double no. of target channels
192
+ mean, lv = x.chunk(2, dim=1)
193
+
194
+ # Optionally, clip log-var to a lower bound
195
+ if self.logvar_lowerbound is not None:
196
+ lv = torch.clip(lv, min=self.logvar_lowerbound)
197
+
165
198
  return mean, lv
166
199
 
167
- def distr_params(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
200
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
168
201
  """
169
202
  Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
170
203
 
171
204
  Parameters
172
205
  ----------
173
206
  x: torch.Tensor
174
- The input tensor to the likelihood module, i.e., the output of the top-down pass.
207
+ The input tensor to the likelihood module, i.e., the output
208
+ the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
209
+ `predict_logvar` is not None, or (B, C, [Z], Y, X) otherwise.
175
210
  """
176
211
  mean, lv = self.get_mean_lv(x)
177
212
  params = {
@@ -181,24 +216,41 @@ class GaussianLikelihood(LikelihoodModule):
181
216
  return params
182
217
 
183
218
  @staticmethod
184
- def mean(params):
219
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
185
220
  return params["mean"]
186
221
 
187
222
  @staticmethod
188
- def mode(params):
223
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
189
224
  return params["mean"]
190
225
 
191
226
  @staticmethod
192
- def sample(params):
227
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
193
228
  # p = Normal(params['mean'], (params['logvar'] / 2).exp())
194
229
  # return p.rsample()
195
230
  return params["mean"]
196
231
 
197
232
  @staticmethod
198
- def logvar(params):
233
+ def logvar(params: dict[str, torch.Tensor]) -> torch.Tensor:
199
234
  return params["logvar"]
200
235
 
201
- def log_likelihood(self, x, params):
236
+ def log_likelihood(
237
+ self, x: torch.Tensor, params: dict[str, Union[torch.Tensor, None]]
238
+ ):
239
+ """Compute Gaussian log-likelihood
240
+
241
+ Parameters
242
+ ----------
243
+ x: torch.Tensor
244
+ The target tensor. Shape is (B, C, [Z], Y, X).
245
+ params: dict[str, Union[torch.Tensor, None]]
246
+ The tensors obtained by chunking the output of the top-down pass,
247
+ here used as parameters of the Gaussian distribution.
248
+
249
+ Returns
250
+ -------
251
+ torch.Tensor
252
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
253
+ """
202
254
  if self.predict_logvar is not None:
203
255
  logprob = log_normal(x, params["mean"], params["logvar"])
204
256
  else:
@@ -236,39 +288,46 @@ class NoiseModelLikelihood(LikelihoodModule):
236
288
 
237
289
  def __init__(
238
290
  self,
239
- ch_in: int,
240
- color_channels: int,
241
- data_mean: Union[Dict[str, torch.Tensor], torch.Tensor],
242
- data_std: Union[Dict[str, torch.Tensor], torch.Tensor],
243
- noiseModel: nn.Module,
291
+ data_mean: Union[np.ndarray, torch.Tensor],
292
+ data_std: Union[np.ndarray, torch.Tensor],
293
+ noiseModel: NoiseModel,
244
294
  ):
295
+ """Constructor.
296
+
297
+ Parameters
298
+ ----------
299
+ data_mean: Union[np.ndarray, torch.Tensor]
300
+ The mean of the data, used to unnormalize data for noise model evaluation.
301
+ data_std: Union[np.ndarray, torch.Tensor]
302
+ The standard deviation of the data, used to unnormalize data for noise
303
+ model evaluation.
304
+ noiseModel: NoiseModel
305
+ The noise model instance used to compute the likelihood.
306
+ """
245
307
  super().__init__()
246
- self.parameter_net = (
247
- nn.Identity()
248
- ) # nn.Conv2d(ch_in, color_channels, kernel_size=3, padding=1)
249
- self.data_mean = data_mean
250
- self.data_std = data_std
308
+ self.data_mean = torch.Tensor(data_mean)
309
+ self.data_std = torch.Tensor(data_std)
251
310
  self.noiseModel = noiseModel
252
311
 
253
- def set_params_to_same_device_as(self, correct_device_tensor):
254
- if isinstance(self.data_mean, torch.Tensor):
255
- if self.data_mean.device != correct_device_tensor.device:
256
- self.data_mean = self.data_mean.to(correct_device_tensor.device)
257
- self.data_std = self.data_std.to(correct_device_tensor.device)
258
- elif isinstance(self.data_mean, dict):
259
- for key in self.data_mean.keys():
260
- self.data_mean[key] = self.data_mean[key].to(
261
- correct_device_tensor.device
262
- )
263
- self.data_std[key] = self.data_std[key].to(correct_device_tensor.device)
264
-
265
- def get_mean_lv(self, x):
266
- return self.parameter_net(x), None
267
-
268
- def distr_params(self, x):
269
- mean, lv = self.get_mean_lv(x)
270
- # mean, lv = x.chunk(2, dim=1)
312
+ def _set_params_to_same_device_as(
313
+ self, correct_device_tensor: torch.Tensor
314
+ ) -> None:
315
+ """Set the parameters to the same device as the input tensor.
316
+
317
+ Parameters
318
+ ----------
319
+ correct_device_tensor: torch.Tensor
320
+ The tensor whose device is used to set the parameters.
321
+ """
322
+ if self.data_mean.device != correct_device_tensor.device:
323
+ self.data_mean = self.data_mean.to(correct_device_tensor.device)
324
+ self.data_std = self.data_std.to(correct_device_tensor.device)
325
+
326
+ def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
327
+ return x, None
271
328
 
329
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
330
+ mean, lv = self.get_mean_lv(x)
272
331
  params = {
273
332
  "mean": mean,
274
333
  "logvar": lv,
@@ -276,37 +335,39 @@ class NoiseModelLikelihood(LikelihoodModule):
276
335
  return params
277
336
 
278
337
  @staticmethod
279
- def mean(params):
338
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
280
339
  return params["mean"]
281
340
 
282
341
  @staticmethod
283
- def mode(params):
342
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
284
343
  return params["mean"]
285
344
 
286
345
  @staticmethod
287
- def sample(params):
288
- # p = Normal(params['mean'], (params['logvar'] / 2).exp())
289
- # return p.rsample()
346
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
290
347
  return params["mean"]
291
348
 
292
- def log_likelihood(self, x: torch.Tensor, params: Dict[str, torch.Tensor]):
293
- """
294
- Compute the log-likelihood given the parameters `params` obtained from the reconstruction tensor and the target tensor `x`.
349
+ def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
350
+ """Compute the log-likelihood given the parameters `params` obtained
351
+ from the reconstruction tensor and the target tensor `x`.
352
+
353
+ Parameters
354
+ ----------
355
+ x: torch.Tensor
356
+ The target tensor. Shape is (B, C, [Z], Y, X).
357
+ params: dict[str, Union[torch.Tensor, None]]
358
+ The tensors obtained from output of the top-down pass.
359
+ Here, "mean" correspond to the whole output, while logvar is `None`.
360
+
361
+ Returns
362
+ -------
363
+ torch.Tensor
364
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
295
365
  """
296
- predicted_s_denormalized = (
297
- params["mean"] * self.data_std["target"] + self.data_mean["target"]
298
- )
299
- x_denormalized = x * self.data_std["target"] + self.data_mean["target"]
300
- # predicted_s_cloned = predicted_s_denormalized
301
- # predicted_s_reduced = predicted_s_cloned.permute(1, 0, 2, 3)
302
-
303
- # x_cloned = x_denormalized
304
- # x_cloned = x_cloned.permute(1, 0, 2, 3)
305
- # x_reduced = x_cloned[0, ...]
306
- # import pdb;pdb.set_trace()
366
+ self._set_params_to_same_device_as(x)
367
+ predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
368
+ x_denormalized = x * self.data_std + self.data_mean
307
369
  likelihoods = self.noiseModel.likelihood(
308
370
  x_denormalized, predicted_s_denormalized
309
371
  )
310
- # likelihoods = self.noiseModel.likelihood(x, params['mean'])
311
372
  logprob = torch.log(likelihoods)
312
373
  return logprob