careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,395 @@
1
+ """
2
+ Script for utility functions needed by the LVAE model.
3
+ """
4
+
5
+ from typing import Iterable
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.transforms.functional as F
11
+ from torch.distributions.normal import Normal
12
+
13
+
14
+ def torch_nanmean(inp):
15
+ return torch.mean(inp[~inp.isnan()])
16
+
17
+
18
+ def compute_batch_mean(x):
19
+ N = len(x)
20
+ return x.view(N, -1).mean(dim=1)
21
+
22
+
23
+ def power_of_2(self, x):
24
+ assert isinstance(x, int)
25
+ if x == 1:
26
+ return True
27
+ if x == 0:
28
+ # happens with validation
29
+ return False
30
+ if x % 2 == 1:
31
+ return False
32
+ return self.power_of_2(x // 2)
33
+
34
+
35
+ class Enum:
36
+ @classmethod
37
+ def name(cls, enum_type):
38
+ for key, value in cls.__dict__.items():
39
+ if enum_type == value:
40
+ return key
41
+
42
+ @classmethod
43
+ def contains(cls, enum_type):
44
+ for key, value in cls.__dict__.items():
45
+ if enum_type == value:
46
+ return True
47
+ return False
48
+
49
+ @classmethod
50
+ def from_name(cls, enum_type_str):
51
+ for key, value in cls.__dict__.items():
52
+ if key == enum_type_str:
53
+ return value
54
+ assert f"{cls.__name__}:{enum_type_str} doesnot exist."
55
+
56
+
57
+ class LossType(Enum):
58
+ Elbo = 0
59
+ ElboWithCritic = 1
60
+ ElboMixedReconstruction = 2
61
+ MSE = 3
62
+ ElboWithNbrConsistency = 4
63
+ ElboSemiSupMixedReconstruction = 5
64
+ ElboCL = 6
65
+ ElboRestrictedReconstruction = 7
66
+ DenoiSplitMuSplit = 8
67
+
68
+
69
+ class ModelType(Enum):
70
+ LadderVae = 3
71
+ LadderVaeTwinDecoder = 4
72
+ LadderVAECritic = 5
73
+ # Separate vampprior: two optimizers
74
+ LadderVaeSepVampprior = 6
75
+ # one encoder for mixed input, two for separate inputs.
76
+ LadderVaeSepEncoder = 7
77
+ LadderVAEMultiTarget = 8
78
+ LadderVaeSepEncoderSingleOptim = 9
79
+ UNet = 10
80
+ BraveNet = 11
81
+ LadderVaeStitch = 12
82
+ LadderVaeSemiSupervised = 13
83
+ LadderVaeStitch2Stage = 14 # Note that previously trained models will have issue.
84
+ # since earlier, LadderVaeStitch2Stage = 13, LadderVaeSemiSupervised = 14
85
+ LadderVaeMixedRecons = 15
86
+ LadderVaeCL = 16
87
+ LadderVaeTwoDataSet = (
88
+ 17 # on one subdset, apply disentanglement, on other apply reconstruction
89
+ )
90
+ LadderVaeTwoDatasetMultiBranch = 18
91
+ LadderVaeTwoDatasetMultiOptim = 19
92
+ LVaeDeepEncoderIntensityAug = 20
93
+ AutoRegresiveLadderVAE = 21
94
+ LadderVAEInterleavedOptimization = 22
95
+ Denoiser = 23
96
+ DenoiserSplitter = 24
97
+ SplitterDenoiser = 25
98
+ LadderVAERestrictedReconstruction = 26
99
+ LadderVAETwoDataSetRestRecon = 27
100
+ LadderVAETwoDataSetFinetuning = 28
101
+
102
+
103
+ def _pad_crop_img(x, size, mode) -> torch.Tensor:
104
+ """Pads or crops a tensor.
105
+ Pads or crops a tensor of shape (batch, channels, h, w) to new height
106
+ and width given by a tuple.
107
+ Args:
108
+ x (torch.Tensor): Input image
109
+ size (list or tuple): Desired size (height, width)
110
+ mode (str): Mode, either 'pad' or 'crop'
111
+ Returns:
112
+ The padded or cropped tensor
113
+ """
114
+ assert x.dim() == 4 and len(size) == 2
115
+ size = tuple(size)
116
+ x_size = x.size()[2:4]
117
+ if mode == "pad":
118
+ cond = x_size[0] > size[0] or x_size[1] > size[1]
119
+ elif mode == "crop":
120
+ cond = x_size[0] < size[0] or x_size[1] < size[1]
121
+ else:
122
+ raise ValueError(f"invalid mode '{mode}'")
123
+ if cond:
124
+ raise ValueError(f"trying to {mode} from size {x_size} to size {size}")
125
+ dr, dc = (abs(x_size[0] - size[0]), abs(x_size[1] - size[1]))
126
+ dr1, dr2 = dr // 2, dr - (dr // 2)
127
+ dc1, dc2 = dc // 2, dc - (dc // 2)
128
+ if mode == "pad":
129
+ return nn.functional.pad(x, [dc1, dc2, dr1, dr2, 0, 0, 0, 0])
130
+ elif mode == "crop":
131
+ return x[:, :, dr1 : x_size[0] - dr2, dc1 : x_size[1] - dc2]
132
+
133
+
134
+ def pad_img_tensor(x, size) -> torch.Tensor:
135
+ """Pads a tensor.
136
+ Pads a tensor of shape (batch, channels, h, w) to a desired height and width.
137
+ Args:
138
+ x (torch.Tensor): Input image
139
+ size (list or tuple): Desired size (height, width)
140
+
141
+ Returns
142
+ -------
143
+ The padded tensor
144
+ """
145
+ return _pad_crop_img(x, size, "pad")
146
+
147
+
148
+ def crop_img_tensor(x, size) -> torch.Tensor:
149
+ """Crops a tensor.
150
+ Crops a tensor of shape (batch, channels, h, w) to a desired height and width
151
+ given by a tuple.
152
+ Args:
153
+ x (torch.Tensor): Input image
154
+ size (list or tuple): Desired size (height, width)
155
+
156
+ Returns
157
+ -------
158
+ The cropped tensor
159
+ """
160
+ return _pad_crop_img(x, size, "crop")
161
+
162
+
163
+ class StableExponential:
164
+ """
165
+ Class that redefines the definition of exp() to increase numerical stability.
166
+ Naturally, also the definition of log() must change accordingly.
167
+ However, it is worth noting that the two operations remain one the inverse of the other,
168
+ meaning that x = log(exp(x)) and x = exp(log(x)) are always true.
169
+
170
+ Definition:
171
+ exp(x) = {
172
+ exp(x) if x<=0
173
+ x+1 if x>0
174
+ }
175
+
176
+ log(x) = {
177
+ x if x<=0
178
+ log(1+x) if x>0
179
+ }
180
+
181
+ NOTE 1:
182
+ Within the class everything is done on the tensor given as input to the constructor.
183
+ Therefore, when exp() is called, self._tensor.exp() is computed.
184
+ When log() is called, torch.log(self._tensor.exp()) is computed instead.
185
+
186
+ NOTE 2:
187
+ Given the output from exp(), torch.log() or the log() method of the class give identical results.
188
+ """
189
+
190
+ def __init__(self, tensor):
191
+ self._raw_tensor = tensor
192
+ posneg_dic = self.posneg_separation(self._raw_tensor)
193
+ self.pos_f, self.neg_f = posneg_dic["filter"]
194
+ self.pos_data, self.neg_data = posneg_dic["value"]
195
+
196
+ def posneg_separation(self, tensor):
197
+ pos = tensor > 0
198
+ pos_tensor = torch.clip(tensor, min=0)
199
+
200
+ neg = tensor <= 0
201
+ neg_tensor = torch.clip(tensor, max=0)
202
+
203
+ return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]}
204
+
205
+ def exp(self):
206
+ return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f
207
+
208
+ def log(self):
209
+ return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f
210
+
211
+
212
+ class StableLogVar:
213
+ """
214
+ Class that provides a numerically stable implementation of Log-Variance.
215
+ Specifically, it uses the exp() and log() formulas defined in `StableExponential` class.
216
+ """
217
+
218
+ def __init__(
219
+ self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
220
+ ):
221
+ """
222
+ Contructor.
223
+
224
+ Parameters
225
+ ----------
226
+ logvar: torch.Tensor
227
+ The input (true) logvar vector, to be converted in the Stable version.
228
+ enable_stable: bool, optional
229
+ Whether to compute the stable version of log-variance. Default is `True`.
230
+ var_eps: float, optional
231
+ The minimum value attainable by the variance. Default is `1e-6`.
232
+ """
233
+ self._lv = logvar
234
+ self._enable_stable = enable_stable
235
+ self._eps = var_eps
236
+
237
+ def get(self) -> torch.Tensor:
238
+ if self._enable_stable is False:
239
+ return self._lv
240
+
241
+ return torch.log(self.get_var())
242
+
243
+ def get_var(self) -> torch.Tensor:
244
+ """
245
+ Get Variance from Log-Variance.
246
+ """
247
+ if self._enable_stable is False:
248
+ return torch.exp(self._lv)
249
+ return StableExponential(self._lv).exp() + self._eps
250
+
251
+ def get_std(self) -> torch.Tensor:
252
+ return torch.sqrt(self.get_var())
253
+
254
+ def centercrop_to_size(self, size: Iterable[int]) -> None:
255
+ """
256
+ Centercrop the log-variance tensor to the desired size.
257
+
258
+ Parameters
259
+ ----------
260
+ size: torch.Tensor
261
+ The desired size of the log-variance tensor.
262
+ """
263
+ if self._lv.shape[-1] == size:
264
+ return
265
+
266
+ diff = self._lv.shape[-1] - size
267
+ assert diff > 0 and diff % 2 == 0
268
+ self._lv = F.center_crop(self._lv, (size, size))
269
+
270
+
271
+ class StableMean:
272
+
273
+ def __init__(self, mean):
274
+ self._mean = mean
275
+
276
+ def get(self) -> torch.Tensor:
277
+ return self._mean
278
+
279
+ def centercrop_to_size(self, size: Iterable[int]) -> None:
280
+ """
281
+ Centercrop the mean tensor to the desired size.
282
+
283
+ Parameters
284
+ ----------
285
+ size: torch.Tensor
286
+ The desired size of the log-variance tensor.
287
+ """
288
+ if self._mean.shape[-1] == size:
289
+ return
290
+
291
+ diff = self._mean.shape[-1] - size
292
+ assert diff > 0 and diff % 2 == 0
293
+ self._mean = F.center_crop(self._mean, (size, size))
294
+
295
+
296
+ def allow_numpy(func):
297
+ """
298
+ All optional arguements are passed as is. positional arguments are checked. if they are numpy array,
299
+ they are converted to torch Tensor.
300
+ """
301
+
302
+ def numpy_wrapper(*args, **kwargs):
303
+ new_args = []
304
+ for arg in args:
305
+ if isinstance(arg, np.ndarray):
306
+ arg = torch.Tensor(arg)
307
+ new_args.append(arg)
308
+ new_args = tuple(new_args)
309
+
310
+ output = func(*new_args, **kwargs)
311
+ return output
312
+
313
+ return numpy_wrapper
314
+
315
+
316
+ class Interpolate(nn.Module):
317
+ """Wrapper for torch.nn.functional.interpolate."""
318
+
319
+ def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False):
320
+ super().__init__()
321
+ assert (size is None) == (scale is not None)
322
+ self.size = size
323
+ self.scale = scale
324
+ self.mode = mode
325
+ self.align_corners = align_corners
326
+
327
+ def forward(self, x):
328
+ out = F.interpolate(
329
+ x,
330
+ size=self.size,
331
+ scale_factor=self.scale,
332
+ mode=self.mode,
333
+ align_corners=self.align_corners,
334
+ )
335
+ return out
336
+
337
+
338
+ def kl_normal_mc(z, p_mulv, q_mulv):
339
+ """
340
+ One-sample estimation of element-wise KL between two diagonal
341
+ multivariate normal distributions. Any number of dimensions,
342
+ broadcasting supported (be careful).
343
+ :param z:
344
+ :param p_mulv:
345
+ :param q_mulv:
346
+ :return:
347
+ """
348
+ assert isinstance(p_mulv, tuple)
349
+ assert isinstance(q_mulv, tuple)
350
+ p_mu, p_lv = p_mulv
351
+ q_mu, q_lv = q_mulv
352
+
353
+ p_std = p_lv.get_std()
354
+ q_std = q_lv.get_std()
355
+
356
+ p_distrib = Normal(p_mu.get(), p_std)
357
+ q_distrib = Normal(q_mu.get(), q_std)
358
+ return q_distrib.log_prob(z) - p_distrib.log_prob(z)
359
+
360
+
361
+ def free_bits_kl(
362
+ kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
363
+ ) -> torch.Tensor:
364
+ """
365
+ Computes free-bits version of KL divergence.
366
+ Ensures that the KL doesn't go to zero for any latent dimension.
367
+ Hence, it contributes to use latent variables more efficiently,
368
+ leading to better representation learning.
369
+
370
+ NOTE:
371
+ Takes in the KL with shape (batch size, layers), returns the KL with
372
+ free bits (for optimization) with shape (layers,), which is the average
373
+ free-bits KL per layer in the current batch.
374
+ If batch_average is False (default), the free bits are per layer and
375
+ per batch element. Otherwise, the free bits are still per layer, but
376
+ are assigned on average to the whole batch. In both cases, the batch
377
+ average is returned, so it's simply a matter of doing mean(clamp(KL))
378
+ or clamp(mean(KL)).
379
+
380
+ Args:
381
+ kl (torch.Tensor)
382
+ free_bits (float)
383
+ batch_average (bool, optional))
384
+ eps (float, optional)
385
+
386
+ Returns
387
+ -------
388
+ The KL with free bits
389
+ """
390
+ assert kl.dim() == 2
391
+ if free_bits < eps:
392
+ return kl.mean(0)
393
+ if batch_average:
394
+ return kl.mean(0).clamp(min=free_bits)
395
+ return kl.clamp(min=free_bits).mean(0)
@@ -0,0 +1,10 @@
1
+ """Package to house various prediction utilies."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "stitch_prediction_single",
6
+ "convert_outputs",
7
+ ]
8
+
9
+ from .prediction_outputs import convert_outputs
10
+ from .stitch_prediction import stitch_prediction, stitch_prediction_single
@@ -0,0 +1,137 @@
1
+ """Module containing functions to convert prediction outputs to desired form."""
2
+
3
+ from typing import Any, List, Literal, Tuple, Union, overload
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from ..config.tile_information import TileInformation
9
+ from .stitch_prediction import stitch_prediction
10
+
11
+
12
+ def convert_outputs(
13
+ predictions: List[Any], tiled: bool
14
+ ) -> Union[List[NDArray], NDArray]:
15
+ """
16
+ Convert the Lightning trainer outputs to the desired form.
17
+
18
+ This method allows stitching back together tiled predictions.
19
+
20
+ Parameters
21
+ ----------
22
+ predictions : list
23
+ Predictions that are output from `Trainer.predict`.
24
+ tiled : bool
25
+ Whether the predictions are tiled.
26
+
27
+ Returns
28
+ -------
29
+ list of numpy.ndarray or numpy.ndarray
30
+ List of arrays with the axes SC(Z)YX. If there is only 1 output it will not
31
+ be in a list.
32
+ """
33
+ if len(predictions) == 0:
34
+ return predictions
35
+
36
+ # this layout is to stop mypy complaining
37
+ if tiled:
38
+ predictions_comb = combine_batches(predictions, tiled)
39
+ predictions_output = stitch_prediction(*predictions_comb)
40
+ else:
41
+ predictions_output = combine_batches(predictions, tiled)
42
+
43
+ return predictions_output
44
+
45
+
46
+ # for mypy
47
+ @overload
48
+ def combine_batches( # numpydoc ignore=GL08
49
+ predictions: List[Any], tiled: Literal[True]
50
+ ) -> Tuple[List[NDArray], List[TileInformation]]: ...
51
+
52
+
53
+ # for mypy
54
+ @overload
55
+ def combine_batches( # numpydoc ignore=GL08
56
+ predictions: List[Any], tiled: Literal[False]
57
+ ) -> List[NDArray]: ...
58
+
59
+
60
+ # for mypy
61
+ @overload
62
+ def combine_batches( # numpydoc ignore=GL08
63
+ predictions: List[Any], tiled: Union[bool, Literal[True], Literal[False]]
64
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]: ...
65
+
66
+
67
+ def combine_batches(
68
+ predictions: List[Any], tiled: bool
69
+ ) -> Union[List[NDArray], Tuple[List[NDArray], List[TileInformation]]]:
70
+ """
71
+ If predictions are in batches, they will be combined.
72
+
73
+ Parameters
74
+ ----------
75
+ predictions : list
76
+ Predictions that are output from `Trainer.predict`.
77
+ tiled : bool
78
+ Whether the predictions are tiled.
79
+
80
+ Returns
81
+ -------
82
+ (list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
83
+ Combined batches.
84
+ """
85
+ if tiled:
86
+ return _combine_tiled_batches(predictions)
87
+ else:
88
+ return _combine_array_batches(predictions)
89
+
90
+
91
+ def _combine_tiled_batches(
92
+ predictions: List[Tuple[NDArray, List[TileInformation]]]
93
+ ) -> Tuple[List[NDArray], List[TileInformation]]:
94
+ """
95
+ Combine batches from tiled output.
96
+
97
+ Parameters
98
+ ----------
99
+ predictions : list of (numpy.ndarray, list of TileInformation)
100
+ Predictions that are output from `Trainer.predict`. For tiled batches, this is
101
+ a list of tuples. The first element of the tuples is the prediction output of
102
+ tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
103
+ element of the tuples is a list of TileInformation objects of length B.
104
+
105
+ Returns
106
+ -------
107
+ tuple of (list of numpy.ndarray, list of TileInformation)
108
+ Combined batches.
109
+ """
110
+ # turn list of lists into single list
111
+ tile_infos = [
112
+ tile_info for _, tile_info_list in predictions for tile_info in tile_info_list
113
+ ]
114
+ prediction_tiles: List[NDArray] = _combine_array_batches(
115
+ [preds for preds, _ in predictions]
116
+ )
117
+ return prediction_tiles, tile_infos
118
+
119
+
120
+ def _combine_array_batches(predictions: List[NDArray]) -> List[NDArray]:
121
+ """
122
+ Combine batches of arrays.
123
+
124
+ Parameters
125
+ ----------
126
+ predictions : list
127
+ Prediction arrays that are output from `Trainer.predict`. A list of arrays that
128
+ have dimensions (B, C, (Z), Y, X), where B is batch size.
129
+
130
+ Returns
131
+ -------
132
+ list of numpy.ndarray
133
+ A list of arrays with dimensions (1, C, (Z), Y, X).
134
+ """
135
+ prediction_concat: NDArray = np.concatenate(predictions, axis=0)
136
+ prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
137
+ return prediction_split
@@ -0,0 +1,103 @@
1
+ """Prediction utility functions."""
2
+
3
+ import builtins
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.tile_information import TileInformation
10
+
11
+
12
+ # TODO: why not allow input and output of torch.tensor ?
13
+ def stitch_prediction(
14
+ tiles: List[np.ndarray],
15
+ tile_infos: List[TileInformation],
16
+ ) -> List[np.ndarray]:
17
+ """
18
+ Stitch tiles back together to form a full image(s).
19
+
20
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
21
+ singleton dimension.
22
+
23
+ Parameters
24
+ ----------
25
+ tiles : list of numpy.ndarray
26
+ Cropped tiles and their respective stitching coordinates. Can contain tiles
27
+ from multiple images.
28
+ tile_infos : list of TileInformation
29
+ List of information and coordinates obtained from
30
+ `dataset.tiled_patching.extract_tiles`.
31
+
32
+ Returns
33
+ -------
34
+ list of numpy.ndarray
35
+ Full image(s).
36
+ """
37
+ # Find where to split the lists so that only info from one image is contained.
38
+ # Do this by locating the last tiles of each image.
39
+ last_tiles = [tile_info.last_tile for tile_info in tile_infos]
40
+ last_tile_position = np.where(last_tiles)[0]
41
+ image_slices = [
42
+ slice(
43
+ None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
44
+ )
45
+ for i in range(len(last_tile_position))
46
+ ]
47
+ image_predictions = []
48
+ # slice the lists and apply stitch_prediction_single to each in turn.
49
+ for image_slice in image_slices:
50
+ image_predictions.append(
51
+ stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
52
+ )
53
+ return image_predictions
54
+
55
+
56
+ def stitch_prediction_single(
57
+ tiles: List[NDArray],
58
+ tile_infos: List[TileInformation],
59
+ ) -> NDArray:
60
+ """
61
+ Stitch tiles back together to form a full image.
62
+
63
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
64
+ singleton dimension.
65
+
66
+ Parameters
67
+ ----------
68
+ tiles : list of numpy.ndarray
69
+ Cropped tiles and their respective stitching coordinates.
70
+ tile_infos : list of TileInformation
71
+ List of information and coordinates obtained from
72
+ `dataset.tiled_patching.extract_tiles`.
73
+
74
+ Returns
75
+ -------
76
+ numpy.ndarray
77
+ Full image, with dimensions SC(Z)YX.
78
+ """
79
+ # retrieve whole array size
80
+ input_shape = tile_infos[0].array_shape
81
+ predicted_image = np.zeros(input_shape, dtype=np.float32)
82
+
83
+ # reshape
84
+ # TODO: can be more elegantly solved if TileInformation allows singleton dims
85
+ singleton_dims = tuple(np.where(np.array(tiles[0].shape) == 1)[0])
86
+ predicted_image = np.expand_dims(predicted_image, singleton_dims)
87
+
88
+ for tile, tile_info in zip(tiles, tile_infos):
89
+
90
+ # Compute coordinates for cropping predicted tile
91
+ crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
92
+ ...,
93
+ *[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
94
+ )
95
+
96
+ # Crop predited tile according to overlap coordinates
97
+ cropped_tile = tile[crop_slices]
98
+
99
+ # Insert cropped tile into predicted image using stitch coordinates
100
+ image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
101
+ predicted_image[image_slices] = cropped_tile.astype(np.float32)
102
+
103
+ return predicted_image
@@ -60,7 +60,7 @@ class N2VManipulate(Transform):
60
60
  remove_center: bool = True,
61
61
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
62
62
  struct_mask_span: int = 5,
63
- seed: Optional[int] = None, # TODO use in pixel manipulation
63
+ seed: Optional[int] = None,
64
64
  ):
65
65
  """Constructor.
66
66
 
@@ -127,6 +127,7 @@ class N2VManipulate(Transform):
127
127
  subpatch_size=self.roi_size,
128
128
  remove_center=self.remove_center,
129
129
  struct_params=self.struct_mask,
130
+ rng=self.rng,
130
131
  )
131
132
  elif self.strategy == SupportedPixelManipulation.MEDIAN:
132
133
  # Iterate over the channels to apply manipulation separately
@@ -136,6 +137,7 @@ class N2VManipulate(Transform):
136
137
  mask_pixel_percentage=self.masked_pixel_percentage,
137
138
  subpatch_size=self.roi_size,
138
139
  struct_params=self.struct_mask,
140
+ rng=self.rng,
139
141
  )
140
142
  else:
141
143
  raise ValueError(f"Unknown masking strategy ({self.strategy}).")