careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,395 @@
1
+ """
2
+ Script for utility functions needed by the LVAE model.
3
+ """
4
+
5
+ from typing import Iterable
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.transforms.functional as F
11
+ from torch.distributions.normal import Normal
12
+
13
+
14
+ def torch_nanmean(inp):
15
+ return torch.mean(inp[~inp.isnan()])
16
+
17
+
18
+ def compute_batch_mean(x):
19
+ N = len(x)
20
+ return x.view(N, -1).mean(dim=1)
21
+
22
+
23
+ def power_of_2(self, x):
24
+ assert isinstance(x, int)
25
+ if x == 1:
26
+ return True
27
+ if x == 0:
28
+ # happens with validation
29
+ return False
30
+ if x % 2 == 1:
31
+ return False
32
+ return self.power_of_2(x // 2)
33
+
34
+
35
+ class Enum:
36
+ @classmethod
37
+ def name(cls, enum_type):
38
+ for key, value in cls.__dict__.items():
39
+ if enum_type == value:
40
+ return key
41
+
42
+ @classmethod
43
+ def contains(cls, enum_type):
44
+ for key, value in cls.__dict__.items():
45
+ if enum_type == value:
46
+ return True
47
+ return False
48
+
49
+ @classmethod
50
+ def from_name(cls, enum_type_str):
51
+ for key, value in cls.__dict__.items():
52
+ if key == enum_type_str:
53
+ return value
54
+ assert f"{cls.__name__}:{enum_type_str} doesnot exist."
55
+
56
+
57
+ class LossType(Enum):
58
+ Elbo = 0
59
+ ElboWithCritic = 1
60
+ ElboMixedReconstruction = 2
61
+ MSE = 3
62
+ ElboWithNbrConsistency = 4
63
+ ElboSemiSupMixedReconstruction = 5
64
+ ElboCL = 6
65
+ ElboRestrictedReconstruction = 7
66
+ DenoiSplitMuSplit = 8
67
+
68
+
69
+ class ModelType(Enum):
70
+ LadderVae = 3
71
+ LadderVaeTwinDecoder = 4
72
+ LadderVAECritic = 5
73
+ # Separate vampprior: two optimizers
74
+ LadderVaeSepVampprior = 6
75
+ # one encoder for mixed input, two for separate inputs.
76
+ LadderVaeSepEncoder = 7
77
+ LadderVAEMultiTarget = 8
78
+ LadderVaeSepEncoderSingleOptim = 9
79
+ UNet = 10
80
+ BraveNet = 11
81
+ LadderVaeStitch = 12
82
+ LadderVaeSemiSupervised = 13
83
+ LadderVaeStitch2Stage = 14 # Note that previously trained models will have issue.
84
+ # since earlier, LadderVaeStitch2Stage = 13, LadderVaeSemiSupervised = 14
85
+ LadderVaeMixedRecons = 15
86
+ LadderVaeCL = 16
87
+ LadderVaeTwoDataSet = (
88
+ 17 # on one subdset, apply disentanglement, on other apply reconstruction
89
+ )
90
+ LadderVaeTwoDatasetMultiBranch = 18
91
+ LadderVaeTwoDatasetMultiOptim = 19
92
+ LVaeDeepEncoderIntensityAug = 20
93
+ AutoRegresiveLadderVAE = 21
94
+ LadderVAEInterleavedOptimization = 22
95
+ Denoiser = 23
96
+ DenoiserSplitter = 24
97
+ SplitterDenoiser = 25
98
+ LadderVAERestrictedReconstruction = 26
99
+ LadderVAETwoDataSetRestRecon = 27
100
+ LadderVAETwoDataSetFinetuning = 28
101
+
102
+
103
+ def _pad_crop_img(x, size, mode) -> torch.Tensor:
104
+ """Pads or crops a tensor.
105
+ Pads or crops a tensor of shape (batch, channels, h, w) to new height
106
+ and width given by a tuple.
107
+ Args:
108
+ x (torch.Tensor): Input image
109
+ size (list or tuple): Desired size (height, width)
110
+ mode (str): Mode, either 'pad' or 'crop'
111
+ Returns:
112
+ The padded or cropped tensor
113
+ """
114
+ assert x.dim() == 4 and len(size) == 2
115
+ size = tuple(size)
116
+ x_size = x.size()[2:4]
117
+ if mode == "pad":
118
+ cond = x_size[0] > size[0] or x_size[1] > size[1]
119
+ elif mode == "crop":
120
+ cond = x_size[0] < size[0] or x_size[1] < size[1]
121
+ else:
122
+ raise ValueError(f"invalid mode '{mode}'")
123
+ if cond:
124
+ raise ValueError(f"trying to {mode} from size {x_size} to size {size}")
125
+ dr, dc = (abs(x_size[0] - size[0]), abs(x_size[1] - size[1]))
126
+ dr1, dr2 = dr // 2, dr - (dr // 2)
127
+ dc1, dc2 = dc // 2, dc - (dc // 2)
128
+ if mode == "pad":
129
+ return nn.functional.pad(x, [dc1, dc2, dr1, dr2, 0, 0, 0, 0])
130
+ elif mode == "crop":
131
+ return x[:, :, dr1 : x_size[0] - dr2, dc1 : x_size[1] - dc2]
132
+
133
+
134
+ def pad_img_tensor(x, size) -> torch.Tensor:
135
+ """Pads a tensor.
136
+ Pads a tensor of shape (batch, channels, h, w) to a desired height and width.
137
+ Args:
138
+ x (torch.Tensor): Input image
139
+ size (list or tuple): Desired size (height, width)
140
+
141
+ Returns
142
+ -------
143
+ The padded tensor
144
+ """
145
+ return _pad_crop_img(x, size, "pad")
146
+
147
+
148
+ def crop_img_tensor(x, size) -> torch.Tensor:
149
+ """Crops a tensor.
150
+ Crops a tensor of shape (batch, channels, h, w) to a desired height and width
151
+ given by a tuple.
152
+ Args:
153
+ x (torch.Tensor): Input image
154
+ size (list or tuple): Desired size (height, width)
155
+
156
+ Returns
157
+ -------
158
+ The cropped tensor
159
+ """
160
+ return _pad_crop_img(x, size, "crop")
161
+
162
+
163
+ class StableExponential:
164
+ """
165
+ Class that redefines the definition of exp() to increase numerical stability.
166
+ Naturally, also the definition of log() must change accordingly.
167
+ However, it is worth noting that the two operations remain one the inverse of the other,
168
+ meaning that x = log(exp(x)) and x = exp(log(x)) are always true.
169
+
170
+ Definition:
171
+ exp(x) = {
172
+ exp(x) if x<=0
173
+ x+1 if x>0
174
+ }
175
+
176
+ log(x) = {
177
+ x if x<=0
178
+ log(1+x) if x>0
179
+ }
180
+
181
+ NOTE 1:
182
+ Within the class everything is done on the tensor given as input to the constructor.
183
+ Therefore, when exp() is called, self._tensor.exp() is computed.
184
+ When log() is called, torch.log(self._tensor.exp()) is computed instead.
185
+
186
+ NOTE 2:
187
+ Given the output from exp(), torch.log() or the log() method of the class give identical results.
188
+ """
189
+
190
+ def __init__(self, tensor):
191
+ self._raw_tensor = tensor
192
+ posneg_dic = self.posneg_separation(self._raw_tensor)
193
+ self.pos_f, self.neg_f = posneg_dic["filter"]
194
+ self.pos_data, self.neg_data = posneg_dic["value"]
195
+
196
+ def posneg_separation(self, tensor):
197
+ pos = tensor > 0
198
+ pos_tensor = torch.clip(tensor, min=0)
199
+
200
+ neg = tensor <= 0
201
+ neg_tensor = torch.clip(tensor, max=0)
202
+
203
+ return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]}
204
+
205
+ def exp(self):
206
+ return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f
207
+
208
+ def log(self):
209
+ return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f
210
+
211
+
212
+ class StableLogVar:
213
+ """
214
+ Class that provides a numerically stable implementation of Log-Variance.
215
+ Specifically, it uses the exp() and log() formulas defined in `StableExponential` class.
216
+ """
217
+
218
+ def __init__(
219
+ self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
220
+ ):
221
+ """
222
+ Contructor.
223
+
224
+ Parameters
225
+ ----------
226
+ logvar: torch.Tensor
227
+ The input (true) logvar vector, to be converted in the Stable version.
228
+ enable_stable: bool, optional
229
+ Whether to compute the stable version of log-variance. Default is `True`.
230
+ var_eps: float, optional
231
+ The minimum value attainable by the variance. Default is `1e-6`.
232
+ """
233
+ self._lv = logvar
234
+ self._enable_stable = enable_stable
235
+ self._eps = var_eps
236
+
237
+ def get(self) -> torch.Tensor:
238
+ if self._enable_stable is False:
239
+ return self._lv
240
+
241
+ return torch.log(self.get_var())
242
+
243
+ def get_var(self) -> torch.Tensor:
244
+ """
245
+ Get Variance from Log-Variance.
246
+ """
247
+ if self._enable_stable is False:
248
+ return torch.exp(self._lv)
249
+ return StableExponential(self._lv).exp() + self._eps
250
+
251
+ def get_std(self) -> torch.Tensor:
252
+ return torch.sqrt(self.get_var())
253
+
254
+ def centercrop_to_size(self, size: Iterable[int]) -> None:
255
+ """
256
+ Centercrop the log-variance tensor to the desired size.
257
+
258
+ Parameters
259
+ ----------
260
+ size: torch.Tensor
261
+ The desired size of the log-variance tensor.
262
+ """
263
+ if self._lv.shape[-1] == size:
264
+ return
265
+
266
+ diff = self._lv.shape[-1] - size
267
+ assert diff > 0 and diff % 2 == 0
268
+ self._lv = F.center_crop(self._lv, (size, size))
269
+
270
+
271
+ class StableMean:
272
+
273
+ def __init__(self, mean):
274
+ self._mean = mean
275
+
276
+ def get(self) -> torch.Tensor:
277
+ return self._mean
278
+
279
+ def centercrop_to_size(self, size: Iterable[int]) -> None:
280
+ """
281
+ Centercrop the mean tensor to the desired size.
282
+
283
+ Parameters
284
+ ----------
285
+ size: torch.Tensor
286
+ The desired size of the log-variance tensor.
287
+ """
288
+ if self._mean.shape[-1] == size:
289
+ return
290
+
291
+ diff = self._mean.shape[-1] - size
292
+ assert diff > 0 and diff % 2 == 0
293
+ self._mean = F.center_crop(self._mean, (size, size))
294
+
295
+
296
+ def allow_numpy(func):
297
+ """
298
+ All optional arguements are passed as is. positional arguments are checked. if they are numpy array,
299
+ they are converted to torch Tensor.
300
+ """
301
+
302
+ def numpy_wrapper(*args, **kwargs):
303
+ new_args = []
304
+ for arg in args:
305
+ if isinstance(arg, np.ndarray):
306
+ arg = torch.Tensor(arg)
307
+ new_args.append(arg)
308
+ new_args = tuple(new_args)
309
+
310
+ output = func(*new_args, **kwargs)
311
+ return output
312
+
313
+ return numpy_wrapper
314
+
315
+
316
+ class Interpolate(nn.Module):
317
+ """Wrapper for torch.nn.functional.interpolate."""
318
+
319
+ def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False):
320
+ super().__init__()
321
+ assert (size is None) == (scale is not None)
322
+ self.size = size
323
+ self.scale = scale
324
+ self.mode = mode
325
+ self.align_corners = align_corners
326
+
327
+ def forward(self, x):
328
+ out = F.interpolate(
329
+ x,
330
+ size=self.size,
331
+ scale_factor=self.scale,
332
+ mode=self.mode,
333
+ align_corners=self.align_corners,
334
+ )
335
+ return out
336
+
337
+
338
+ def kl_normal_mc(z, p_mulv, q_mulv):
339
+ """
340
+ One-sample estimation of element-wise KL between two diagonal
341
+ multivariate normal distributions. Any number of dimensions,
342
+ broadcasting supported (be careful).
343
+ :param z:
344
+ :param p_mulv:
345
+ :param q_mulv:
346
+ :return:
347
+ """
348
+ assert isinstance(p_mulv, tuple)
349
+ assert isinstance(q_mulv, tuple)
350
+ p_mu, p_lv = p_mulv
351
+ q_mu, q_lv = q_mulv
352
+
353
+ p_std = p_lv.get_std()
354
+ q_std = q_lv.get_std()
355
+
356
+ p_distrib = Normal(p_mu.get(), p_std)
357
+ q_distrib = Normal(q_mu.get(), q_std)
358
+ return q_distrib.log_prob(z) - p_distrib.log_prob(z)
359
+
360
+
361
+ def free_bits_kl(
362
+ kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
363
+ ) -> torch.Tensor:
364
+ """
365
+ Computes free-bits version of KL divergence.
366
+ Ensures that the KL doesn't go to zero for any latent dimension.
367
+ Hence, it contributes to use latent variables more efficiently,
368
+ leading to better representation learning.
369
+
370
+ NOTE:
371
+ Takes in the KL with shape (batch size, layers), returns the KL with
372
+ free bits (for optimization) with shape (layers,), which is the average
373
+ free-bits KL per layer in the current batch.
374
+ If batch_average is False (default), the free bits are per layer and
375
+ per batch element. Otherwise, the free bits are still per layer, but
376
+ are assigned on average to the whole batch. In both cases, the batch
377
+ average is returned, so it's simply a matter of doing mean(clamp(KL))
378
+ or clamp(mean(KL)).
379
+
380
+ Args:
381
+ kl (torch.Tensor)
382
+ free_bits (float)
383
+ batch_average (bool, optional))
384
+ eps (float, optional)
385
+
386
+ Returns
387
+ -------
388
+ The KL with free bits
389
+ """
390
+ assert kl.dim() == 2
391
+ if free_bits < eps:
392
+ return kl.mean(0)
393
+ if batch_average:
394
+ return kl.mean(0).clamp(min=free_bits)
395
+ return kl.clamp(min=free_bits).mean(0)
@@ -27,7 +27,7 @@ def model_factory(
27
27
  Parameters
28
28
  ----------
29
29
  model_configuration : Union[UNetModel, VAEModel]
30
- Model configuration
30
+ Model configuration.
31
31
 
32
32
  Returns
33
33
  -------
careamics/models/unet.py CHANGED
@@ -34,7 +34,9 @@ class UnetEncoder(nn.Module):
34
34
  Dropout probability, by default 0.0.
35
35
  pool_kernel : int, optional
36
36
  Kernel size for the max pooling layers, by default 2.
37
- groups: int, optional
37
+ n2v2 : bool, optional
38
+ Whether to use N2V2 architecture, by default False.
39
+ groups : int, optional
38
40
  Number of blocked connections from input channels to output
39
41
  channels, by default 1.
40
42
  """
@@ -70,7 +72,9 @@ class UnetEncoder(nn.Module):
70
72
  Dropout probability, by default 0.0.
71
73
  pool_kernel : int, optional
72
74
  Kernel size for the max pooling layers, by default 2.
73
- groups: int, optional
75
+ n2v2 : bool, optional
76
+ Whether to use N2V2 architecture, by default False.
77
+ groups : int, optional
74
78
  Number of blocked connections from input channels to output
75
79
  channels, by default 1.
76
80
  """
@@ -140,7 +144,9 @@ class UnetDecoder(nn.Module):
140
144
  Whether to use batch normalization, by default True.
141
145
  dropout : float, optional
142
146
  Dropout probability, by default 0.0.
143
- groups: int, optional
147
+ n2v2 : bool, optional
148
+ Whether to use N2V2 architecture, by default False.
149
+ groups : int, optional
144
150
  Number of blocked connections from input channels to output
145
151
  channels, by default 1.
146
152
  """
@@ -170,7 +176,9 @@ class UnetDecoder(nn.Module):
170
176
  Whether to use batch normalization, by default True.
171
177
  dropout : float, optional
172
178
  Dropout probability, by default 0.0.
173
- groups: int, optional
179
+ n2v2 : bool, optional
180
+ Whether to use N2V2 architecture, by default False.
181
+ groups : int, optional
174
182
  Number of blocked connections from input channels to output
175
183
  channels, by default 1.
176
184
  """
@@ -250,22 +258,25 @@ class UnetDecoder(nn.Module):
250
258
 
251
259
  @staticmethod
252
260
  def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
253
- """
254
- Splits the tensors `A` and `B` into equally sized groups along the
255
- channel axis (axis=1); then concatenates the groups in alternating
256
- order along the channel axis, starting with the first group from tensor
257
- A.
261
+ """Interleave two tensors.
262
+
263
+ Splits the tensors `A` and `B` into equally sized groups along the channel
264
+ axis (axis=1); then concatenates the groups in alternating order along the
265
+ channel axis, starting with the first group from tensor A.
258
266
 
259
267
  Parameters
260
268
  ----------
261
- A: torch.Tensor
262
- B: torch.Tensor
263
- groups: int
269
+ A : torch.Tensor
270
+ First tensor.
271
+ B : torch.Tensor
272
+ Second tensor.
273
+ groups : int
264
274
  The number of groups.
265
275
 
266
276
  Returns
267
277
  -------
268
278
  torch.Tensor
279
+ Interleaved tensor.
269
280
 
270
281
  Raises
271
282
  ------
@@ -322,8 +333,14 @@ class UNet(nn.Module):
322
333
  Dropout probability, by default 0.0.
323
334
  pool_kernel : int, optional
324
335
  Kernel size of the pooling layers, by default 2.
325
- last_activation : Optional[Callable], optional
336
+ final_activation : Optional[Callable], optional
326
337
  Activation function to use for the last layer, by default None.
338
+ n2v2 : bool, optional
339
+ Whether to use N2V2 architecture, by default False.
340
+ independent_channels : bool
341
+ Whether to train the channels independently, by default True.
342
+ **kwargs : Any
343
+ Additional keyword arguments, unused.
327
344
  """
328
345
 
329
346
  def __init__(
@@ -362,11 +379,15 @@ class UNet(nn.Module):
362
379
  Dropout probability, by default 0.0.
363
380
  pool_kernel : int, optional
364
381
  Kernel size of the pooling layers, by default 2.
365
- last_activation : Optional[Callable], optional
382
+ final_activation : Optional[Callable], optional
366
383
  Activation function to use for the last layer, by default None.
384
+ n2v2 : bool, optional
385
+ Whether to use N2V2 architecture, by default False.
367
386
  independent_channels : bool
368
387
  Whether to train parallel independent networks for each channel, by
369
388
  default True.
389
+ **kwargs : Any
390
+ Additional keyword arguments, unused.
370
391
  """
371
392
  super().__init__()
372
393
 
@@ -0,0 +1,12 @@
1
+ """Package to house various prediction utilies."""
2
+
3
+ __all__ = [
4
+ "create_pred_datamodule",
5
+ "stitch_prediction",
6
+ "stitch_prediction_single",
7
+ "convert_outputs",
8
+ ]
9
+
10
+ from .create_pred_datamodule import create_pred_datamodule
11
+ from .prediction_outputs import convert_outputs
12
+ from .stitch_prediction import stitch_prediction, stitch_prediction_single