careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ Layer module.
3
3
 
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
+
6
7
  from typing import List, Optional, Tuple, Union
7
8
 
8
9
  import torch
@@ -161,6 +162,18 @@ def _unpack_kernel_size(
161
162
  """Unpack kernel_size to a tuple of ints.
162
163
 
163
164
  Inspired by Kornia implementation. TODO: link
165
+
166
+ Parameters
167
+ ----------
168
+ kernel_size : Union[Tuple[int, ...], int]
169
+ Kernel size.
170
+ dim : int
171
+ Number of dimensions.
172
+
173
+ Returns
174
+ -------
175
+ Tuple[int, ...]
176
+ Kernel size tuple.
164
177
  """
165
178
  if isinstance(kernel_size, int):
166
179
  kernel_dims = tuple([kernel_size for _ in range(dim)])
@@ -172,7 +185,20 @@ def _unpack_kernel_size(
172
185
  def _compute_zero_padding(
173
186
  kernel_size: Union[Tuple[int, ...], int], dim: int
174
187
  ) -> Tuple[int, ...]:
175
- """Utility function that computes zero padding tuple."""
188
+ """Utility function that computes zero padding tuple.
189
+
190
+ Parameters
191
+ ----------
192
+ kernel_size : Union[Tuple[int, ...], int]
193
+ Kernel size.
194
+ dim : int
195
+ Number of dimensions.
196
+
197
+ Returns
198
+ -------
199
+ Tuple[int, ...]
200
+ Zero padding tuple.
201
+ """
176
202
  kernel_dims = _unpack_kernel_size(kernel_size, dim)
177
203
  return tuple([(kd - 1) // 2 for kd in kernel_dims])
178
204
 
@@ -190,14 +216,19 @@ def get_pascal_kernel_1d(
190
216
 
191
217
  Parameters
192
218
  ----------
193
- kernel_size: height and width of the kernel.
194
- norm: if to normalize the kernel or not. Default: False.
195
- device: tensor device
196
- dtype: tensor dtype
219
+ kernel_size : int
220
+ Kernel size.
221
+ norm : bool
222
+ Normalize the kernel, by default False.
223
+ device : Optional[torch.device]
224
+ Device of the tensor, by default None.
225
+ dtype : Optional[torch.dtype]
226
+ Data type of the tensor, by default None.
197
227
 
198
228
  Returns
199
229
  -------
200
- kernel shaped as :math:`(kernel_size,)`
230
+ torch.Tensor
231
+ Pascal kernel.
201
232
 
202
233
  Examples
203
234
  --------
@@ -244,19 +275,28 @@ def _get_pascal_kernel_nd(
244
275
  ) -> torch.Tensor:
245
276
  """Generate pascal filter kernel by kernel size.
246
277
 
278
+ If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
279
+ otherwise the kernel will be shaped as kernel_size
280
+
247
281
  Inspired by Kornia implementation.
248
282
 
249
283
  Parameters
250
284
  ----------
251
- kernel_size: height and width of the kernel.
252
- norm: if to normalize the kernel or not. Default: True.
253
- device: tensor device
254
- dtype: tensor dtype
285
+ kernel_size : Union[Tuple[int, int], int]
286
+ Kernel size for the pascal kernel.
287
+ norm : bool
288
+ Normalize the kernel, by default True.
289
+ dim : int
290
+ Number of dimensions, by default 2.
291
+ device : Optional[torch.device]
292
+ Device of the tensor, by default None.
293
+ dtype : Optional[torch.dtype]
294
+ Data type of the tensor, by default None.
255
295
 
256
296
  Returns
257
297
  -------
258
- if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
259
- otherwise the kernel will be shaped as kernel_size
298
+ torch.Tensor
299
+ Pascal kernel.
260
300
 
261
301
  Examples
262
302
  --------
@@ -302,6 +342,24 @@ def _max_blur_pool_by_kernel2d(
302
342
  """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
303
343
 
304
344
  Inspired by Kornia implementation.
345
+
346
+ Parameters
347
+ ----------
348
+ x : torch.Tensor
349
+ Input tensor.
350
+ kernel : torch.Tensor
351
+ Kernel tensor.
352
+ stride : int
353
+ Stride.
354
+ max_pool_size : int
355
+ Maximum pool size.
356
+ ceil_mode : bool
357
+ Ceil mode, by default False. Set to True to match output size of conv2d.
358
+
359
+ Returns
360
+ -------
361
+ torch.Tensor
362
+ Output tensor.
305
363
  """
306
364
  # compute local maxima
307
365
  x = F.max_pool2d(
@@ -322,6 +380,24 @@ def _max_blur_pool_by_kernel3d(
322
380
  """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
323
381
 
324
382
  Inspired by Kornia implementation.
383
+
384
+ Parameters
385
+ ----------
386
+ x : torch.Tensor
387
+ Input tensor.
388
+ kernel : torch.Tensor
389
+ Kernel tensor.
390
+ stride : int
391
+ Stride.
392
+ max_pool_size : int
393
+ Maximum pool size.
394
+ ceil_mode : bool
395
+ Ceil mode, by default False. Set to True to match output size of conv2d.
396
+
397
+ Returns
398
+ -------
399
+ torch.Tensor
400
+ Output tensor.
325
401
  """
326
402
  # compute local maxima
327
403
  x = F.max_pool3d(
@@ -342,21 +418,16 @@ class MaxBlurPool(nn.Module):
342
418
 
343
419
  Parameters
344
420
  ----------
345
- dim: int
346
- Toggles between 2D and 3D
347
- kernel_size: Union[Tuple[int, int], int]
421
+ dim : int
422
+ Toggles between 2D and 3D.
423
+ kernel_size : Union[Tuple[int, int], int]
348
424
  Kernel size for max pooling.
349
- stride: int
425
+ stride : int
350
426
  Stride for pooling.
351
- max_pool_size: int
427
+ max_pool_size : int
352
428
  Max kernel size for max pooling.
353
- ceil_mode: bool
354
- Should be true to match output size of conv2d with same kernel size.
355
-
356
- Returns
357
- -------
358
- torch.Tensor
359
- The pooled and blurred tensor.
429
+ ceil_mode : bool
430
+ Ceil mode, by default False. Set to True to match output size of conv2d.
360
431
  """
361
432
 
362
433
  def __init__(
@@ -367,6 +438,21 @@ class MaxBlurPool(nn.Module):
367
438
  max_pool_size: int = 2,
368
439
  ceil_mode: bool = False,
369
440
  ) -> None:
441
+ """Constructor.
442
+
443
+ Parameters
444
+ ----------
445
+ dim : int
446
+ Dimension of the convolution.
447
+ kernel_size : Union[Tuple[int, int], int]
448
+ Kernel size for max pooling.
449
+ stride : int, optional
450
+ Stride, by default 2.
451
+ max_pool_size : int, optional
452
+ Maximum pool size, by default 2.
453
+ ceil_mode : bool, optional
454
+ Ceil mode, by default False. Set to True to match output size of conv2d.
455
+ """
370
456
  super().__init__()
371
457
  self.dim = dim
372
458
  self.kernel_size = kernel_size
@@ -376,7 +462,18 @@ class MaxBlurPool(nn.Module):
376
462
  self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
377
463
 
378
464
  def forward(self, x: torch.Tensor) -> torch.Tensor:
379
- """Forward pass of the function."""
465
+ """Forward pass of the function.
466
+
467
+ Parameters
468
+ ----------
469
+ x : torch.Tensor
470
+ Input tensor.
471
+
472
+ Returns
473
+ -------
474
+ torch.Tensor
475
+ Output tensor.
476
+ """
380
477
  self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
381
478
  if self.dim == 2:
382
479
  return _max_blur_pool_by_kernel2d(
@@ -3,6 +3,7 @@ Model factory.
3
3
 
4
4
  Model creation factory functions.
5
5
  """
6
+
6
7
  from typing import Union
7
8
 
8
9
  import torch
@@ -26,7 +27,7 @@ def model_factory(
26
27
  Parameters
27
28
  ----------
28
29
  model_configuration : Union[UNetModel, VAEModel]
29
- Model configuration
30
+ Model configuration.
30
31
 
31
32
  Returns
32
33
  -------
careamics/models/unet.py CHANGED
@@ -3,7 +3,8 @@ UNet model.
3
3
 
4
4
  A UNet encoder, decoder and complete model.
5
5
  """
6
- from typing import Any, List, Union
6
+
7
+ from typing import Any, List, Tuple, Union
7
8
 
8
9
  import torch
9
10
  import torch.nn as nn
@@ -33,6 +34,11 @@ class UnetEncoder(nn.Module):
33
34
  Dropout probability, by default 0.0.
34
35
  pool_kernel : int, optional
35
36
  Kernel size for the max pooling layers, by default 2.
37
+ n2v2 : bool, optional
38
+ Whether to use N2V2 architecture, by default False.
39
+ groups : int, optional
40
+ Number of blocked connections from input channels to output
41
+ channels, by default 1.
36
42
  """
37
43
 
38
44
  def __init__(
@@ -45,6 +51,7 @@ class UnetEncoder(nn.Module):
45
51
  dropout: float = 0.0,
46
52
  pool_kernel: int = 2,
47
53
  n2v2: bool = False,
54
+ groups: int = 1,
48
55
  ) -> None:
49
56
  """
50
57
  Constructor.
@@ -65,6 +72,11 @@ class UnetEncoder(nn.Module):
65
72
  Dropout probability, by default 0.0.
66
73
  pool_kernel : int, optional
67
74
  Kernel size for the max pooling layers, by default 2.
75
+ n2v2 : bool, optional
76
+ Whether to use N2V2 architecture, by default False.
77
+ groups : int, optional
78
+ Number of blocked connections from input channels to output
79
+ channels, by default 1.
68
80
  """
69
81
  super().__init__()
70
82
 
@@ -77,7 +89,7 @@ class UnetEncoder(nn.Module):
77
89
  encoder_blocks = []
78
90
 
79
91
  for n in range(depth):
80
- out_channels = num_channels_init * (2**n)
92
+ out_channels = num_channels_init * (2**n) * groups
81
93
  in_channels = in_channels if n == 0 else out_channels // 2
82
94
  encoder_blocks.append(
83
95
  Conv_Block(
@@ -86,6 +98,7 @@ class UnetEncoder(nn.Module):
86
98
  out_channels=out_channels,
87
99
  dropout_perc=dropout,
88
100
  use_batch_norm=use_batch_norm,
101
+ groups=groups,
89
102
  )
90
103
  )
91
104
  encoder_blocks.append(self.pooling)
@@ -131,6 +144,11 @@ class UnetDecoder(nn.Module):
131
144
  Whether to use batch normalization, by default True.
132
145
  dropout : float, optional
133
146
  Dropout probability, by default 0.0.
147
+ n2v2 : bool, optional
148
+ Whether to use N2V2 architecture, by default False.
149
+ groups : int, optional
150
+ Number of blocked connections from input channels to output
151
+ channels, by default 1.
134
152
  """
135
153
 
136
154
  def __init__(
@@ -141,6 +159,7 @@ class UnetDecoder(nn.Module):
141
159
  use_batch_norm: bool = True,
142
160
  dropout: float = 0.0,
143
161
  n2v2: bool = False,
162
+ groups: int = 1,
144
163
  ) -> None:
145
164
  """
146
165
  Constructor.
@@ -157,15 +176,21 @@ class UnetDecoder(nn.Module):
157
176
  Whether to use batch normalization, by default True.
158
177
  dropout : float, optional
159
178
  Dropout probability, by default 0.0.
179
+ n2v2 : bool, optional
180
+ Whether to use N2V2 architecture, by default False.
181
+ groups : int, optional
182
+ Number of blocked connections from input channels to output
183
+ channels, by default 1.
160
184
  """
161
185
  super().__init__()
162
186
 
163
187
  upsampling = nn.Upsample(
164
188
  scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
165
189
  )
166
- in_channels = out_channels = num_channels_init * 2 ** (depth - 1)
190
+ in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
167
191
 
168
192
  self.n2v2 = n2v2
193
+ self.groups = groups
169
194
 
170
195
  self.bottleneck = Conv_Block(
171
196
  conv_dim,
@@ -174,34 +199,32 @@ class UnetDecoder(nn.Module):
174
199
  intermediate_channel_multiplier=2,
175
200
  use_batch_norm=use_batch_norm,
176
201
  dropout_perc=dropout,
202
+ groups=self.groups,
177
203
  )
178
204
 
179
- decoder_blocks = []
205
+ decoder_blocks: List[nn.Module] = []
180
206
  for n in range(depth):
181
207
  decoder_blocks.append(upsampling)
182
- in_channels = (
183
- num_channels_init ** (depth - n)
184
- if (self.n2v2 and n == depth - 1)
185
- else num_channels_init * 2 ** (depth - n)
186
- )
208
+ in_channels = (num_channels_init * 2 ** (depth - n)) * groups
187
209
  out_channels = in_channels // 2
188
210
  decoder_blocks.append(
189
211
  Conv_Block(
190
212
  conv_dim,
191
- in_channels=in_channels + in_channels // 2
192
- if n > 0
193
- else in_channels,
213
+ in_channels=(
214
+ in_channels + in_channels // 2 if n > 0 else in_channels
215
+ ),
194
216
  out_channels=out_channels,
195
217
  intermediate_channel_multiplier=2,
196
218
  dropout_perc=dropout,
197
219
  activation="ReLU",
198
220
  use_batch_norm=use_batch_norm,
221
+ groups=groups,
199
222
  )
200
223
  )
201
224
 
202
225
  self.decoder_blocks = nn.ModuleList(decoder_blocks)
203
226
 
204
- def forward(self, *features: List[torch.Tensor]) -> torch.Tensor:
227
+ def forward(self, *features: torch.Tensor) -> torch.Tensor:
205
228
  """
206
229
  Forward pass.
207
230
 
@@ -217,20 +240,73 @@ class UnetDecoder(nn.Module):
217
240
  Output of the decoder.
218
241
  """
219
242
  x: torch.Tensor = features[0]
220
- skip_connections: torch.Tensor = features[1:][::-1]
243
+ skip_connections: Tuple[torch.Tensor, ...] = features[-1:0:-1]
221
244
 
222
245
  x = self.bottleneck(x)
223
246
 
224
247
  for i, module in enumerate(self.decoder_blocks):
225
248
  x = module(x)
226
249
  if isinstance(module, nn.Upsample):
250
+ # divide index by 2 because of upsampling layers
251
+ skip_connection: torch.Tensor = skip_connections[i // 2]
227
252
  if self.n2v2:
228
253
  if x.shape != skip_connections[-1].shape:
229
- x = torch.cat([x, skip_connections[i // 2]], axis=1)
254
+ x = self._interleave(x, skip_connection, self.groups)
230
255
  else:
231
- x = torch.cat([x, skip_connections[i // 2]], axis=1)
256
+ x = self._interleave(x, skip_connection, self.groups)
232
257
  return x
233
258
 
259
+ @staticmethod
260
+ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
261
+ """Interleave two tensors.
262
+
263
+ Splits the tensors `A` and `B` into equally sized groups along the channel
264
+ axis (axis=1); then concatenates the groups in alternating order along the
265
+ channel axis, starting with the first group from tensor A.
266
+
267
+ Parameters
268
+ ----------
269
+ A : torch.Tensor
270
+ First tensor.
271
+ B : torch.Tensor
272
+ Second tensor.
273
+ groups : int
274
+ The number of groups.
275
+
276
+ Returns
277
+ -------
278
+ torch.Tensor
279
+ Interleaved tensor.
280
+
281
+ Raises
282
+ ------
283
+ ValueError:
284
+ If either of `A` or `B`'s channel axis is not divisible by `groups`.
285
+ """
286
+ if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
287
+ raise ValueError(f"Number of channels not divisible by {groups} groups.")
288
+
289
+ m = A.shape[1] // groups
290
+ n = B.shape[1] // groups
291
+
292
+ A_groups: List[torch.Tensor] = [
293
+ A[:, i * m : (i + 1) * m] for i in range(groups)
294
+ ]
295
+ B_groups: List[torch.Tensor] = [
296
+ B[:, i * n : (i + 1) * n] for i in range(groups)
297
+ ]
298
+
299
+ interleaved = torch.cat(
300
+ [
301
+ tensor_list[i]
302
+ for i in range(groups)
303
+ for tensor_list in [A_groups, B_groups]
304
+ ],
305
+ dim=1,
306
+ )
307
+
308
+ return interleaved
309
+
234
310
 
235
311
  class UNet(nn.Module):
236
312
  """
@@ -257,8 +333,14 @@ class UNet(nn.Module):
257
333
  Dropout probability, by default 0.0.
258
334
  pool_kernel : int, optional
259
335
  Kernel size of the pooling layers, by default 2.
260
- last_activation : Optional[Callable], optional
336
+ final_activation : Optional[Callable], optional
261
337
  Activation function to use for the last layer, by default None.
338
+ n2v2 : bool, optional
339
+ Whether to use N2V2 architecture, by default False.
340
+ independent_channels : bool
341
+ Whether to train the channels independently, by default True.
342
+ **kwargs : Any
343
+ Additional keyword arguments, unused.
262
344
  """
263
345
 
264
346
  def __init__(
@@ -273,6 +355,7 @@ class UNet(nn.Module):
273
355
  pool_kernel: int = 2,
274
356
  final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
275
357
  n2v2: bool = False,
358
+ independent_channels: bool = True,
276
359
  **kwargs: Any,
277
360
  ) -> None:
278
361
  """
@@ -296,11 +379,20 @@ class UNet(nn.Module):
296
379
  Dropout probability, by default 0.0.
297
380
  pool_kernel : int, optional
298
381
  Kernel size of the pooling layers, by default 2.
299
- last_activation : Optional[Callable], optional
382
+ final_activation : Optional[Callable], optional
300
383
  Activation function to use for the last layer, by default None.
384
+ n2v2 : bool, optional
385
+ Whether to use N2V2 architecture, by default False.
386
+ independent_channels : bool
387
+ Whether to train parallel independent networks for each channel, by
388
+ default True.
389
+ **kwargs : Any
390
+ Additional keyword arguments, unused.
301
391
  """
302
392
  super().__init__()
303
393
 
394
+ groups = in_channels if independent_channels else 1
395
+
304
396
  self.encoder = UnetEncoder(
305
397
  conv_dims,
306
398
  in_channels=in_channels,
@@ -310,6 +402,7 @@ class UNet(nn.Module):
310
402
  dropout=dropout,
311
403
  pool_kernel=pool_kernel,
312
404
  n2v2=n2v2,
405
+ groups=groups,
313
406
  )
314
407
 
315
408
  self.decoder = UnetDecoder(
@@ -319,11 +412,13 @@ class UNet(nn.Module):
319
412
  use_batch_norm=use_batch_norm,
320
413
  dropout=dropout,
321
414
  n2v2=n2v2,
415
+ groups=groups,
322
416
  )
323
417
  self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
324
- in_channels=num_channels_init,
418
+ in_channels=num_channels_init * groups,
325
419
  out_channels=num_classes,
326
420
  kernel_size=1,
421
+ groups=groups,
327
422
  )
328
423
  self.final_activation = get_activation(final_activation)
329
424
 
@@ -1,8 +1,5 @@
1
- """
2
- Prediction convenience functions.
1
+ """Prediction utility functions."""
3
2
 
4
- These functions are used during prediction.
5
- """
6
3
  from typing import List
7
4
 
8
5
  import numpy as np
@@ -20,7 +17,7 @@ def stitch_prediction(
20
17
  ----------
21
18
  tiles : List[torch.Tensor]
22
19
  Cropped tiles and their respective stitching coordinates.
23
- stitching_coords : List
20
+ stitching_data : List
24
21
  List of information and coordinates obtained from
25
22
  `dataset.tiled_patching.extract_tiles`.
26
23
 
@@ -3,39 +3,18 @@
3
3
  __all__ = [
4
4
  "get_all_transforms",
5
5
  "N2VManipulate",
6
- "NDFlip",
6
+ "XYFlip",
7
7
  "XYRandomRotate90",
8
8
  "ImageRestorationTTA",
9
9
  "Denormalize",
10
10
  "Normalize",
11
+ "Compose",
11
12
  ]
12
13
 
13
14
 
15
+ from .compose import Compose, get_all_transforms
14
16
  from .n2v_manipulate import N2VManipulate
15
- from .nd_flip import NDFlip
16
17
  from .normalize import Denormalize, Normalize
17
18
  from .tta import ImageRestorationTTA
19
+ from .xy_flip import XYFlip
18
20
  from .xy_random_rotate90 import XYRandomRotate90
19
-
20
- ALL_TRANSFORMS = {
21
- "Normalize": Normalize,
22
- "N2VManipulate": N2VManipulate,
23
- "NDFlip": NDFlip,
24
- "XYRandomRotate90": XYRandomRotate90,
25
- }
26
-
27
-
28
- def get_all_transforms() -> dict:
29
- """Return all the transforms accepted by CAREamics.
30
-
31
- Note that while CAREamics accepts any `Compose` transforms from Albumentations (see
32
- https://albumentations.ai/), only a few transformations are explicitely supported
33
- (see `SupportedTransform`).
34
-
35
- Returns
36
- -------
37
- dict
38
- A dictionary with all the transforms accepted by CAREamics, where the keys are
39
- the transform names and the values are the transform classes.
40
- """
41
- return ALL_TRANSFORMS