careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,50 +1,111 @@
1
1
  """
2
- Script containing modules for definining different likelihood functions (as nn.Module).
2
+ Script containing modules for defining different likelihood functions (as nn.Module).
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+
5
7
  import math
6
- from typing import Dict, Literal, Tuple, Union
8
+ from typing import Literal, Union, TYPE_CHECKING, Any, Optional
7
9
 
8
- import numpy as np
9
10
  import torch
10
11
  from torch import nn
11
12
 
13
+ from careamics.config.likelihood_model import (
14
+ GaussianLikelihoodConfig,
15
+ NMLikelihoodConfig,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from careamics.models.lvae.noise_models import (
20
+ GaussianMixtureNoiseModel,
21
+ MultiChannelNoiseModel,
22
+ )
23
+
24
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
25
+
26
+
27
+ def likelihood_factory(
28
+ config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig, None]
29
+ ):
30
+ """
31
+ Factory function for creating likelihood modules.
32
+
33
+ Parameters
34
+ ----------
35
+ config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig]
36
+ The configuration object for the likelihood module.
37
+
38
+ Returns
39
+ -------
40
+ nn.Module
41
+ The likelihood module.
42
+ """
43
+ if config is None:
44
+ return None
12
45
 
46
+ if isinstance(config, GaussianLikelihoodConfig):
47
+ return GaussianLikelihood(
48
+ predict_logvar=config.predict_logvar,
49
+ logvar_lowerbound=config.logvar_lowerbound,
50
+ )
51
+ elif isinstance(config, NMLikelihoodConfig):
52
+ return NoiseModelLikelihood(
53
+ data_mean=config.data_mean,
54
+ data_std=config.data_std,
55
+ noiseModel=config.noise_model,
56
+ )
57
+ else:
58
+ raise ValueError(f"Invalid likelihood model type: {config.model_type}")
59
+
60
+
61
+ # TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
13
62
  class LikelihoodModule(nn.Module):
14
63
  """
15
64
  The base class for all likelihood modules.
16
65
  It defines the fundamental structure and methods for specialized likelihood models.
17
66
  """
18
67
 
19
- def distr_params(self, x):
68
+ def distr_params(self, x: Any) -> None:
20
69
  return None
21
70
 
22
- def set_params_to_same_device_as(self, correct_device_tensor):
71
+ def set_params_to_same_device_as(self, correct_device_tensor: Any) -> None:
23
72
  pass
24
73
 
25
74
  @staticmethod
26
- def logvar(params):
75
+ def logvar(params: Any) -> None:
27
76
  return None
28
77
 
29
78
  @staticmethod
30
- def mean(params):
79
+ def mean(params: Any) -> None:
31
80
  return None
32
81
 
33
82
  @staticmethod
34
- def mode(params):
83
+ def mode(params: Any) -> None:
35
84
  return None
36
85
 
37
86
  @staticmethod
38
- def sample(params):
87
+ def sample(params: Any) -> None:
39
88
  return None
40
89
 
41
- def log_likelihood(self, x, params):
90
+ def log_likelihood(self, x: Any, params: Any) -> None:
42
91
  return None
43
92
 
44
- def forward(
45
- self, input_: torch.Tensor, x: torch.Tensor
46
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
93
+ def get_mean_lv(
94
+ self, x: torch.Tensor
95
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
47
96
 
97
+ def forward(
98
+ self, input_: torch.Tensor, x: Union[torch.Tensor, None]
99
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
100
+ """
101
+ Parameters:
102
+ -----------
103
+ input_: torch.Tensor
104
+ The output of the top-down pass (e.g., reconstructed image in HDN,
105
+ or the unmixed images in 'Split' models).
106
+ x: Union[torch.Tensor, None]
107
+ The target tensor. If None, the log-likelihood is not computed.
108
+ """
48
109
  distr_params = self.distr_params(input_)
49
110
  mean = self.mean(distr_params)
50
111
  mode = self.mode(distr_params)
@@ -68,8 +129,7 @@ class LikelihoodModule(nn.Module):
68
129
 
69
130
 
70
131
  class GaussianLikelihood(LikelihoodModule):
71
- r"""
72
- A specialize `LikelihoodModule` for Gaussian likelihood.
132
+ r"""A specialized `LikelihoodModule` for Gaussian likelihood.
73
133
 
74
134
  Specifically, in the LVAE model, the likelihood is defined as:
75
135
  p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
@@ -77,50 +137,32 @@ class GaussianLikelihood(LikelihoodModule):
77
137
 
78
138
  def __init__(
79
139
  self,
80
- ch_in: int,
81
- color_channels: int,
82
- predict_logvar: Literal[None, "pixelwise", "global", "channelwise"] = None,
83
- logvar_lowerbound: float = None,
84
- conv2d_bias: bool = True,
140
+ predict_logvar: Union[Literal["pixelwise"], None] = None,
141
+ logvar_lowerbound: Union[float, None] = None,
85
142
  ):
86
- """
87
- Constructor.
143
+ """Constructor.
88
144
 
89
145
  Parameters
90
146
  ----------
91
- predict_logvar: Literal[None, 'global', 'pixelwise', 'channelwise'], optional
92
- If not `None`, it expresses how to compute the log-variance.
93
- Namely:
94
- - if `pixelwise`, log-variance is computed for each pixel.
95
- - if `global`, log-variance is computed as the mean of all pixel-wise entries.
96
- - if `channelwise`, log-variance is computed as the average over the channels.
97
- Default is `None`.
147
+ predict_logvar: Union[Literal["pixelwise"], None], optional
148
+ If `pixelwise`, log-variance is computed for each pixel, else log-variance
149
+ is not computed. Default is `None`.
98
150
  logvar_lowerbound: float, optional
99
151
  The lowerbound value for log-variance. Default is `None`.
100
- conv2d_bias: bool, optional
101
- Whether to use bias term in convolutions. Default is `True`.
102
152
  """
