careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ """Tiled patching utilities."""
2
+
1
3
  import itertools
2
4
  from typing import Generator, List, Tuple, Union
3
5
 
@@ -8,7 +10,7 @@ from careamics.config.tile_information import TileInformation
8
10
 
9
11
  def _compute_crop_and_stitch_coords_1d(
10
12
  axis_size: int, tile_size: int, overlap: int
11
- ) -> Tuple[List[Tuple[int, ...]], ...]:
13
+ ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
12
14
  """
13
15
  Compute the coordinates of each tile along an axis, given the overlap.
14
16
 
@@ -43,9 +45,11 @@ def _compute_crop_and_stitch_coords_1d(
43
45
  stitch_coords.append(
44
46
  (
45
47
  i + overlap // 2 if i > 0 else 0,
46
- i + tile_size - overlap // 2
47
- if crop_coords[-1][1] < axis_size
48
- else axis_size,
48
+ (
49
+ i + tile_size - overlap // 2
50
+ if crop_coords[-1][1] < axis_size
51
+ else axis_size
52
+ ),
49
53
  )
50
54
  )
51
55
 
@@ -53,9 +57,11 @@ def _compute_crop_and_stitch_coords_1d(
53
57
  overlap_crop_coords.append(
54
58
  (
55
59
  overlap // 2 if i > 0 else 0,
56
- tile_size - overlap // 2
57
- if crop_coords[-1][1] < axis_size
58
- else tile_size,
60
+ (
61
+ tile_size - overlap // 2
62
+ if crop_coords[-1][1] < axis_size
63
+ else tile_size
64
+ ),
59
65
  )
60
66
  )
61
67
 
@@ -1,3 +1,5 @@
1
+ """Patch validation functions."""
2
+
1
3
  from typing import List, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -1,3 +1,5 @@
1
+ """Zarr dataset."""
2
+
1
3
  # from itertools import islice
2
4
  # from typing import Callable, Dict, List, Optional, Tuple, Union
3
5
 
@@ -1,10 +1,10 @@
1
1
  """Training and validation Lightning data modules."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
5
 
5
6
  import numpy as np
6
7
  import pytorch_lightning as L
7
- from albumentations import Compose
8
8
  from torch.utils.data import DataLoader
9
9
 
10
10
  from careamics.config import DataConfig
@@ -95,13 +95,13 @@ class CAREamicsTrainData(L.LightningDataModule):
95
95
  Batch size.
96
96
  use_in_memory : bool
97
97
  Whether to use in memory dataset if possible.
98
- train_data : Union[Path, str, np.ndarray]
98
+ train_data : Union[Path, np.ndarray]
99
99
  Training data.
100
- val_data : Optional[Union[Path, str, np.ndarray]]
100
+ val_data : Optional[Union[Path, np.ndarray]]
101
101
  Validation data.
102
- train_data_target : Optional[Union[Path, str, np.ndarray]]
102
+ train_data_target : Optional[Union[Path, np.ndarray]]
103
103
  Training target data.
104
- val_data_target : Optional[Union[Path, str, np.ndarray]]
104
+ val_data_target : Optional[Union[Path, np.ndarray]]
105
105
  Validation target data.
106
106
  val_percentage : float
107
107
  Percentage of the training data to use for validation, if no validation data is
@@ -217,17 +217,33 @@ class CAREamicsTrainData(L.LightningDataModule):
217
217
  )
218
218
 
219
219
  # configuration
220
- self.data_config = data_config
221
- self.data_type = data_config.data_type
222
- self.batch_size = data_config.batch_size
223
- self.use_in_memory = use_in_memory
220
+ self.data_config: DataConfig = data_config
221
+ self.data_type: str = data_config.data_type
222
+ self.batch_size: int = data_config.batch_size
223
+ self.use_in_memory: bool = use_in_memory
224
+
225
+ # data: make data Path or np.ndarray, use type annotations for mypy
226
+ self.train_data: Union[Path, np.ndarray] = (
227
+ Path(train_data) if isinstance(train_data, str) else train_data
228
+ )
224
229
 
