careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,443 @@
1
+ """
2
+ UNet model.
3
+
4
+ A UNet encoder, decoder and complete model.
5
+ """
6
+
7
+ from typing import Any, List, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from ..config.support import SupportedActivation
13
+ from .activation import get_activation
14
+ from .layers import Conv_Block, MaxBlurPool
15
+
16
+
17
+ class UnetEncoder(nn.Module):
18
+ """
19
+ Unet encoder pathway.
20
+
21
+ Parameters
22
+ ----------
23
+ conv_dim : int
24
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
25
+ in_channels : int, optional
26
+ Number of input channels, by default 1.
27
+ depth : int, optional
28
+ Number of encoder blocks, by default 3.
29
+ num_channels_init : int, optional
30
+ Number of channels in the first encoder block, by default 64.
31
+ use_batch_norm : bool, optional
32
+ Whether to use batch normalization, by default True.
33
+ dropout : float, optional
34
+ Dropout probability, by default 0.0.
35
+ pool_kernel : int, optional
36
+ Kernel size for the max pooling layers, by default 2.
37
+ n2v2 : bool, optional
38
+ Whether to use N2V2 architecture, by default False.
39
+ groups : int, optional
40
+ Number of blocked connections from input channels to output
41
+ channels, by default 1.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ conv_dim: int,
47
+ in_channels: int = 1,
48
+ depth: int = 3,
49
+ num_channels_init: int = 64,
50
+ use_batch_norm: bool = True,
51
+ dropout: float = 0.0,
52
+ pool_kernel: int = 2,
53
+ n2v2: bool = False,
54
+ groups: int = 1,
55
+ ) -> None:
56
+ """
57
+ Constructor.
58
+
59
+ Parameters
60
+ ----------
61
+ conv_dim : int
62
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
63
+ in_channels : int, optional
64
+ Number of input channels, by default 1.
65
+ depth : int, optional
66
+ Number of encoder blocks, by default 3.
67
+ num_channels_init : int, optional
68
+ Number of channels in the first encoder block, by default 64.
69
+ use_batch_norm : bool, optional
70
+ Whether to use batch normalization, by default True.
71
+ dropout : float, optional
72
+ Dropout probability, by default 0.0.
73
+ pool_kernel : int, optional
74
+ Kernel size for the max pooling layers, by default 2.
75
+ n2v2 : bool, optional
76
+ Whether to use N2V2 architecture, by default False.
77
+ groups : int, optional
78
+ Number of blocked connections from input channels to output
79
+ channels, by default 1.
80
+ """
81
+ super().__init__()
82
+
83
+ self.pooling = (
84
+ getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
85
+ if not n2v2
86
+ else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
87
+ )
88
+
89
+ encoder_blocks = []
90
+
91
+ for n in range(depth):
92
+ out_channels = num_channels_init * (2**n) * groups
93
+ in_channels = in_channels if n == 0 else out_channels // 2
94
+ encoder_blocks.append(
95
+ Conv_Block(
96
+ conv_dim,
97
+ in_channels=in_channels,
98
+ out_channels=out_channels,
99
+ dropout_perc=dropout,
100
+ use_batch_norm=use_batch_norm,
101
+ groups=groups,
102
+ )
103
+ )
104
+ encoder_blocks.append(self.pooling)
105
+ self.encoder_blocks = nn.ModuleList(encoder_blocks)
106
+
107
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
108
+ """
109
+ Forward pass.
110
+
111
+ Parameters
112
+ ----------
113
+ x : torch.Tensor
114
+ Input tensor.
115
+
116
+ Returns
117
+ -------
118
+ List[torch.Tensor]
119
+ Output of each encoder block (skip connections) and final output of the
120
+ encoder.
121
+ """
122
+ encoder_features = []
123
+ for module in self.encoder_blocks:
124
+ x = module(x)
125
+ if isinstance(module, Conv_Block):
126
+ encoder_features.append(x)
127
+ features = [x, *encoder_features]
128
+ return features
129
+
130
+
131
+ class UnetDecoder(nn.Module):
132
+ """
133
+ Unet decoder pathway.
134
+
135
+ Parameters
136
+ ----------
137
+ conv_dim : int
138
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
139
+ depth : int, optional
140
+ Number of decoder blocks, by default 3.
141
+ num_channels_init : int, optional
142
+ Number of channels in the first encoder block, by default 64.
143
+ use_batch_norm : bool, optional
144
+ Whether to use batch normalization, by default True.
145
+ dropout : float, optional
146
+ Dropout probability, by default 0.0.
147
+ n2v2 : bool, optional
148
+ Whether to use N2V2 architecture, by default False.
149
+ groups : int, optional
150
+ Number of blocked connections from input channels to output
151
+ channels, by default 1.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ conv_dim: int,
157
+ depth: int = 3,
158
+ num_channels_init: int = 64,
159
+ use_batch_norm: bool = True,
160
+ dropout: float = 0.0,
161
+ n2v2: bool = False,
162
+ groups: int = 1,
163
+ ) -> None:
164
+ """
165
+ Constructor.
166
+
167
+ Parameters
168
+ ----------
169
+ conv_dim : int
170
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
171
+ depth : int, optional
172
+ Number of decoder blocks, by default 3.
173
+ num_channels_init : int, optional
174
+ Number of channels in the first encoder block, by default 64.
175
+ use_batch_norm : bool, optional
176
+ Whether to use batch normalization, by default True.
177
+ dropout : float, optional
178
+ Dropout probability, by default 0.0.
179
+ n2v2 : bool, optional
180
+ Whether to use N2V2 architecture, by default False.
181
+ groups : int, optional
182
+ Number of blocked connections from input channels to output
183
+ channels, by default 1.
184
+ """
185
+ super().__init__()
186
+
187
+ upsampling = nn.Upsample(
188
+ scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
189
+ )
190
+ in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
191
+
192
+ self.n2v2 = n2v2
193
+ self.groups = groups
194
+
195
+ self.bottleneck = Conv_Block(
196
+ conv_dim,
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ intermediate_channel_multiplier=2,
200
+ use_batch_norm=use_batch_norm,
201
+ dropout_perc=dropout,
202
+ groups=self.groups,
203
+ )
204
+
205
+ decoder_blocks: List[nn.Module] = []
206
+ for n in range(depth):
207
+ decoder_blocks.append(upsampling)
208
+ in_channels = (num_channels_init * 2 ** (depth - n)) * groups
209
+ out_channels = in_channels // 2
210
+ decoder_blocks.append(
211
+ Conv_Block(
212
+ conv_dim,
213
+ in_channels=(
214
+ in_channels + in_channels // 2 if n > 0 else in_channels
215
+ ),
216
+ out_channels=out_channels,
217
+ intermediate_channel_multiplier=2,
218
+ dropout_perc=dropout,
219
+ activation="ReLU",
220
+ use_batch_norm=use_batch_norm,
221
+ groups=groups,
222
+ )
223
+ )
224
+
225
+ self.decoder_blocks = nn.ModuleList(decoder_blocks)
226
+
227
+ def forward(self, *features: torch.Tensor) -> torch.Tensor:
228
+ """
229
+ Forward pass.
230
+
231
+ Parameters
232
+ ----------
233
+ *features : List[torch.Tensor]
234
+ List containing the output of each encoder block(skip connections) and final
235
+ output of the encoder.
236
+
237
+ Returns
238
+ -------
239
+ torch.Tensor
240
+ Output of the decoder.
241
+ """
242
+ x: torch.Tensor = features[0]
243
+ skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
244
+
245
+ x = self.bottleneck(x)
246
+
247
+ for i, module in enumerate(self.decoder_blocks):
248
+ x = module(x)
249
+ if isinstance(module, nn.Upsample):
250
+ # divide index by 2 because of upsampling layers
251
+ skip_connection: torch.Tensor = skip_connections[i // 2]
252
+ if self.n2v2:
253
+ if x.shape != skip_connections[-1].shape:
254
+ x = self._interleave(x, skip_connection, self.groups)
255
+ else:
256
+ x = self._interleave(x, skip_connection, self.groups)
257
+ return x
258
+
259
+ @staticmethod
260
+ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
261
+ """Interleave two tensors.
262
+
263
+ Splits the tensors `A` and `B` into equally sized groups along the channel
264
+ axis (axis=1); then concatenates the groups in alternating order along the
265
+ channel axis, starting with the first group from tensor A.
266
+
267
+ Parameters
268
+ ----------
269
+ A : torch.Tensor
270
+ First tensor.
271
+ B : torch.Tensor
272
+ Second tensor.
273
+ groups : int
274
+ The number of groups.
275
+
276
+ Returns
277
+ -------
278
+ torch.Tensor
279
+ Interleaved tensor.
280
+
281
+ Raises
282
+ ------
283
+ ValueError:
284
+ If either of `A` or `B`'s channel axis is not divisible by `groups`.
285
+ """
286
+ if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
287
+ raise ValueError(f"Number of channels not divisible by {groups} groups.")
288
+
289
+ m = A.shape[1] // groups
290
+ n = B.shape[1] // groups
291
+
292
+ A_groups: List[torch.Tensor] = [
293
+ A[:, i * m : (i + 1) * m] for i in range(groups)
294
+ ]
295
+ B_groups: List[torch.Tensor] = [
296
+ B[:, i * n : (i + 1) * n] for i in range(groups)
297
+ ]
298
+
299
+ interleaved = torch.cat(
300
+ [
301
+ tensor_list[i]
302
+ for i in range(groups)
303
+ for tensor_list in [A_groups, B_groups]
304
+ ],
305
+ dim=1,
306
+ )
307
+
308
+ return interleaved
309
+
310
+
311
+ class UNet(nn.Module):
312
+ """
313
+ UNet model.
314
+
315
+ Adapted for PyTorch from:
316
+ https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
317
+
318
+ Parameters
319
+ ----------
320
+ conv_dims : int
321
+ Number of dimensions of the convolution layers (2 or 3).
322
+ num_classes : int, optional
323
+ Number of classes to predict, by default 1.
324
+ in_channels : int, optional
325
+ Number of input channels, by default 1.
326
+ depth : int, optional
327
+ Number of downsamplings, by default 3.
328
+ num_channels_init : int, optional
329
+ Number of filters in the first convolution layer, by default 64.
330
+ use_batch_norm : bool, optional
331
+ Whether to use batch normalization, by default True.
332
+ dropout : float, optional
333
+ Dropout probability, by default 0.0.
334
+ pool_kernel : int, optional
335
+ Kernel size of the pooling layers, by default 2.
336
+ final_activation : Optional[Callable], optional
337
+ Activation function to use for the last layer, by default None.
338
+ n2v2 : bool, optional
339
+ Whether to use N2V2 architecture, by default False.
340
+ independent_channels : bool
341
+ Whether to train the channels independently, by default True.
342
+ **kwargs : Any
343
+ Additional keyword arguments, unused.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ conv_dims: int,
349
+ num_classes: int = 1,
350
+ in_channels: int = 1,
351
+ depth: int = 3,
352
+ num_channels_init: int = 64,
353
+ use_batch_norm: bool = True,
354
+ dropout: float = 0.0,
355
+ pool_kernel: int = 2,
356
+ final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
357
+ n2v2: bool = False,
358
+ independent_channels: bool = True,
359
+ **kwargs: Any,
360
+ ) -> None:
361
+ """
362
+ Constructor.
363
+
364
+ Parameters
365
+ ----------
366
+ conv_dims : int
367
+ Number of dimensions of the convolution layers (2 or 3).
368
+ num_classes : int, optional
369
+ Number of classes to predict, by default 1.
370
+ in_channels : int, optional
371
+ Number of input channels, by default 1.
372
+ depth : int, optional
373
+ Number of downsamplings, by default 3.
374
+ num_channels_init : int, optional
375
+ Number of filters in the first convolution layer, by default 64.
376
+ use_batch_norm : bool, optional
377
+ Whether to use batch normalization, by default True.
378
+ dropout : float, optional
379
+ Dropout probability, by default 0.0.
380
+ pool_kernel : int, optional
381
+ Kernel size of the pooling layers, by default 2.
382
+ final_activation : Optional[Callable], optional
383
+ Activation function to use for the last layer, by default None.
384
+ n2v2 : bool, optional
385
+ Whether to use N2V2 architecture, by default False.
386
+ independent_channels : bool
387
+ Whether to train parallel independent networks for each channel, by
388
+ default True.
389
+ **kwargs : Any
390
+ Additional keyword arguments, unused.
391
+ """
392
+ super().__init__()
393
+
394
+ groups = in_channels if independent_channels else 1
395
+
396
+ self.encoder = UnetEncoder(
397
+ conv_dims,
398
+ in_channels=in_channels,
399
+ depth=depth,
400
+ num_channels_init=num_channels_init,
401
+ use_batch_norm=use_batch_norm,
402
+ dropout=dropout,
403
+ pool_kernel=pool_kernel,
404
+ n2v2=n2v2,
405
+ groups=groups,
406
+ )
407
+
408
+ self.decoder = UnetDecoder(
409
+ conv_dims,
410
+ depth=depth,
411
+ num_channels_init=num_channels_init,
412
+ use_batch_norm=use_batch_norm,
413
+ dropout=dropout,
414
+ n2v2=n2v2,
415
+ groups=groups,
416
+ )
417
+ self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
418
+ in_channels=num_channels_init * groups,
419
+ out_channels=num_classes,
420
+ kernel_size=1,
421
+ groups=groups,
422
+ )
423
+ self.final_activation = get_activation(final_activation)
424
+
425
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
426
+ """
427
+ Forward pass.
428
+
429
+ Parameters
430
+ ----------
431
+ x : torch.Tensor
432
+ Input tensor.
433
+
434
+ Returns
435
+ -------
436
+ torch.Tensor
437
+ Output of the model.
438
+ """
439
+ encoder_features = self.encoder(x)
440
+ x = self.decoder(*encoder_features)
441
+ x = self.final_conv(x)
442
+ x = self.final_activation(x)
443
+ return x
@@ -0,0 +1,10 @@
1
+ """Package to house various prediction utilies."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "stitch_prediction_single",
6
+ "convert_outputs",
7
+ ]
8
+
9
+ from .prediction_outputs import convert_outputs
10
+ from .stitch_prediction import stitch_prediction, stitch_prediction_single
@@ -0,0 +1,158 @@
1
+ """Module containing pytorch implementations for obtaining predictions from an LVAE."""
2
+
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+
7
+ from careamics.models.lvae import LadderVAE as LVAE
8
+ from careamics.models.lvae.likelihoods import LikelihoodModule
9
+
10
+ # TODO: convert these functions to lightning module `predict_step`
11
+ # -> mmse_count will have to be an instance attribute?
12
+
13
+
14
+ # This function is needed because the output of the datasets (input here) can include
15
+ # auxillary items, such as the TileInformation. This function allows for easier reuse
16
+ # between lvae_predict_single_sample and lvae_predict_mmse.
17
+ def lvae_predict_single_sample(
18
+ model: LVAE,
19
+ likelihood_obj: LikelihoodModule,
20
+ input: torch.Tensor,
21
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
22
+ """
23
+ Generate a single sample prediction from an LVAE model, for a given input.
24
+
25
+ Parameters
26
+ ----------
27
+ model : LVAE
28
+ Trained LVAE model.
29
+ likelihood_obj : LikelihoodModule
30
+ Instance of a likelihood class.
31
+ input : torch.tensor
32
+ Input to generate prediction for. Expected shape is (S, C, Y, X).
33
+
34
+ Returns
35
+ -------
36
+ tuple of (torch.tensor, optional torch.tensor)
37
+ The first element is the sample prediction, and the second element is the
38
+ log-variance. The log-variance will be None if `model.predict_logvar is None`.
39
+ """
40
+ model.eval() # Not in original predict code: effects batch_norm and dropout layers
41
+ with torch.no_grad():
42
+ output: torch.Tensor
43
+ output, _ = model(input) # 2nd item is top-down data dict
44
+
45
+ # presently, get_mean_lv just splits the output in 2 if predict_logvar=True,
46
+ # optionally clips the logvavr if logvar_lowerbound is not None
47
+ # TODO: consider refactoring to remove use of the likelihood object
48
+ sample_prediction, log_var = likelihood_obj.get_mean_lv(output)
49
+
50
+ # TODO: output denormalization using target stats that will be saved in data config
51
+ # -> Don't think we need this, saw it in a random bit of code somewhere.
52
+
53
+ return sample_prediction, log_var
54
+
55
+
56
+ def lvae_predict_tiled_batch(
57
+ model: LVAE,
58
+ likelihood_obj: LikelihoodModule,
59
+ input: tuple[Any],
60
+ ) -> tuple[tuple[Any], Optional[tuple[Any]]]:
61
+ # TODO: fix docstring return types, ... too many output options
62
+ """
63
+ Generate a single sample prediction from an LVAE model, for a given input.
64
+
65
+ Parameters
66
+ ----------
67
+ model : LVAE
68
+ Trained LVAE model.
69
+ likelihood_obj : LikelihoodModule
70
+ Instance of a likelihood class.
71
+ input : torch.tensor | tuple of (torch.tensor, Any, ...)
72
+ Input to generate prediction for. This can include auxilary inputs such as
73
+ `TileInformation`, but the model input is always the first item of the tuple.
74
+ Expected shape of the model input is (S, C, Y, X).
75
+
76
+ Returns
77
+ -------
78
+ tuple of ((torch.tensor, Any, ...), optional tuple of (torch.tensor, Any, ...))
79
+ The first element is the sample prediction, and the second element is the
80
+ log-variance. The log-variance will be None if `model.predict_logvar is None`.
81
+ Any auxillary data included in the input will also be include with both the
82
+ sample prediction and the log-variance.
83
+ """
84
+ x: torch.Tensor
85
+ aux: list[Any]
86
+ x, *aux = input
87
+
88
+ sample_prediction, log_var = lvae_predict_single_sample(
89
+ model=model, likelihood_obj=likelihood_obj, input=x
90
+ )
91
+
92
+ log_var_output = (log_var, *aux) if log_var is not None else None
93
+ return (sample_prediction, *aux), log_var_output
94
+
95
+
96
+ def lvae_predict_mmse_tiled_batch(
97
+ model: LVAE,
98
+ likelihood_obj: LikelihoodModule,
99
+ input: tuple[Any],
100
+ mmse_count: int,
101
+ ) -> tuple[tuple[Any], tuple[Any], Optional[tuple[Any]]]:
102
+ # TODO: fix docstring return types, ... hard to make readable
103
+ """
104
+ Generate the MMSE (minimum mean squared error) prediction, for a given input.
105
+
106
+ This is calculated from the mean of multiple single sample predictions.
107
+
108
+ Parameters
109
+ ----------
110
+ model : LVAE
111
+ Trained LVAE model.
112
+ likelihood_obj : LikelihoodModule
113
+ Instance of a likelihood class.
114
+ input : torch.tensor | tuple of (torch.tensor, Any, ...)
115
+ Input to generate prediction for. This can include auxilary inputs such as
116
+ `TileInformation`, but the model input is always the first item of the tuple.
117
+ Expected shape of the model input is (S, C, Y, X).
118
+ mmse_count : int
119
+ Number of samples to generate to calculate MMSE (minimum mean squared error).
120
+
121
+ Returns
122
+ -------
123
+ tuple of (tuple of (torch.Tensor[Any], Any, ...))
124
+ A tuple of 3 elements. The first element contains the MMSE prediction, the
125
+ second contains the standard deviation of the samples used to create the MMSE
126
+ prediction. Finally the last element contains the log-variance of the
127
+ likelihood, this will be `None` if `likelihood.predict_logvar` is `None`.
128
+ Any auxillary data included in the input will also be include with all of the
129
+ MMSE prediction, the standard deviation, and the log-variance.
130
+ """
131
+ if mmse_count <= 0:
132
+ raise ValueError("MMSE count must be greater than zero.")
133
+
134
+ x: torch.Tensor
135
+ aux: list[Any]
136
+ x, *aux = input
137
+
138
+ input_shape = x.shape
139
+ output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
140
+ log_var: Optional[torch.Tensor] = None
141
+ # pre-declare empty array to fill with individual sample predictions
142
+ sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
143
+ for mmse_idx in range(mmse_count):
144
+ sample_prediction, lv = lvae_predict_single_sample(
145
+ model=model, likelihood_obj=likelihood_obj, input=x
146
+ )
147
+ # only keep the log variance of the first sample prediction
148
+ if mmse_idx == 0:
149
+ log_var = lv
150
+
151
+ # store sample predictions
152
+ sample_predictions[mmse_idx, ...] = sample_prediction
153
+
154
+ mmse_prediction = torch.mean(sample_predictions, dim=0)
155
+ mmse_prediction_std = torch.std(sample_predictions, dim=0)
156
+
157
+ log_var_output = (log_var, *aux) if log_var is not None else None
158
+ return (mmse_prediction, *aux), (mmse_prediction_std, *aux), log_var_output