careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,443 @@
1
+ """
2
+ UNet model.
3
+
4
+ A UNet encoder, decoder and complete model.
5
+ """
6
+
7
+ from typing import Any, List, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from ..config.support import SupportedActivation
13
+ from .activation import get_activation
14
+ from .layers import Conv_Block, MaxBlurPool
15
+
16
+
17
+ class UnetEncoder(nn.Module):
18
+ """
19
+ Unet encoder pathway.
20
+
21
+ Parameters
22
+ ----------
23
+ conv_dim : int
24
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
25
+ in_channels : int, optional
26
+ Number of input channels, by default 1.
27
+ depth : int, optional
28
+ Number of encoder blocks, by default 3.
29
+ num_channels_init : int, optional
30
+ Number of channels in the first encoder block, by default 64.
31
+ use_batch_norm : bool, optional
32
+ Whether to use batch normalization, by default True.
33
+ dropout : float, optional
34
+ Dropout probability, by default 0.0.
35
+ pool_kernel : int, optional
36
+ Kernel size for the max pooling layers, by default 2.
37
+ n2v2 : bool, optional
38
+ Whether to use N2V2 architecture, by default False.
39
+ groups : int, optional
40
+ Number of blocked connections from input channels to output
41
+ channels, by default 1.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ conv_dim: int,
47
+ in_channels: int = 1,
48
+ depth: int = 3,
49
+ num_channels_init: int = 64,
50
+ use_batch_norm: bool = True,
51
+ dropout: float = 0.0,
52
+ pool_kernel: int = 2,
53
+ n2v2: bool = False,
54
+ groups: int = 1,
55
+ ) -> None:
56
+ """
57
+ Constructor.
58
+
59
+ Parameters
60
+ ----------
61
+ conv_dim : int
62
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
63
+ in_channels : int, optional
64
+ Number of input channels, by default 1.
65
+ depth : int, optional
66
+ Number of encoder blocks, by default 3.
67
+ num_channels_init : int, optional
68
+ Number of channels in the first encoder block, by default 64.
69
+ use_batch_norm : bool, optional
70
+ Whether to use batch normalization, by default True.
71
+ dropout : float, optional
72
+ Dropout probability, by default 0.0.
73
+ pool_kernel : int, optional
74
+ Kernel size for the max pooling layers, by default 2.
75
+ n2v2 : bool, optional
76
+ Whether to use N2V2 architecture, by default False.
77
+ groups : int, optional
78
+ Number of blocked connections from input channels to output
79
+ channels, by default 1.
80
+ """
81
+ super().__init__()
82
+
83
+ self.pooling = (
84
+ getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
85
+ if not n2v2
86
+ else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
87
+ )
88
+
89
+ encoder_blocks = []
90
+
91
+ for n in range(depth):
92
+ out_channels = num_channels_init * (2**n) * groups
93
+ in_channels = in_channels if n == 0 else out_channels // 2
94
+ encoder_blocks.append(
95
+ Conv_Block(
96
+ conv_dim,
97
+ in_channels=in_channels,
98
+ out_channels=out_channels,
99
+ dropout_perc=dropout,
100
+ use_batch_norm=use_batch_norm,
101
+ groups=groups,
102
+ )
103
+ )
104
+ encoder_blocks.append(self.pooling)
105
+ self.encoder_blocks = nn.ModuleList(encoder_blocks)
106
+
107
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
108
+ """
109
+ Forward pass.
110
+
111
+ Parameters
112
+ ----------
113
+ x : torch.Tensor
114
+ Input tensor.
115
+
116
+ Returns
117
+ -------
118
+ List[torch.Tensor]
119
+ Output of each encoder block (skip connections) and final output of the
120
+ encoder.
121
+ """
122
+ encoder_features = []
123
+ for module in self.encoder_blocks:
124
+ x = module(x)
125
+ if isinstance(module, Conv_Block):
126
+ encoder_features.append(x)
127
+ features = [x, *encoder_features]
128
+ return features
129
+
130
+
131
+ class UnetDecoder(nn.Module):
132
+ """
133
+ Unet decoder pathway.
134
+
135
+ Parameters
136
+ ----------
137
+ conv_dim : int
138
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
139
+ depth : int, optional
140
+ Number of decoder blocks, by default 3.
141
+ num_channels_init : int, optional
142
+ Number of channels in the first encoder block, by default 64.
143
+ use_batch_norm : bool, optional
144
+ Whether to use batch normalization, by default True.
145
+ dropout : float, optional
146
+ Dropout probability, by default 0.0.
147
+ n2v2 : bool, optional
148
+ Whether to use N2V2 architecture, by default False.
149
+ groups : int, optional
150
+ Number of blocked connections from input channels to output
151
+ channels, by default 1.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ conv_dim: int,
157
+ depth: int = 3,
158
+ num_channels_init: int = 64,
159
+ use_batch_norm: bool = True,
160
+ dropout: float = 0.0,
161
+ n2v2: bool = False,
162
+ groups: int = 1,
163
+ ) -> None:
164
+ """
165
+ Constructor.
166
+
167
+ Parameters
168
+ ----------
169
+ conv_dim : int
170
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
171
+ depth : int, optional
172
+ Number of decoder blocks, by default 3.
173
+ num_channels_init : int, optional
174
+ Number of channels in the first encoder block, by default 64.
175
+ use_batch_norm : bool, optional
176
+ Whether to use batch normalization, by default True.
177
+ dropout : float, optional
178
+ Dropout probability, by default 0.0.
179
+ n2v2 : bool, optional
180
+ Whether to use N2V2 architecture, by default False.
181
+ groups : int, optional
182
+ Number of blocked connections from input channels to output
183
+ channels, by default 1.
184
+ """
185
+ super().__init__()
186
+
187
+ upsampling = nn.Upsample(
188
+ scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
189
+ )
190
+ in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
191
+
192
+ self.n2v2 = n2v2
193
+ self.groups = groups
194
+
195
+ self.bottleneck = Conv_Block(
196
+ conv_dim,
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ intermediate_channel_multiplier=2,
200
+ use_batch_norm=use_batch_norm,
201
+ dropout_perc=dropout,
202
+ groups=self.groups,
203
+ )
204
+
205
+ decoder_blocks: List[nn.Module] = []
206
+ for n in range(depth):
207
+ decoder_blocks.append(upsampling)
208
+ in_channels = (num_channels_init * 2 ** (depth - n)) * groups
209
+ out_channels = in_channels // 2
210
+ decoder_blocks.append(
211
+ Conv_Block(
212
+ conv_dim,
213
+ in_channels=(
214
+ in_channels + in_channels // 2 if n > 0 else in_channels
215
+ ),
216
+ out_channels=out_channels,
217
+ intermediate_channel_multiplier=2,
218
+ dropout_perc=dropout,
219
+ activation="ReLU",
220
+ use_batch_norm=use_batch_norm,
221
+ groups=groups,
222
+ )
223
+ )
224
+
225
+ self.decoder_blocks = nn.ModuleList(decoder_blocks)
226
+
227
+ def forward(self, *features: torch.Tensor) -> torch.Tensor:
228
+ """
229
+ Forward pass.
230
+
231
+ Parameters
232
+ ----------
233
+ *features : List[torch.Tensor]
234
+ List containing the output of each encoder block(skip connections) and final
235
+ output of the encoder.
236
+
237
+ Returns
238
+ -------
239
+ torch.Tensor
240
+ Output of the decoder.
241
+ """
242
+ x: torch.Tensor = features[0]
243
+ skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
244
+
245
+ x = self.bottleneck(x)
246
+
247
+ for i, module in enumerate(self.decoder_blocks):
248
+ x = module(x)
249
+ if isinstance(module, nn.Upsample):
250
+ # divide index by 2 because of upsampling layers
251
+ skip_connection: torch.Tensor = skip_connections[i // 2]
252
+ if self.n2v2:
253
+ if x.shape != skip_connections[-1].shape:
254
+ x = self._interleave(x, skip_connection, self.groups)
255
+ else:
256
+ x = self._interleave(x, skip_connection, self.groups)
257
+ return x
258
+
259
+ @staticmethod
260
+ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
261
+ """Interleave two tensors.
262
+
263
+ Splits the tensors `A` and `B` into equally sized groups along the channel
264
+ axis (axis=1); then concatenates the groups in alternating order along the
265
+ channel axis, starting with the first group from tensor A.
266
+
267
+ Parameters
268
+ ----------
269
+ A : torch.Tensor
270
+ First tensor.
271
+ B : torch.Tensor
272
+ Second tensor.
273
+ groups : int
274
+ The number of groups.
275
+
276
+ Returns
277
+ -------
278
+ torch.Tensor
279
+ Interleaved tensor.
280
+
281
+ Raises
282
+ ------
283
+ ValueError:
284
+ If either of `A` or `B`'s channel axis is not divisible by `groups`.
285
+ """
286
+ if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
287
+ raise ValueError(f"Number of channels not divisible by {groups} groups.")
288
+
289
+ m = A.shape[1] // groups
290
+ n = B.shape[1] // groups
291
+
292
+ A_groups: List[torch.Tensor] = [
293
+ A[:, i * m : (i + 1) * m] for i in range(groups)
294
+ ]
295
+ B_groups: List[torch.Tensor] = [
296
+ B[:, i * n : (i + 1) * n] for i in range(groups)
297
+ ]
298
+
299
+ interleaved = torch.cat(
300
+ [
301
+ tensor_list[i]
302
+ for i in range(groups)
303
+ for tensor_list in [A_groups, B_groups]
304
+ ],
305
+ dim=1,
306
+ )
307
+
308
+ return interleaved
309
+
310
+
311
+ class UNet(nn.Module):
312
+ """
313
+ UNet model.
314
+
315
+ Adapted for PyTorch from:
316
+ https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
317
+
318
+ Parameters
319
+ ----------
320
+ conv_dims : int
321
+ Number of dimensions of the convolution layers (2 or 3).
322
+ num_classes : int, optional
323
+ Number of classes to predict, by default 1.
324
+ in_channels : int, optional
325
+ Number of input channels, by default 1.
326
+ depth : int, optional
327
+ Number of downsamplings, by default 3.
328
+ num_channels_init : int, optional
329
+ Number of filters in the first convolution layer, by default 64.
330
+ use_batch_norm : bool, optional
331
+ Whether to use batch normalization, by default True.
332
+ dropout : float, optional
333
+ Dropout probability, by default 0.0.
334
+ pool_kernel : int, optional
335
+ Kernel size of the pooling layers, by default 2.
336
+ final_activation : Optional[Callable], optional
337
+ Activation function to use for the last layer, by default None.
338
+ n2v2 : bool, optional
339
+ Whether to use N2V2 architecture, by default False.
340
+ independent_channels : bool
341
+ Whether to train the channels independently, by default True.
342
+ **kwargs : Any
343
+ Additional keyword arguments, unused.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ conv_dims: int,
349
+ num_classes: int = 1,
350
+ in_channels: int = 1,
351
+ depth: int = 3,
352
+ num_channels_init: int = 64,
353
+ use_batch_norm: bool = True,
354
+ dropout: float = 0.0,
355
+ pool_kernel: int = 2,
356
+ final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
357
+ n2v2: bool = False,
358
+ independent_channels: bool = True,
359
+ **kwargs: Any,
360
+ ) -> None:
361
+ """
362
+ Constructor.
363
+
364
+ Parameters
365
+ ----------
366
+ conv_dims : int
367
+ Number of dimensions of the convolution layers (2 or 3).
368
+ num_classes : int, optional
369
+ Number of classes to predict, by default 1.
370
+ in_channels : int, optional
371
+ Number of input channels, by default 1.
372
+ depth : int, optional
373
+ Number of downsamplings, by default 3.
374
+ num_channels_init : int, optional
375
+ Number of filters in the first convolution layer, by default 64.
376
+ use_batch_norm : bool, optional
377
+ Whether to use batch normalization, by default True.
378
+ dropout : float, optional
379
+ Dropout probability, by default 0.0.
380
+ pool_kernel : int, optional
381
+ Kernel size of the pooling layers, by default 2.
382
+ final_activation : Optional[Callable], optional
383
+ Activation function to use for the last layer, by default None.
384
+ n2v2 : bool, optional
385
+ Whether to use N2V2 architecture, by default False.
386
+ independent_channels : bool
387
+ Whether to train parallel independent networks for each channel, by
388
+ default True.
389
+ **kwargs : Any
390
+ Additional keyword arguments, unused.
391
+ """
392
+ super().__init__()
393
+
394
+ groups = in_channels if independent_channels else 1
395
+
396
+ self.encoder = UnetEncoder(
397
+ conv_dims,
398
+ in_channels=in_channels,
399
+ depth=depth,
400
+ num_channels_init=num_channels_init,
401
+ use_batch_norm=use_batch_norm,
402
+ dropout=dropout,
403
+ pool_kernel=pool_kernel,
404
+ n2v2=n2v2,
405
+ groups=groups,
406
+ )
407
+
408
+ self.decoder = UnetDecoder(
409
+ conv_dims,
410
+ depth=depth,
411
+ num_channels_init=num_channels_init,
412
+ use_batch_norm=use_batch_norm,
413
+ dropout=dropout,
414
+ n2v2=n2v2,
415
+ groups=groups,
416
+ )
417
+ self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
418
+ in_channels=num_channels_init * groups,
419
+ out_channels=num_classes,
420
+ kernel_size=1,
421
+ groups=groups,
422
+ )
423
+ self.final_activation = get_activation(final_activation)
424
+
425
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
426
+ """
427
+ Forward pass.
428
+
429
+ Parameters
430
+ ----------
431
+ x : torch.Tensor
432
+ Input tensor.
433
+
434
+ Returns
435
+ -------
436
+ torch.Tensor
437
+ Output of the model.
438
+ """
439
+ encoder_features = self.encoder(x)
440
+ x = self.decoder(*encoder_features)
441
+ x = self.final_conv(x)
442
+ x = self.final_activation(x)
443
+ return x
@@ -0,0 +1,10 @@
1
+ """Package to house various prediction utilies."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "stitch_prediction_single",
6
+ "convert_outputs",
7
+ ]
8
+
9
+ from .prediction_outputs import convert_outputs
10
+ from .stitch_prediction import stitch_prediction, stitch_prediction_single
@@ -0,0 +1,135 @@
1
+ """Module containing functions to convert prediction outputs to desired form."""
2
+
3
+ from typing import Any, List, Literal, Tuple, Union, overload
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from ..config.tile_information import TileInformation
9
+ from .stitch_prediction import stitch_prediction
10
+
11
+
12
+ def convert_outputs(predictions: List[Any], tiled: bool) -> list[NDArray]:
13
+ """
14
+ Convert the Lightning trainer outputs to the desired form.
15
+
16
+ This method allows stitching back together tiled predictions.
17
+
18
+ Parameters
19
+ ----------
20
+ predictions : list
21
+ Predictions that are output from `Trainer.predict`.
22
+ tiled : bool
23
+ Whether the predictions are tiled.
24
+
25
+ Returns
26
+ -------
27
+ list of numpy.ndarray or numpy.ndarray
28
+ List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
29
+ be in a list.
30
+ """
31
+ if len(predictions) == 0:
32
+ return predictions
33
+
34
+ # this layout is to stop mypy complaining
35
+ if tiled:
36
+ predictions_comb = combine_batches(predictions, tiled)
37
+ predictions_output = stitch_prediction(*predictions_comb)
38
+ else:
39
+ predictions_output = combine_batches(predictions, tiled)
40
+
41
+ return predictions_output
42
+
43
+
44
+ # for mypy
45
+ @overload
46
+ def combine_batches( # numpydoc ignore=GL08
47
+ predictions: List[Any], tiled: Literal[True]
48
+ ) -> Tuple[List[NDArray], List[TileInformation]]: ...
49
+
50
+
51
+ # for mypy
52
+ @overload
53
+ def combine_batches( # numpydoc ignore=GL08
54
+ predictions: List[Any], tiled: Literal[False]
55
+ ) -> List[NDArray]: ...
56
+
57
+
58
+ # for mypy
59
+ @overload
60
+ def combine_batches( # numpydoc ignore=GL08
61
+ predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
62
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
63
+
64
+
65
+ def combine_batches(
66
+ predictions: List[Any], tiled: bool
67
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
68
+ """
69
+ If predictions are in batches, they will be combined.
70
+
71
+ Parameters
72
+ ----------
73
+ predictions : list
74
+ Predictions that are output from `Trainer.predict`.
75
+ tiled : bool
76
+ Whether the predictions are tiled.
77
+
78
+ Returns
79
+ -------
80
+ (list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
81
+ Combined batches.
82
+ """
83
+ if tiled:
84
+ return _combine_tiled_batches(predictions)
85
+ else:
86
+ return _combine_array_batches(predictions)
87
+
88
+
89
+ def _combine_tiled_batches(
90
+ predictions: List[Tuple[NDArray, List[TileInformation]]]
91
+ ) -> Tuple[List[NDArray], List[TileInformation]]:
92
+ """
93
+ Combine batches from tiled output.
94
+
95
+ Parameters
96
+ ----------
97
+ predictions : list of (numpy.ndarray, list of TileInformation)
98
+ Predictions that are output from `Trainer.predict`. For tiled batches, this is
99
+ a list of tuples. The first element of the tuples is the prediction output of
100
+ tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
101
+ element of the tuples is a list of TileInformation objects of length B.
102
+
103
+ Returns
104
+ -------
105
+ tuple of (list of numpy.ndarray, list of TileInformation)
106
+ Combined batches.
107
+ """
108
+ # turn list of lists into single list
109
+ tile_infos = [
110
+ tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
111
+ ]
112
+ prediction_tiles: List[NDArray] = _combine_array_batches(
113
+ [preds for preds, _ in predictions]
114
+ )
115
+ return prediction_tiles, tile_infos
116
+
117
+
118
+ def _combine_array_batches(predictions: List[NDArray]) -> List[NDArray]:
119
+ """
120
+ Combine batches of arrays.
121
+
122
+ Parameters
123
+ ----------
124
+ predictions : list
125
+ Prediction arrays that are output from `Trainer.predict`. A list of arrays that
126
+ have dimensions (B, C, (Z), Y, X), where B is batch size.
127
+
128
+ Returns
129
+ -------
130
+ list of numpy.ndarray
131
+ A list of arrays with dimensions (1, C, (Z), Y, X).
132
+ """
133
+ prediction_concat: NDArray = np.concatenate(predictions, axis=0)
134
+ prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
135
+ return prediction_split
@@ -0,0 +1,98 @@
1
+ """Prediction utility functions."""
2
+
3
+ import builtins
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.tile_information import TileInformation
10
+
11
+
12
+ # TODO: why not allow input and output of torch.tensor ?
13
+ def stitch_prediction(
14
+ tiles: List[np.ndarray],
15
+ tile_infos: List[TileInformation],
16
+ ) -> List[np.ndarray]:
17
+ """
18
+ Stitch tiles back together to form a full image(s).
19
+
20
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
21
+ singleton dimension.
22
+
23
+ Parameters
24
+ ----------
25
+ tiles : list of numpy.ndarray
26
+ Cropped tiles and their respective stitching coordinates. Can contain tiles
27
+ from multiple images.
28
+ tile_infos : list of TileInformation
29
+ List of information and coordinates obtained from
30
+ `dataset.tiled_patching.extract_tiles`.
31
+
32
+ Returns
33
+ -------
34
+ list of numpy.ndarray
35
+ Full image(s).
36
+ """
37
+ # Find where to split the lists so that only info from one image is contained.
38
+ # Do this by locating the last tiles of each image.
39
+ last_tiles = [tile_info.last_tile for tile_info in tile_infos]
40
+ last_tile_position = np.where(last_tiles)[0]
41
+ image_slices = [
42
+ slice(
43
+ None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
44
+ )
45
+ for i in range(len(last_tile_position))
46
+ ]
47
+ image_predictions = []
48
+ # slice the lists and apply stitch_prediction_single to each in turn.
49
+ for image_slice in image_slices:
50
+ image_predictions.append(
51
+ stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
52
+ )
53
+ return image_predictions
54
+
55
+
56
+ def stitch_prediction_single(
57
+ tiles: List[NDArray],
58
+ tile_infos: List[TileInformation],
59
+ ) -> NDArray:
60
+ """
61
+ Stitch tiles back together to form a full image.
62
+
63
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
64
+ singleton dimension.
65
+
66
+ Parameters
67
+ ----------
68
+ tiles : list of numpy.ndarray
69
+ Cropped tiles and their respective stitching coordinates.
70
+ tile_infos : list of TileInformation
71
+ List of information and coordinates obtained from
72
+ `dataset.tiled_patching.extract_tiles`.
73
+
74
+ Returns
75
+ -------
76
+ numpy.ndarray
77
+ Full image, with dimensions SC(Z)YX.
78
+ """
79
+ # retrieve whole array size
80
+ input_shape = (1, *tile_infos[0].array_shape) # add S dim
81
+ predicted_image = np.zeros(input_shape, dtype=np.float32)
82
+
83
+ for tile, tile_info in zip(tiles, tile_infos):
84
+
85
+ # Compute coordinates for cropping predicted tile
86
+ crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
87
+ ...,
88
+ *[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
89
+ )
90
+
91
+ # Crop predited tile according to overlap coordinates
92
+ cropped_tile = tile[crop_slices]
93
+
94
+ # Insert cropped tile into predicted image using stitch coordinates
95
+ image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
96
+ predicted_image[image_slices] = cropped_tile.astype(np.float32)
97
+
98
+ return predicted_image
@@ -0,0 +1,20 @@
1
+ """Transforms that are used to augment the data."""
2
+
3
+ __all__ = [
4
+ "get_all_transforms",
5
+ "N2VManipulate",
6
+ "XYFlip",
7
+ "XYRandomRotate90",
8
+ "ImageRestorationTTA",
9
+ "Denormalize",
10
+ "Normalize",
11
+ "Compose",
12
+ ]
13
+
14
+
15
+ from .compose import Compose, get_all_transforms
16
+ from .n2v_manipulate import N2VManipulate
17
+ from .normalize import Denormalize, Normalize
18
+ from .tta import ImageRestorationTTA
19
+ from .xy_flip import XYFlip
20
+ from .xy_random_rotate90 import XYRandomRotate90