225
- # data
226
- self.train_data = train_data
227
- self.val_data = val_data
230
+ self.val_data: Union[Path, np.ndarray] = (
231
+ Path(val_data) if isinstance(val_data, str) else val_data
232
+ )
228
233
 
229
- self.train_data_target = train_data_target
230
- self.val_data_target = val_data_target
234
+ self.train_data_target: Union[Path, np.ndarray] = (
235
+ Path(train_data_target)
236
+ if isinstance(train_data_target, str)
237
+ else train_data_target
238
+ )
239
+
240
+ self.val_data_target: Union[Path, np.ndarray] = (
241
+ Path(val_data_target)
242
+ if isinstance(val_data_target, str)
243
+ else val_data_target
244
+ )
245
+
246
+ # validation split
231
247
  self.val_percentage = val_percentage
232
248
  self.val_minimum_split = val_minimum_split
233
249
 
@@ -241,10 +257,10 @@ class CAREamicsTrainData(L.LightningDataModule):
241
257
  elif data_config.data_type != SupportedData.ARRAY:
242
258
  self.read_source_func = get_read_func(data_config.data_type)
243
259
 
244
- self.extension_filter = extension_filter
260
+ self.extension_filter: str = extension_filter
245
261
 
246
262
  # Pytorch dataloader parameters
