careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,364 @@
1
+ """
2
+ Script containing modules for defining different likelihood functions (as nn.Module).
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from typing import Literal, Union, TYPE_CHECKING, Any, Optional
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ from careamics.config.likelihood_model import (
14
+ GaussianLikelihoodConfig,
15
+ NMLikelihoodConfig,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from careamics.models.lvae.noise_models import (
20
+ GaussianMixtureNoiseModel,
21
+ MultiChannelNoiseModel,
22
+ )
23
+
24
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
25
+
26
+
27
+ def likelihood_factory(
28
+ config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig, None]
29
+ ):
30
+ """
31
+ Factory function for creating likelihood modules.
32
+
33
+ Parameters
34
+ ----------
35
+ config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig]
36
+ The configuration object for the likelihood module.
37
+
38
+ Returns
39
+ -------
40
+ nn.Module
41
+ The likelihood module.
42
+ """
43
+ if config is None:
44
+ return None
45
+
46
+ if isinstance(config, GaussianLikelihoodConfig):
47
+ return GaussianLikelihood(
48
+ predict_logvar=config.predict_logvar,
49
+ logvar_lowerbound=config.logvar_lowerbound,
50
+ )
51
+ elif isinstance(config, NMLikelihoodConfig):
52
+ return NoiseModelLikelihood(
53
+ data_mean=config.data_mean,
54
+ data_std=config.data_std,
55
+ noiseModel=config.noise_model,
56
+ )
57
+ else:
58
+ raise ValueError(f"Invalid likelihood model type: {config.model_type}")
59
+
60
+
61
+ # TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
62
+ class LikelihoodModule(nn.Module):
63
+ """
64
+ The base class for all likelihood modules.
65
+ It defines the fundamental structure and methods for specialized likelihood models.
66
+ """
67
+
68
+ def distr_params(self, x: Any) -> None:
69
+ return None
70
+
71
+ def set_params_to_same_device_as(self, correct_device_tensor: Any) -> None:
72
+ pass
73
+
74
+ @staticmethod
75
+ def logvar(params: Any) -> None:
76
+ return None
77
+
78
+ @staticmethod
79
+ def mean(params: Any) -> None:
80
+ return None
81
+
82
+ @staticmethod
83
+ def mode(params: Any) -> None:
84
+ return None
85
+
86
+ @staticmethod
87
+ def sample(params: Any) -> None:
88
+ return None
89
+
90
+ def log_likelihood(self, x: Any, params: Any) -> None:
91
+ return None
92
+
93
+ def get_mean_lv(
94
+ self, x: torch.Tensor
95
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
96
+
97
+ def forward(
98
+ self, input_: torch.Tensor, x: Union[torch.Tensor, None]
99
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
100
+ """
101
+ Parameters:
102
+ -----------
103
+ input_: torch.Tensor
104
+ The output of the top-down pass (e.g., reconstructed image in HDN,
105
+ or the unmixed images in 'Split' models).
106
+ x: Union[torch.Tensor, None]
107
+ The target tensor. If None, the log-likelihood is not computed.
108
+ """
109
+ distr_params = self.distr_params(input_)
110
+ mean = self.mean(distr_params)
111
+ mode = self.mode(distr_params)
112
+ sample = self.sample(distr_params)
113
+ logvar = self.logvar(distr_params)
114
+
115
+ if x is None:
116
+ ll = None
117
+ else:
118
+ ll = self.log_likelihood(x, distr_params)
119
+
120
+ dct = {
121
+ "mean": mean,
122
+ "mode": mode,
123
+ "sample": sample,
124
+ "params": distr_params,
125
+ "logvar": logvar,
126
+ }
127
+
128
+ return ll, dct
129
+
130
+
131
+ class GaussianLikelihood(LikelihoodModule):
132
+ r"""A specialized `LikelihoodModule` for Gaussian likelihood.
133
+
134
+ Specifically, in the LVAE model, the likelihood is defined as:
135
+ p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ predict_logvar: Union[Literal["pixelwise"], None] = None,
141
+ logvar_lowerbound: Union[float, None] = None,
142
+ ):
143
+ """Constructor.
144
+
145
+ Parameters
146
+ ----------
147
+ predict_logvar: Union[Literal["pixelwise"], None], optional
148
+ If `pixelwise`, log-variance is computed for each pixel, else log-variance
149
+ is not computed. Default is `None`.
150
+ logvar_lowerbound: float, optional
151
+ The lowerbound value for log-variance. Default is `None`.
152
+ """
153
+ super().__init__()
154
+
155
+ self.predict_logvar = predict_logvar
156
+ self.logvar_lowerbound = logvar_lowerbound
157
+ assert self.predict_logvar in [None, "pixelwise"]
158
+
159
+ print(
160
+ f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
161
+ )
162
+
163
+ def get_mean_lv(
164
+ self, x: torch.Tensor
165
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
166
+ """
167
+ Given the output of the top-down pass, compute the mean and log-variance of the
168
+ Gaussian distribution defining the likelihood.
169
+
170
+ Parameters
171
+ ----------
172
+ x: torch.Tensor
173
+ The input tensor to the likelihood module, i.e., the output of the top-down
174
+ pass.
175
+
176
+ Returns
177
+ -------
178
+ tuple of (torch.tensor, optional torch.tensor)
179
+ The first element of the tuple is the mean, the second element is the
180
+ log-variance. If the attribute `predict_logvar` is `None` then the second
181
+ element will be `None`.
182
+ """
183
+
184
+ # if LadderVAE.predict_logvar is None, dim 1 of `x`` has no. of target channels
185
+ if self.predict_logvar is None:
186
+ return x, None
187
+
188
+ # Get pixel-wise mean and logvar
189
+ # if LadderVAE.predict_logvar is not None,
190
+ # dim 1 has double no. of target channels
191
+ mean, lv = x.chunk(2, dim=1)
192
+
193
+ # Optionally, clip log-var to a lower bound
194
+ if self.logvar_lowerbound is not None:
195
+ lv = torch.clip(lv, min=self.logvar_lowerbound)
196
+
197
+ return mean, lv
198
+
199
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
200
+ """
201
+ Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
202
+
203
+ Parameters
204
+ ----------
205
+ x: torch.Tensor
206
+ The input tensor to the likelihood module, i.e., the output
207
+ the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
208
+ `predict_logvar` is not None, or (B, C, [Z], Y, X) otherwise.
209
+ """
210
+ mean, lv = self.get_mean_lv(x)
211
+ params = {
212
+ "mean": mean,
213
+ "logvar": lv,
214
+ }
215
+ return params
216
+
217
+ @staticmethod
218
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
219
+ return params["mean"]
220
+
221
+ @staticmethod
222
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
223
+ return params["mean"]
224
+
225
+ @staticmethod
226
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
227
+ # p = Normal(params['mean'], (params['logvar'] / 2).exp())
228
+ # return p.rsample()
229
+ return params["mean"]
230
+
231
+ @staticmethod
232
+ def logvar(params: dict[str, torch.Tensor]) -> torch.Tensor:
233
+ return params["logvar"]
234
+
235
+ def log_likelihood(
236
+ self, x: torch.Tensor, params: dict[str, Union[torch.Tensor, None]]
237
+ ):
238
+ """Compute Gaussian log-likelihood
239
+
240
+ Parameters
241
+ ----------
242
+ x: torch.Tensor
243
+ The target tensor. Shape is (B, C, [Z], Y, X).
244
+ params: dict[str, Union[torch.Tensor, None]]
245
+ The tensors obtained by chunking the output of the top-down pass,
246
+ here used as parameters of the Gaussian distribution.
247
+
248
+ Returns
249
+ -------
250
+ torch.Tensor
251
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
252
+ """
253
+ if self.predict_logvar is not None:
254
+ logprob = log_normal(x, params["mean"], params["logvar"])
255
+ else:
256
+ logprob = -0.5 * (params["mean"] - x) ** 2
257
+ return logprob
258
+
259
+
260
+ def log_normal(
261
+ x: torch.Tensor, mean: torch.Tensor, logvar: torch.Tensor
262
+ ) -> torch.Tensor:
263
+ """
264
+ Compute the log-probability at `x` of a Gaussian distribution
265
+ with parameters `(mean, exp(logvar))`.
266
+
267
+ NOTE: In the case of LVAE, the log-likeihood formula becomes:
268
+ \\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{p_\theta(x|z_1)}]=-\frac{1}{2}(\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{2\\pi\\sigma_{p,0}^2(z_1)}] +\\mathbb{E}_{z_1\\sim{q_\\phi}}[\frac{(x-\\mu_{p,0}(z_1))^2}{\\sigma_{p,0}^2(z_1)}])
269
+
270
+ Parameters
271
+ ----------
272
+ x: torch.Tensor
273
+ The ground-truth tensor. Shape is (batch, channels, dim1, dim2).
274
+ mean: torch.Tensor
275
+ The inferred mean of distribution. Shape is (batch, channels, dim1, dim2).
276
+ logvar: torch.Tensor
277
+ The inferred log-variance of distribution. Shape has to be either scalar or broadcastable.
278
+ """
279
+ var = torch.exp(logvar)
280
+ log_prob = -0.5 * (
281
+ ((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
282
+ )
283
+ return log_prob
284
+
285
+
286
+ class NoiseModelLikelihood(LikelihoodModule):
287
+
288
+ def __init__(
289
+ self,
290
+ data_mean: torch.Tensor,
291
+ data_std: torch.Tensor,
292
+ noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports...
293
+ ):
294
+ """Constructor.
295
+
296
+ Parameters
297
+ ----------
298
+ data_mean: torch.Tensor
299
+ The mean of the data, used to unnormalize data for noise model evaluation.
300
+ data_std: torch.Tensor
301
+ The standard deviation of the data, used to unnormalize data for noise
302
+ model evaluation.
303
+ noiseModel: NoiseModel
304
+ The noise model instance used to compute the likelihood.
305
+ """
306
+ super().__init__()
307
+ self.data_mean = data_mean
308
+ self.data_std = data_std
309
+ self.noiseModel = noiseModel
310
+
311
+ def set_params_to_same_device_as(
312
+ self, correct_device_tensor: torch.Tensor
313
+ ) -> None: # TODO: needed?
314
+ if self.data_mean.device != correct_device_tensor.device:
315
+ self.data_mean = self.data_mean.to(correct_device_tensor.device)
316
+ self.data_std = self.data_std.to(correct_device_tensor.device)
317
+
318
+ def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
319
+ return x, None
320
+
321
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
322
+ mean, lv = self.get_mean_lv(x)
323
+ params = {
324
+ "mean": mean,
325
+ "logvar": lv,
326
+ }
327
+ return params
328
+
329
+ @staticmethod
330
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
331
+ return params["mean"]
332
+
333
+ @staticmethod
334
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
335
+ return params["mean"]
336
+
337
+ @staticmethod
338
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
339
+ return params["mean"]
340
+
341
+ def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
342
+ """Compute the log-likelihood given the parameters `params` obtained
343
+ from the reconstruction tensor and the target tensor `x`.
344
+
345
+ Parameters
346
+ ----------
347
+ x: torch.Tensor
348
+ The target tensor. Shape is (B, C, [Z], Y, X).
349
+ params: dict[str, Union[torch.Tensor, None]]
350
+ The tensors obtained from output of the top-down pass.
351
+ Here, "mean" correspond to the whole output, while logvar is `None`.
352
+
353
+ Returns
354
+ -------
355
+ torch.Tensor
356
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
357
+ """
358
+ predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
359
+ x_denormalized = x * self.data_std + self.data_mean
360
+ likelihoods = self.noiseModel.likelihood(
361
+ x_denormalized, predicted_s_denormalized
362
+ )
363
+ logprob = torch.log(likelihoods)
364
+ return logprob