103
153
  super().__init__()
104
154
 
105
- # If True, then we also predict pixelwise logvar.
106
155
  self.predict_logvar = predict_logvar
107
156
  self.logvar_lowerbound = logvar_lowerbound
108
- self.conv2d_bias = conv2d_bias
109
- assert self.predict_logvar in [None, "global", "pixelwise", "channelwise"]
110
-
111
- # logvar_ch_needed = self.predict_logvar is not None
112
- # self.parameter_net = nn.Conv2d(ch_in,
113
- # color_channels * (1 + logvar_ch_needed),
114
- # kernel_size=3,
115
- # padding=1,
116
- # bias=self.conv2d_bias)
117
- self.parameter_net = nn.Identity()
157
+ assert self.predict_logvar in [None, "pixelwise"]
118
158
 
119
159
  print(
120
160
  f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
121
161
  )
122
162
 
123
- def get_mean_lv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
163
+ def get_mean_lv(
164
+ self, x: torch.Tensor
165
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
124
166
  """
125
167
  Given the output of the top-down pass, compute the mean and log-variance of the
126
168
  Gaussian distribution defining the likelihood.
@@ -128,50 +170,42 @@ class GaussianLikelihood(LikelihoodModule):
128
170
  Parameters
129
171
  ----------
130
172
  x: torch.Tensor
131
- The input tensor to the likelihood module, i.e., the output of the top-down pass.
173
+ The input tensor to the likelihood module, i.e., the output of the top-down
174
+ pass.
175
+
176
+ Returns
177
+ -------
178
+ tuple of (torch.tensor, optional torch.tensor)
179
+ The first element of the tuple is the mean, the second element is the
180
+ log-variance. If the attribute `predict_logvar` is `None` then the second
181
+ element will be `None`.
132
182
  """
133
- # Feed the output of the top-down pass to a parameter network
134
- # This network can be either a Conv2d or Identity module
135
- x = self.parameter_net(x)
136
183
 
137
- if self.predict_logvar is not None:
138
- # Get pixel-wise mean and logvar
139
- mean, lv = x.chunk(2, dim=1)
140
-
141
- # Optionally, compute the global or channel-wise logvar
142
- if self.predict_logvar in ["channelwise", "global"]:
143
- if self.predict_logvar == "channelwise":
144
- # logvar should be of the following shape (batch, num_channels, ). Other dims would be singletons.
145
- N = np.prod(lv.shape[:2])
146
- new_shape = (*mean.shape[:2], *([1] * len(mean.shape[2:])))
147
- elif self.predict_logvar == "global":
148
- # logvar should be of the following shape (batch, ). Other dims would be singletons.
149
- N = lv.shape[0]
150
- new_shape = (*mean.shape[:1], *([1] * len(mean.shape[1:])))
151
- else:
152
- raise ValueError(
153
- f"Invalid value for self.predict_logvar:{self.predict_logvar}"
154
- )
155
-
156
- lv = torch.mean(lv.reshape(N, -1), dim=1)
157
- lv = lv.reshape(new_shape)
158
-
159
- # Optionally, clip log-var to a lower bound
160
- if self.logvar_lowerbound is not None:
161
- lv = torch.clip(lv, min=self.logvar_lowerbound)
162
- else:
163
- mean = x
164
- lv = None
184
+ # if LadderVAE.predict_logvar is None, dim 1 of `x`` has no. of target channels
185
+ if self.predict_logvar is None:
186
+ return x, None
187
+
188
+ # Get pixel-wise mean and logvar
189
+ # if LadderVAE.predict_logvar is not None,
190
+ # dim 1 has double no. of target channels
191
+ mean, lv = x.chunk(2, dim=1)
192
+
193
+ # Optionally, clip log-var to a lower bound
194
+ if self.logvar_lowerbound is not None:
195
+ lv = torch.clip(lv, min=self.logvar_lowerbound)
196
+
165
197
  return mean, lv
166
198
 
167
- def distr_params(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
199
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
168
200
  """
169
201
  Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
170
202
 
171
203
  Parameters
172
204
  ----------
173
205
  x: torch.Tensor
174
- The input tensor to the likelihood module, i.e., the output of the top-down pass.
206
+ The input tensor to the likelihood module, i.e., the output
207
+ the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
208
+ `predict_logvar` is not None, or (B, C, [Z], Y, X) otherwise.
175
209
  """
176
210
  mean, lv = self.get_mean_lv(x)
177
211
  params = {
@@ -181,24 +215,41 @@ class GaussianLikelihood(LikelihoodModule):
181
215
  return params
182
216
 
183
217
  @staticmethod
184
- def mean(params):
218
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
185
219
  return params["mean"]
186
220
 
187
221
  @staticmethod
188
- def mode(params):
222
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
189
223
  return params["mean"]
190
224
 
191
225
  @staticmethod
192
- def sample(params):
226
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
193
227
  # p = Normal(params['mean'], (params['logvar'] / 2).exp())
194
228
  # return p.rsample()
195
229
  return params["mean"]
196
230
 
197
231
  @staticmethod
198
- def logvar(params):
232
+ def logvar(params: dict[str, torch.Tensor]) -> torch.Tensor:
199
233
  return params["logvar"]
200
234
 
201
- def log_likelihood(self, x, params):
235
+ def log_likelihood(
236
+ self, x: torch.Tensor, params: dict[str, Union[torch.Tensor, None]]
237
+ ):
238
+ """Compute Gaussian log-likelihood
239
+
240
+ Parameters
241
+ ----------
242
+ x: torch.Tensor
243
+ The target tensor. Shape is (B, C, [Z], Y, X).
244
+ params: dict[str, Union[torch.Tensor, None]]
245
+ The tensors obtained by chunking the output of the top-down pass,
246
+ here used as parameters of the Gaussian distribution.
247
+
248
+ Returns
249
+ -------
250
+ torch.Tensor
251
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
252
+ """
202
253
  if self.predict_logvar is not None:
203
254
  logprob = log_normal(x, params["mean"], params["logvar"])
204
255
  else:
@@ -236,39 +287,39 @@ class NoiseModelLikelihood(LikelihoodModule):
236
287
 
237
288
  def __init__(
238
289
  self,
239
- ch_in: int,
240
- color_channels: int,
241
- data_mean: Union[Dict[str, torch.Tensor], torch.Tensor],
242
- data_std: Union[Dict[str, torch.Tensor], torch.Tensor],
243
- noiseModel: nn.Module,
290
+ data_mean: torch.Tensor,
291
+ data_std: torch.Tensor,
292
+ noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports...
244
293
  ):
294
+ """Constructor.
295
+
296
+ Parameters
297
+ ----------
298
+ data_mean: torch.Tensor
299
+ The mean of the data, used to unnormalize data for noise model evaluation.
300
+ data_std: torch.Tensor
301
+ The standard deviation of the data, used to unnormalize data for noise
302
+ model evaluation.
303
+ noiseModel: NoiseModel
304
+ The noise model instance used to compute the likelihood.
305
+ """
245
306
  super().__init__()
246
- self.parameter_net = (
247
- nn.Identity()
248
- ) # nn.Conv2d(ch_in, color_channels, kernel_size=3, padding=1)
249
307
  self.data_mean = data_mean
250
308
  self.data_std = data_std
251
309
  self.noiseModel = noiseModel
252
310
 
253
- def set_params_to_same_device_as(self, correct_device_tensor):
254
- if isinstance(self.data_mean, torch.Tensor):
255
- if self.data_mean.device != correct_device_tensor.device:
256
- self.data_mean = self.data_mean.to(correct_device_tensor.device)
257
- self.data_std = self.data_std.to(correct_device_tensor.device)
258
- elif isinstance(self.data_mean, dict):
259
- for key in self.data_mean.keys():
260
- self.data_mean[key] = self.data_mean[key].to(
261
- correct_device_tensor.device
262
- )
263
- self.data_std[key] = self.data_std[key].to(correct_device_tensor.device)
264
-
265
- def get_mean_lv(self, x):
266
- return self.parameter_net(x), None
267
-
268
- def distr_params(self, x):
269
- mean, lv = self.get_mean_lv(x)
270
- # mean, lv = x.chunk(2, dim=1)
311
+ def set_params_to_same_device_as(
312
+ self, correct_device_tensor: torch.Tensor
313
+ ) -> None: # TODO: needed?
314
+ if self.data_mean.device != correct_device_tensor.device:
315
+ self.data_mean = self.data_mean.to(correct_device_tensor.device)
316
+ self.data_std = self.data_std.to(correct_device_tensor.device)
317
+
318
+ def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
319
+ return x, None
271
320
 
321
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
322
+ mean, lv = self.get_mean_lv(x)
272
323
  params = {
273
324
  "mean": mean,
274
325
  "logvar": lv,
@@ -276,37 +327,38 @@ class NoiseModelLikelihood(LikelihoodModule):
276
327
  return params
277
328
 
278
329
  @staticmethod
279
- def mean(params):
330
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
280
331
  return params["mean"]
281
332
 
282
333
  @staticmethod
283
- def mode(params):
334
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
284
335
  return params["mean"]
285
336
 
286
337
  @staticmethod
287
- def sample(params):
288
- # p = Normal(params['mean'], (params['logvar'] / 2).exp())
289
- # return p.rsample()
338
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
290
339
  return params["mean"]
291
340
 
292
- def log_likelihood(self, x: torch.Tensor, params: Dict[str, torch.Tensor]):
293
- """
294
- Compute the log-likelihood given the parameters `params` obtained from the reconstruction tensor and the target tensor `x`.
341
+ def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
342
+ """Compute the log-likelihood given the parameters `params` obtained
343
+ from the reconstruction tensor and the target tensor `x`.
344
+
345
+ Parameters
346
+ ----------
347
+ x: torch.Tensor
348
+ The target tensor. Shape is (B, C, [Z], Y, X).
349
+ params: dict[str, Union[torch.Tensor, None]]
350
+ The tensors obtained from output of the top-down pass.
351
+ Here, "mean" correspond to the whole output, while logvar is `None`.
352
+
353
+ Returns
354
+ -------
355
+ torch.Tensor
356
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
295
357
  """
296
- predicted_s_denormalized = (
297
- params["mean"] * self.data_std["target"] + self.data_mean["target"]
298
- )
299
- x_denormalized = x * self.data_std["target"] + self.data_mean["target"]
300
- # predicted_s_cloned = predicted_s_denormalized
301
- # predicted_s_reduced = predicted_s_cloned.permute(1, 0, 2, 3)
302
-
303
- # x_cloned = x_denormalized
304
- # x_cloned = x_cloned.permute(1, 0, 2, 3)
305
- # x_reduced = x_cloned[0, ...]
306
- # import pdb;pdb.set_trace()
358
+ predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
359
+ x_denormalized = x * self.data_std + self.data_mean
307
360
  likelihoods = self.noiseModel.likelihood(
308
361
  x_denormalized, predicted_s_denormalized
309
362
  )
310
- # likelihoods = self.noiseModel.likelihood(x, params['mean'])
311
363
  logprob = torch.log(likelihoods)
312
364
  return logprob