247
- self.dataloader_params = (
263
+ self.dataloader_params: Dict[str, Any] = (
248
264
  data_config.dataloader_params if data_config.dataloader_params else {}
249
265
  )
250
266
 
@@ -309,20 +325,30 @@ class CAREamicsTrainData(L.LightningDataModule):
309
325
  """
310
326
  # if numpy array
311
327
  if self.data_type == SupportedData.ARRAY:
328
+ # mypy checks
329
+ assert isinstance(self.train_data, np.ndarray)
330
+ if self.train_data_target is not None:
331
+ assert isinstance(self.train_data_target, np.ndarray)
332
+
312
333
  # train dataset
313
334
  self.train_dataset: DatasetType = InMemoryDataset(
314
335
  data_config=self.data_config,
315
336
  inputs=self.train_data,
316
- data_target=self.train_data_target,
337
+ input_target=self.train_data_target,
317
338
  )
318
339
 
319
340
  # validation dataset
320
341
  if self.val_data is not None:
342
+ # mypy checks
343
+ assert isinstance(self.val_data, np.ndarray)
344
+ if self.val_data_target is not None:
345
+ assert isinstance(self.val_data_target, np.ndarray)
346
+
321
347
  # create its own dataset
322
348
  self.val_dataset: DatasetType = InMemoryDataset(
323
349
  data_config=self.data_config,
324
350
  inputs=self.val_data,
325
- data_target=self.val_data_target,
351
+ input_target=self.val_data_target,
326
352
  )
327
353
  else:
328
354
  # extract validation from the training patches
@@ -341,9 +367,9 @@ class CAREamicsTrainData(L.LightningDataModule):
341
367
  self.train_dataset = InMemoryDataset(
342
368
  data_config=self.data_config,
343
369
  inputs=self.train_files,
344
- data_target=self.train_target_files
345
- if self.train_data_target
346
- else None,
370
+ input_target=(
371
+ self.train_target_files if self.train_data_target else None
372
+ ),
347
373
  read_source_func=self.read_source_func,
348
374
  )
349
375
 
@@ -352,9 +378,9 @@ class CAREamicsTrainData(L.LightningDataModule):
352
378
  self.val_dataset = InMemoryDataset(
353
379
  data_config=self.data_config,
354
380
  inputs=self.val_files,
355
- data_target=self.val_target_files
356
- if self.val_data_target
357
- else None,
381
+ input_target=(
382
+ self.val_target_files if self.val_data_target else None
383
+ ),
358
384
  read_source_func=self.read_source_func,
359
385
  )
360
386
  else:
@@ -370,9 +396,9 @@ class CAREamicsTrainData(L.LightningDataModule):
370
396
  self.train_dataset = PathIterableDataset(
371
397
  data_config=self.data_config,
372
398
  src_files=self.train_files,
373
- target_files=self.train_target_files
374
- if self.train_data_target
375
- else None,
399
+ target_files=(
400
+ self.train_target_files if self.train_data_target else None
401
+ ),
376
402
  read_source_func=self.read_source_func,
377
403
  )
378
404
 
@@ -382,9 +408,9 @@ class CAREamicsTrainData(L.LightningDataModule):
382
408
  self.val_dataset = PathIterableDataset(
383
409
  data_config=self.data_config,
384
410
  src_files=self.val_files,
385
- target_files=self.val_target_files
386
- if self.val_data_target
387
- else None,
411
+ target_files=(
412
+ self.val_target_files if self.val_data_target else None
413
+ ),
388
414
  read_source_func=self.read_source_func,
389
415
  )
390
416
  elif len(self.train_files) <= self.val_minimum_split:
@@ -452,8 +478,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
452
478
  In particular, N2V requires a specific transformation (N2V manipulates), which is
453
479
  not compatible with supervised training. The default transformations applied to the
454
480
  training patches are defined in `careamics.config.data_model`. To use different
455
- transformations, pass a list of transforms or an albumentation `Compose` as
456
- `transforms` parameter. See examples for more details.
481
+ transformations, pass a list of transforms. See examples for more details.
457
482
 
458
483
  By default, CAREamics only supports types defined in
459
484
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -488,7 +513,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
488
513
  Batch size.
489
514
  val_data : Optional[Union[str, Path]], optional
490
515
  Validation data, by default None.
491
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
516
+ transforms : List[TRANSFORMS_UNION], optional
492
517
  List of transforms to apply to training patches. If None, default transforms
493
518
  are applied.
494
519
  train_target_data : Optional[Union[str, Path]], optional
@@ -584,7 +609,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
584
609
  axes: str,
585
610
  batch_size: int,
586
611
  val_data: Optional[Union[str, Path]] = None,
587
- transforms: Optional[Union[List[TRANSFORMS_UNION], Compose]] = None,
612
+ transforms: Optional[List[TRANSFORMS_UNION]] = None,
588
613
  train_target_data: Optional[Union[str, Path]] = None,
589
614
  val_target_data: Optional[Union[str, Path]] = None,
590
615
  read_source_func: Optional[Callable] = None,
@@ -617,8 +642,8 @@ class TrainingDataWrapper(CAREamicsTrainData):
617
642
  In particular, N2V requires a specific transformation (N2V manipulates), which
618
643
  is not compatible with supervised training. The default transformations applied
619
644
  to the training patches are defined in `careamics.config.data_model`. To use
620
- different transformations, pass a list of transforms or an albumentation
621
- `Compose` as `transforms` parameter. See examples for more details.
645
+ different transformations, pass a list of transforms. See examples for more
646
+ details.
622
647
 
623
648
  By default, CAREamics only supports types defined in
624
649
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -655,7 +680,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
655
680
  Batch size.
656
681
  val_data : Optional[Union[str, Path]], optional
657
682
  Validation data, by default None.
658
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
683
+ transforms : Optional[List[TRANSFORMS_UNION]], optional
659
684
  List of transforms to apply to training patches. If None, default transforms
660
685
  are applied.
661
686
  train_target_data : Optional[Union[str, Path]], optional
@@ -709,10 +734,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
709
734
  self.data_config = DataConfig(**data_dict)
710
735
 
711
736
  # N2V specific checks, N2V, structN2V, and transforms
712
- if (
713
- self.data_config.has_transform_list()
714
- and self.data_config.has_n2v_manipulate()
715
- ):
737
+ if self.data_config.has_n2v_manipulate():
716
738
  # there is not target, n2v2 and structN2V can be changed
717
739
  if train_target_data is None:
718
740
  self.data_config.set_N2V2(use_n2v2)
@@ -1,3 +1,5 @@
1
+ """CAREamics Lightning module."""
2
+
1
3
  from typing import Any, Optional, Union
2
4
 
3
5
  import pytorch_lightning as L
@@ -24,6 +26,11 @@ class CAREamicsModule(L.LightningModule):
24
26
  This class encapsulates the a PyTorch model along with the training, validation,
25
27
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
26
28
 
29
+ Parameters
30
+ ----------
31
+ algorithm_config : Union[AlgorithmModel, dict]
32
+ Algorithm configuration.
33
+
27
34
  Attributes
28
35
  ----------
29
36
  model : nn.Module
@@ -39,8 +46,7 @@ class CAREamicsModule(L.LightningModule):
39
46
  """
40
47
 
41
48
  def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
42
- """
43
- CAREamics Lightning module.
49
+ """Lightning module for CAREamics.
44
50
 
45
51
  This class encapsulates the a PyTorch model along with the training, validation,
46
52
  and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
@@ -162,7 +168,7 @@ class CAREamicsModule(L.LightningModule):
162
168
  mean=self._trainer.datamodule.predict_dataset.mean,
163
169
  std=self._trainer.datamodule.predict_dataset.std,
164
170
  )
