careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,312 @@
1
+ """
2
+ Script containing modules for definining different likelihood functions (as nn.Module).
3
+ """
4
+
5
+ import math
6
+ from typing import Dict, Literal, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class LikelihoodModule(nn.Module):
14
+ """
15
+ The base class for all likelihood modules.
16
+ It defines the fundamental structure and methods for specialized likelihood models.
17
+ """
18
+
19
+ def distr_params(self, x):
20
+ return None
21
+
22
+ def set_params_to_same_device_as(self, correct_device_tensor):
23
+ pass
24
+
25
+ @staticmethod
26
+ def logvar(params):
27
+ return None
28
+
29
+ @staticmethod
30
+ def mean(params):
31
+ return None
32
+
33
+ @staticmethod
34
+ def mode(params):
35
+ return None
36
+
37
+ @staticmethod
38
+ def sample(params):
39
+ return None
40
+
41
+ def log_likelihood(self, x, params):
42
+ return None
43
+
44
+ def forward(
45
+ self, input_: torch.Tensor, x: torch.Tensor
46
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
47
+
48
+ distr_params = self.distr_params(input_)
49
+ mean = self.mean(distr_params)
50
+ mode = self.mode(distr_params)
51
+ sample = self.sample(distr_params)
52
+ logvar = self.logvar(distr_params)
53
+
54
+ if x is None:
55
+ ll = None
56
+ else:
57
+ ll = self.log_likelihood(x, distr_params)
58
+
59
+ dct = {
60
+ "mean": mean,
61
+ "mode": mode,
62
+ "sample": sample,
63
+ "params": distr_params,
64
+ "logvar": logvar,
65
+ }
66
+
67
+ return ll, dct
68
+
69
+
70
+ class GaussianLikelihood(LikelihoodModule):
71
+ r"""
72
+ A specialize `LikelihoodModule` for Gaussian likelihood.
73
+
74
+ Specifically, in the LVAE model, the likelihood is defined as:
75
+ p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ ch_in: int,
81
+ color_channels: int,
82
+ predict_logvar: Literal[None, "pixelwise", "global", "channelwise"] = None,
83
+ logvar_lowerbound: float = None,
84
+ conv2d_bias: bool = True,
85
+ ):
86
+ """
87
+ Constructor.
88
+
89
+ Parameters
90
+ ----------
91
+ predict_logvar: Literal[None, 'global', 'pixelwise', 'channelwise'], optional
92
+ If not `None`, it expresses how to compute the log-variance.
93
+ Namely:
94
+ - if `pixelwise`, log-variance is computed for each pixel.
95
+ - if `global`, log-variance is computed as the mean of all pixel-wise entries.
96
+ - if `channelwise`, log-variance is computed as the average over the channels.
97
+ Default is `None`.
98
+ logvar_lowerbound: float, optional
99
+ The lowerbound value for log-variance. Default is `None`.
100
+ conv2d_bias: bool, optional
101
+ Whether to use bias term in convolutions. Default is `True`.
102
+ """
103
+ super().__init__()
104
+
105
+ # If True, then we also predict pixelwise logvar.
106
+ self.predict_logvar = predict_logvar
107
+ self.logvar_lowerbound = logvar_lowerbound
108
+ self.conv2d_bias = conv2d_bias
109
+ assert self.predict_logvar in [None, "global", "pixelwise", "channelwise"]
110
+
111
+ # logvar_ch_needed = self.predict_logvar is not None
112
+ # self.parameter_net = nn.Conv2d(ch_in,
113
+ # color_channels * (1 + logvar_ch_needed),
114
+ # kernel_size=3,
115
+ # padding=1,
116
+ # bias=self.conv2d_bias)
117
+ self.parameter_net = nn.Identity()
118
+
119
+ print(
120
+ f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
121
+ )
122
+
123
+ def get_mean_lv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ """
125
+ Given the output of the top-down pass, compute the mean and log-variance of the
126
+ Gaussian distribution defining the likelihood.
127
+
128
+ Parameters
129
+ ----------
130
+ x: torch.Tensor
131
+ The input tensor to the likelihood module, i.e., the output of the top-down pass.
132
+ """
133
+ # Feed the output of the top-down pass to a parameter network
134
+ # This network can be either a Conv2d or Identity module
135
+ x = self.parameter_net(x)
136
+
137
+ if self.predict_logvar is not None:
138
+ # Get pixel-wise mean and logvar
139
+ mean, lv = x.chunk(2, dim=1)
140
+
141
+ # Optionally, compute the global or channel-wise logvar
142
+ if self.predict_logvar in ["channelwise", "global"]:
143
+ if self.predict_logvar == "channelwise":
144
+ # logvar should be of the following shape (batch, num_channels, ). Other dims would be singletons.
145
+ N = np.prod(lv.shape[:2])
146
+ new_shape = (*mean.shape[:2], *([1] * len(mean.shape[2:])))
147
+ elif self.predict_logvar == "global":
148
+ # logvar should be of the following shape (batch, ). Other dims would be singletons.
149
+ N = lv.shape[0]
150
+ new_shape = (*mean.shape[:1], *([1] * len(mean.shape[1:])))
151
+ else:
152
+ raise ValueError(
153
+ f"Invalid value for self.predict_logvar:{self.predict_logvar}"
154
+ )
155
+
156
+ lv = torch.mean(lv.reshape(N, -1), dim=1)
157
+ lv = lv.reshape(new_shape)
158
+
159
+ # Optionally, clip log-var to a lower bound
160
+ if self.logvar_lowerbound is not None:
161
+ lv = torch.clip(lv, min=self.logvar_lowerbound)
162
+ else:
163
+ mean = x
164
+ lv = None
165
+ return mean, lv
166
+
167
+ def distr_params(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
168
+ """
169
+ Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
170
+
171
+ Parameters
172
+ ----------
173
+ x: torch.Tensor
174
+ The input tensor to the likelihood module, i.e., the output of the top-down pass.
175
+ """
176
+ mean, lv = self.get_mean_lv(x)
177
+ params = {
178
+ "mean": mean,
179
+ "logvar": lv,
180
+ }
181
+ return params
182
+
183
+ @staticmethod
184
+ def mean(params):
185
+ return params["mean"]
186
+
187
+ @staticmethod
188
+ def mode(params):
189
+ return params["mean"]
190
+
191
+ @staticmethod
192
+ def sample(params):
193
+ # p = Normal(params['mean'], (params['logvar'] / 2).exp())
194
+ # return p.rsample()
195
+ return params["mean"]
196
+
197
+ @staticmethod
198
+ def logvar(params):
199
+ return params["logvar"]
200
+
201
+ def log_likelihood(self, x, params):
202
+ if self.predict_logvar is not None:
203
+ logprob = log_normal(x, params["mean"], params["logvar"])
204
+ else:
205
+ logprob = -0.5 * (params["mean"] - x) ** 2
206
+ return logprob
207
+
208
+
209
+ def log_normal(
210
+ x: torch.Tensor, mean: torch.Tensor, logvar: torch.Tensor
211
+ ) -> torch.Tensor:
212
+ """
213
+ Compute the log-probability at `x` of a Gaussian distribution
214
+ with parameters `(mean, exp(logvar))`.
215
+
216
+ NOTE: In the case of LVAE, the log-likeihood formula becomes:
217
+ \\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{p_\theta(x|z_1)}]=-\frac{1}{2}(\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{2\\pi\\sigma_{p,0}^2(z_1)}] +\\mathbb{E}_{z_1\\sim{q_\\phi}}[\frac{(x-\\mu_{p,0}(z_1))^2}{\\sigma_{p,0}^2(z_1)}])
218
+
219
+ Parameters
220
+ ----------
221
+ x: torch.Tensor
222
+ The ground-truth tensor. Shape is (batch, channels, dim1, dim2).
223
+ mean: torch.Tensor
224
+ The inferred mean of distribution. Shape is (batch, channels, dim1, dim2).
225
+ logvar: torch.Tensor
226
+ The inferred log-variance of distribution. Shape has to be either scalar or broadcastable.
227
+ """
228
+ var = torch.exp(logvar)
229
+ log_prob = -0.5 * (
230
+ ((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
231
+ )
232
+ return log_prob
233
+
234
+
235
+ class NoiseModelLikelihood(LikelihoodModule):
236
+
237
+ def __init__(
238
+ self,
239
+ ch_in: int,
240
+ color_channels: int,
241
+ data_mean: Union[Dict[str, torch.Tensor], torch.Tensor],
242
+ data_std: Union[Dict[str, torch.Tensor], torch.Tensor],
243
+ noiseModel: nn.Module,
244
+ ):
245
+ super().__init__()
246
+ self.parameter_net = (
247
+ nn.Identity()
248
+ ) # nn.Conv2d(ch_in, color_channels, kernel_size=3, padding=1)
249
+ self.data_mean = data_mean
250
+ self.data_std = data_std
251
+ self.noiseModel = noiseModel
252
+
253
+ def set_params_to_same_device_as(self, correct_device_tensor):
254
+ if isinstance(self.data_mean, torch.Tensor):
255
+ if self.data_mean.device != correct_device_tensor.device:
256
+ self.data_mean = self.data_mean.to(correct_device_tensor.device)
257
+ self.data_std = self.data_std.to(correct_device_tensor.device)
258
+ elif isinstance(self.data_mean, dict):
259
+ for key in self.data_mean.keys():
260
+ self.data_mean[key] = self.data_mean[key].to(
261
+ correct_device_tensor.device
262
+ )
263
+ self.data_std[key] = self.data_std[key].to(correct_device_tensor.device)
264
+
265
+ def get_mean_lv(self, x):
266
+ return self.parameter_net(x), None
267
+
268
+ def distr_params(self, x):
269
+ mean, lv = self.get_mean_lv(x)
270
+ # mean, lv = x.chunk(2, dim=1)
271
+
272
+ params = {
273
+ "mean": mean,
274
+ "logvar": lv,
275
+ }
276
+ return params
277
+
278
+ @staticmethod
279
+ def mean(params):
280
+ return params["mean"]
281
+
282
+ @staticmethod
283
+ def mode(params):
284
+ return params["mean"]
285
+
286
+ @staticmethod
287
+ def sample(params):
288
+ # p = Normal(params['mean'], (params['logvar'] / 2).exp())
289
+ # return p.rsample()
290
+ return params["mean"]
291
+
292
+ def log_likelihood(self, x: torch.Tensor, params: Dict[str, torch.Tensor]):
293
+ """
294
+ Compute the log-likelihood given the parameters `params` obtained from the reconstruction tensor and the target tensor `x`.
295
+ """
296
+ predicted_s_denormalized = (
297
+ params["mean"] * self.data_std["target"] + self.data_mean["target"]
298
+ )
299
+ x_denormalized = x * self.data_std["target"] + self.data_mean["target"]
300
+ # predicted_s_cloned = predicted_s_denormalized
301
+ # predicted_s_reduced = predicted_s_cloned.permute(1, 0, 2, 3)
302
+
303
+ # x_cloned = x_denormalized
304
+ # x_cloned = x_cloned.permute(1, 0, 2, 3)
305
+ # x_reduced = x_cloned[0, ...]
306
+ # import pdb;pdb.set_trace()
307
+ likelihoods = self.noiseModel.likelihood(
308
+ x_denormalized, predicted_s_denormalized
309
+ )
310
+ # likelihoods = self.noiseModel.likelihood(x, params['mean'])
311
+ logprob = torch.log(likelihoods)
312
+ return logprob