165
- denormalized_output = denorm(image=output)["image"]
171
+ denormalized_output, _ = denorm(patch=output)
166
172
 
167
173
  if len(aux) > 0:
168
174
  return denormalized_output, aux
@@ -1,10 +1,10 @@
1
1
  """Prediction Lightning data modules."""
2
+
2
3
  from pathlib import Path
3
- from typing import Any, Callable, List, Literal, Optional, Tuple, Union
4
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
4
5
 
5
6
  import numpy as np
6
7
  import pytorch_lightning as L
7
- from albumentations import Compose
8
8
  from torch.utils.data import DataLoader
9
9
  from torch.utils.data.dataloader import default_collate
10
10
 
@@ -39,7 +39,7 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
39
39
 
40
40
  Parameters
41
41
  ----------
42
- batch : Tuple[Tuple[np.ndarray, TileInformation], ...]
42
+ batch : List[Tuple[np.ndarray, TileInformation], ...]
43
43
  Batch of tiles.
44
44
 
45
45
  Returns
@@ -257,14 +257,13 @@ class PredictDataWrapper(CAREamicsPredictData):
257
257
 
258
258
  The default transformations applied to the images are defined in
259
259
  `careamics.config.inference_model`. To use different transformations, pass a list
260
- of transforms or an albumentation `Compose` as `transforms` parameter. See examples
260
+ of transforms. See examples
261
261
  for more details.
262
262
 
263
263
  The `mean` and `std` parameters are only used if Normalization is defined either
264
- in the default transformations or in the `transforms` parameter, but not with
265
- a `Compose` object. If you pass a `Normalization` transform in a list as
266
- `transforms`, then the mean and std parameters will be overwritten by those passed
267
- to this method.
264
+ in the default transformations or in the `transforms` parameter. If you pass a
265
+ `Normalization` transform in a list as `transforms`, then the mean and std
266
+ parameters will be overwritten by those passed to this method.
268
267
 
269
268
  By default, CAREamics only supports types defined in
270
269
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -276,6 +275,12 @@ class PredictDataWrapper(CAREamicsPredictData):
276
275
  dataloaders, except for `batch_size`, which is set by the `batch_size`
277
276
  parameter.
278
277
 
278
+ Note that if you are using a UNet model and tiling, the tile size must be
279
+ divisible in every dimension by 2**d, where d is the depth of the model. This
280
+ avoids artefacts arising from the broken shift invariance induced by the
281
+ pooling layers of the UNet. If your image has less dimensions, as it may
282
+ happen in the Z dimension, consider padding your image.
283
+
279
284
  Parameters
280
285
  ----------
281
286
  pred_data : Union[str, Path, np.ndarray]
@@ -298,9 +303,6 @@ class PredictDataWrapper(CAREamicsPredictData):
298
303
  Batch size.
299
304
  tta_transforms : bool, optional
300
305
  Use test time augmentation, by default True.
301
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
302
- List of transforms to apply to prediction patches. If None, default
303
- transforms are applied.
304
306
  read_source_func : Optional[Callable], optional
305
307
  Function to read the source data, used if `data_type` is `custom`, by
306
308
  default None.
@@ -321,7 +323,6 @@ class PredictDataWrapper(CAREamicsPredictData):
321
323
  axes: str = "YX",
322
324
  batch_size: int = 1,
323
325
  tta_transforms: bool = True,
324
- transforms: Optional[Union[List, Compose]] = None,
325
326
  read_source_func: Optional[Callable] = None,
326
327
  extension_filter: str = "",
327
328
  dataloader_params: Optional[dict] = None,
@@ -351,9 +352,6 @@ class PredictDataWrapper(CAREamicsPredictData):
351
352
  Batch size.
352
353
  tta_transforms : bool, optional
353
354
  Use test time augmentation, by default True.
354
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
355
- List of transforms to apply to prediction patches. If None, default
356
- transforms are applied.
357
355
  read_source_func : Optional[Callable], optional
358
356
  Function to read the source data, used if `data_type` is `custom`, by
359
357
  default None.
@@ -364,7 +362,7 @@ class PredictDataWrapper(CAREamicsPredictData):
364
362
  """
365
363
  if dataloader_params is None:
366
364
  dataloader_params = {}
367
- prediction_dict = {
365
+ prediction_dict: Dict[str, Any] = {
368
366
  "data_type": data_type,
369
367
  "tile_size": tile_size,
370
368
  "tile_overlap": tile_overlap,
@@ -373,12 +371,9 @@ class PredictDataWrapper(CAREamicsPredictData):
373
371
  "std": std,
374
372
  "tta": tta_transforms,
375
373
  "batch_size": batch_size,
374
+ "transforms": [],
376
375
  }
377
376
 
378
- # if transforms are passed (otherwise it will use the default ones)
379
- if transforms is not None:
380
- prediction_dict["transforms"] = transforms
381
-
382
377
  # validate configuration
383
378
  self.prediction_config = InferenceConfig(**prediction_dict)
384
379
 
@@ -1,3 +1,5 @@
1
+ """Lithning prediction loop allowing tiling."""
2
+
1
3
  from typing import Optional
2
4
 
3
5
  import pytorch_lightning as L
@@ -18,14 +20,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
18
20
  """
19
21
 
20
22
  def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
21
- """
22
- Calls `on_predict_epoch_end` hook.
23
+ """Call `on_predict_epoch_end` hook.
23
24
 
24
25
  Adapted from the parent method.
25
26
 
26
27
  Returns
27
28
  -------
28
- the results for all dataloaders
29
+ Optional[_PREDICT_OUTPUT]
30
+ Prediction output.
29
31
  """
30
32
  trainer = self.trainer
31
33
  call._call_callback_hooks(trainer, "on_predict_epoch_end")
@@ -45,15 +47,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
45
47
 
46
48
  @_no_grad_context
47
49
  def run(self) -> Optional[_PREDICT_OUTPUT]:
48
- """
49
- Runs the prediction loop.
50
+ """Run the prediction loop.
50
51
 
51
52
  Adapted from the parent method in order to stitch the predictions.
52
53
 
53
54
  Returns
54
55
  -------
55
56
  Optional[_PREDICT_OUTPUT]
56
- Prediction output
57
+ Prediction output.
57
58
  """
58
59
  self.setup_data()
59
60
  if self.skip:
@@ -86,6 +87,7 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop):
86
87
 
87
88
  ########################################################
88
89
  ################ CAREamics specific code ###############
90
+ # TODO: next line is not compatible with muSplit
89
91
  is_tiled = len(self.predictions[batch_idx]) == 2
90
92
  if is_tiled:
91
93
  # extract the last tile flag and the coordinates (crop and stitch)
@@ -1,7 +1,5 @@
1
1
  """Losses module."""
2
2
 
3
+ __all__ = ["loss_factory"]
3
4
 
4
5
  from .loss_factory import loss_factory
5
-
6
- # from .noise_model_factory import noise_model_factory as noise_model_factory
7
- # from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel
@@ -3,6 +3,7 @@ Loss factory module.
3
3
 
4
4
  This module contains a factory function for creating loss functions.
5
5
  """
6
+
6
7
  from typing import Callable, Union
7
8
 
8
9
  from ..config.support import SupportedLoss
@@ -16,7 +17,7 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
16
17
 
17
18
  Parameters
18
19
  ----------
19
- loss: SupportedLoss
20
+ loss : Union[SupportedLoss, str]
20
21
  Requested loss.
21
22
 
22
23
  Returns
@@ -5,23 +5,27 @@ This submodule contains the various losses used in CAREamics.
5
5
  """
6
6
 
7
7
  import torch
8
-
9
- # TODO if we are only using the DiceLoss, can we just implement it?
10
- # from segmentation_models_pytorch.losses import DiceLoss
11
8
  from torch.nn import L1Loss, MSELoss
12
9
 
13
10
 
14
- def mse_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
11
+ def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
12
  """
16
13
  Mean squared error loss.
17
14
 
15
+ Parameters
16
+ ----------
17
+ source : torch.Tensor
18
+ Source patches.
19
+ target : torch.Tensor
20
+ Target patches.
21
+
18
22
  Returns
19
23
  -------
20
24
  torch.Tensor
21
25
  Loss value.
22
26
  """
23
27
  loss = MSELoss()
24
- return loss(samples, labels)
28
+ return loss(source, target)
25
29
 
26
30
 
27
31
  def n2v_loss(
@@ -34,9 +38,9 @@ def n2v_loss(
34
38
 
35
39
  Parameters
36
40
  ----------
37
- samples : torch.Tensor
41
+ manipulated_patches : torch.Tensor
38
42
  Patches with manipulated pixels.
39
- labels : torch.Tensor
43
+ original_patches : torch.Tensor
40
44
  Noisy patches.
41
45
  masks : torch.Tensor
42
46
  Array containing masked pixel locations.
@@ -1,6 +1,5 @@
1
1
  """Model I/O utilities."""
2
2
 
3
-
4
3
  __all__ = ["load_pretrained", "export_to_bmz"]
5
4
 
6
5
 
@@ -1,4 +1,5 @@
1
1
  """Functions used to create a README.md file for BMZ export."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Optional
4
5
 
@@ -117,4 +118,4 @@ def readme_factory(
117
118
 
118
119
  readme.write_text("".join(description))
119
120
 
120
- return readme
121
+ return readme.absolute()
@@ -1,4 +1,5 @@
1
1
  """Bioimage.io utils."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Union
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Module use to build BMZ model description."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import List, Optional, Tuple, Union
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Function to export to the BioImage Model Zoo format."""
2
+
2
3
  import tempfile
3
4
  from pathlib import Path
4
5
  from typing import List, Optional, Tuple, Union
@@ -103,9 +104,9 @@ def export_to_bmz(
103
104
  authors : List[dict]
104
105
  Authors of the model.
105
106
  input_array : np.ndarray
106
- Input array.
107
+ Input array, should not have been normalized.
107
108
  output_array : np.ndarray
108
- Output array.
109
+ Output array, should have been denormalized.
109
110
  channel_names : Optional[List[str]], optional
110
111
  Channel names, by default None.
111
112
  data_description : Optional[str], optional
@@ -177,7 +178,7 @@ def export_to_bmz(
177
178
  )
178
179
 
179
180
  # test model description
180
- summary: ValidationSummary = test_model(model_description)
181
+ summary: ValidationSummary = test_model(model_description, decimal=2)
181
182
  if summary.status == "failed":
182
183
  raise ValueError(f"Model description test failed: {summary}")
183
184
 
@@ -1,3 +1,5 @@
1
+ """Activations for CAREamics models."""
2
+
1
3
  from typing import Callable, Union
2
4
 
3
5
  import torch.nn as nn