careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ """Callback saving CAREamics configuration as hyperparameters in the model."""
2
+
1
3
  from pytorch_lightning import LightningModule, Trainer
2
4
  from pytorch_lightning.callbacks import Callback
3
5
 
@@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
11
13
  This allows saving the configuration as dictionnary in the checkpoints, and
12
14
  loading it subsequently in a CAREamist instance.
13
15
 
16
+ Parameters
17
+ ----------
18
+ config : Configuration
19
+ CAREamics configuration to be saved as hyperparameter in the model.
20
+
14
21
  Attributes
15
22
  ----------
16
23
  config : Configuration
17
24
  CAREamics configuration to be saved as hyperparameter in the model.
18
25
  """
19
26
 
20
- def __init__(self, config: Configuration):
27
+ def __init__(self, config: Configuration) -> None:
21
28
  """
22
29
  Constructor.
23
30
 
@@ -28,14 +35,14 @@ class HyperParametersCallback(Callback):
28
35
  """
29
36
  self.config = config
30
37
 
31
- def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
38
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
32
39
  """
33
40
  Update the hyperparameters of the model with the configuration on train start.
34
41
 
35
42
  Parameters
36
43
  ----------
37
44
  trainer : Trainer
38
- PyTorch Lightning trainer.
45
+ PyTorch Lightning trainer, unused.
39
46
  pl_module : LightningModule
40
47
  PyTorch Lightning module.
41
48
  """
@@ -1,3 +1,5 @@
1
+ """Progressbar callback."""
2
+
1
3
  import sys
2
4
  from typing import Dict, Union
3
5
 
@@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
10
12
  """Progress bar for training and validation steps."""
11
13
 
12
14
  def init_train_tqdm(self) -> tqdm:
13
- """Override this to customize the tqdm bar for training."""
15
+ """Override this to customize the tqdm bar for training.
16
+
17
+ Returns
18
+ -------
19
+ tqdm
20
+ A tqdm bar.
21
+ """
14
22
  bar = tqdm(
15
23
  desc="Training",
16
24
  position=(2 * self.process_position),
@@ -23,7 +31,13 @@ class ProgressBarCallback(TQDMProgressBar):
23
31
  return bar
24
32
 
25
33
  def init_validation_tqdm(self) -> tqdm:
26
- """Override this to customize the tqdm bar for validation."""
34
+ """Override this to customize the tqdm bar for validation.
35
+
36
+ Returns
37
+ -------
38
+ tqdm
39
+ A tqdm bar.
40
+ """
27
41
  # The main progress bar doesn't exist in `trainer.validate()`
28
42
  has_main_bar = self.train_progress_bar is not None
29
43
  bar = tqdm(
@@ -37,7 +51,13 @@ class ProgressBarCallback(TQDMProgressBar):
37
51
  return bar
38
52
 
39
53
  def init_test_tqdm(self) -> tqdm:
40
- """Override this to customize the tqdm bar for testing."""
54
+ """Override this to customize the tqdm bar for testing.
55
+
56
+ Returns
57
+ -------
58
+ tqdm
59
+ A tqdm bar.
60
+ """
41
61
  bar = tqdm(
42
62
  desc="Testing",
43
63
  position=(2 * self.process_position),
@@ -52,6 +72,19 @@ class ProgressBarCallback(TQDMProgressBar):
52
72
  def get_metrics(
53
73
  self, trainer: Trainer, pl_module: LightningModule
54
74
  ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
55
- """Override this to customize the metrics displayed in the progress bar."""
75
+ """Override this to customize the metrics displayed in the progress bar.
76
+
77
+ Parameters
78
+ ----------
79
+ trainer : Trainer
80
+ The trainer object.
81
+ pl_module : LightningModule
82
+ The LightningModule object, unused.
83
+
84
+ Returns
85
+ -------
86
+ dict
87
+ A dictionary with the metrics to display in the progress bar.
88
+ """
56
89
  pbar_metrics = trainer.progress_bar_metrics
57
90
  return {**pbar_metrics}
careamics/careamist.py CHANGED
@@ -18,13 +18,14 @@ from careamics.config import (
18
18
  create_inference_configuration,
19
19
  load_configuration,
20
20
  )
21
- from careamics.config.inference_model import TRANSFORMS_UNION
22
21
  from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
22
+ from careamics.dataset.dataset_utils import reshape_array
23
23
  from careamics.lightning_datamodule import CAREamicsTrainData
24
24
  from careamics.lightning_module import CAREamicsModule
25
25
  from careamics.lightning_prediction_datamodule import CAREamicsPredictData
26
26
  from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
27
27
  from careamics.model_io import export_to_bmz, load_pretrained
28
+ from careamics.transforms import Denormalize
28
29
  from careamics.utils import check_path_exists, get_logger
29
30
 
30
31
  from .callbacks import HyperParametersCallback
@@ -73,8 +74,7 @@ class CAREamist:
73
74
  source: Union[Path, str],
74
75
  work_dir: Optional[str] = None,
75
76
  experiment_name: str = "CAREamics",
76
- ) -> None:
77
- ...
77
+ ) -> None: ...
78
78
 
79
79
  @overload
80
80
  def __init__( # numpydoc ignore=GL08
@@ -82,8 +82,7 @@ class CAREamist:
82
82
  source: Configuration,
83
83
  work_dir: Optional[str] = None,
84
84
  experiment_name: str = "CAREamics",
85
- ) -> None:
86
- ...
85
+ ) -> None: ...
87
86
 
88
87
  def __init__(
89
88
  self,
@@ -478,8 +477,7 @@ class CAREamist:
478
477
  source: CAREamicsPredictData,
479
478
  *,
480
479
  checkpoint: Optional[Literal["best", "last"]] = None,
481
- ) -> Union[list, np.ndarray]:
482
- ...
480
+ ) -> Union[list, np.ndarray]: ...
483
481
 
484
482
  @overload
485
483
  def predict( # numpydoc ignore=GL08
@@ -491,14 +489,12 @@ class CAREamist:
491
489
  tile_overlap: Tuple[int, ...] = (48, 48),
492
490
  axes: Optional[str] = None,
493
491
  data_type: Optional[Literal["tiff", "custom"]] = None,
494
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
495
492
  tta_transforms: bool = True,
496
493
  dataloader_params: Optional[Dict] = None,
497
494
  read_source_func: Optional[Callable] = None,
498
495
  extension_filter: str = "",
499
496
  checkpoint: Optional[Literal["best", "last"]] = None,
500
- ) -> Union[list, np.ndarray]:
501
- ...
497
+ ) -> Union[list, np.ndarray]: ...
502
498
 
503
499
  @overload
504
500
  def predict( # numpydoc ignore=GL08
@@ -510,12 +506,10 @@ class CAREamist:
510
506
  tile_overlap: Tuple[int, ...] = (48, 48),
511
507
  axes: Optional[str] = None,
512
508
  data_type: Optional[Literal["array"]] = None,
513
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
514
509
  tta_transforms: bool = True,
515
510
  dataloader_params: Optional[Dict] = None,
516
511
  checkpoint: Optional[Literal["best", "last"]] = None,
517
- ) -> Union[list, np.ndarray]:
518
- ...
512
+ ) -> Union[list, np.ndarray]: ...
519
513
 
520
514
  def predict(
521
515
  self,
@@ -526,7 +520,6 @@ class CAREamist:
526
520
  tile_overlap: Tuple[int, ...] = (48, 48),
527
521
  axes: Optional[str] = None,
528
522
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
529
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
530
523
  tta_transforms: bool = True,
531
524
  dataloader_params: Optional[Dict] = None,
532
525
  read_source_func: Optional[Callable] = None,
@@ -543,11 +536,15 @@ class CAREamist:
543
536
  configuration parameters will be used, with the `patch_size` instead of
544
537
  `tile_size`.
545
538
 
546
- The default transforms are defined in the `InferenceModel` Pydantic model.
547
-
548
539
  Test-time augmentation (TTA) can be switched off using the `tta_transforms`
549
540
  parameter.
550
541
 
542
+ Note that if you are using a UNet model and tiling, the tile size must be
543
+ divisible in every dimension by 2**d, where d is the depth of the model. This
544
+ avoids artefacts arising from the broken shift invariance induced by the
545
+ pooling layers of the UNet. If your image has less dimensions, as it may
546
+ happen in the Z dimension, consider padding your image.
547
+
551
548
  Parameters
552
549
  ----------
553
550
  source : Union[CAREamicsClay, Path, str, np.ndarray]
@@ -562,8 +559,6 @@ class CAREamist:
562
559
  Axes of the input data, by default None.
563
560
  data_type : Optional[Literal["array", "tiff", "custom"]], optional
564
561
  Type of the input data, by default None.
565
- transforms : Optional[List[TRANSFORMS_UNION]], optional
566
- List of transforms to apply to the data, by default None.
567
562
  tta_transforms : bool, optional
568
563
  Whether to apply test-time augmentation, by default True.
569
564
  dataloader_params : Optional[Dict], optional
@@ -602,12 +597,11 @@ class CAREamist:
602
597
  )
603
598
  # create predict config, reuse training config if parameters missing
604
599
  prediction_config = create_inference_configuration(
605
- training_configuration=self.cfg,
600
+ configuration=self.cfg,
606
601
  tile_size=tile_size,
607
602
  tile_overlap=tile_overlap,
608
603
  data_type=data_type,
609
604
  axes=axes,
610
- transforms=transforms,
611
605
  tta_transforms=tta_transforms,
612
606
  batch_size=batch_size,
613
607
  )
@@ -659,38 +653,41 @@ class CAREamist:
659
653
  f"np.ndarray (got {type(source)})."
660
654
  )
661
655
 
662
- def export_to_bmz(
656
+ def _create_data_for_bmz(
663
657
  self,
664
- path: Union[Path, str],
665
- name: str,
666
- authors: List[dict],
667
658
  input_array: Optional[np.ndarray] = None,
668
- general_description: str = "",
669
- channel_names: Optional[List[str]] = None,
670
- data_description: Optional[str] = None,
671
- ) -> None:
672
- """Export the model to the BioImage Model Zoo format.
659
+ ) -> np.ndarray:
660
+ """Create data for BMZ export.
673
661
 
674
- Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
662
+ If no `input_array` is provided, this method checks if there is a prediction
663
+ datamodule, or a training data module, to extract a patch. If none exists,
664
+ then a random aray is created.
665
+
666
+ If there is a non-singleton batch dimension, this method returns only the first
667
+ element.
675
668
 
676
669
  Parameters
677
670
  ----------
678
- path : Union[Path, str]
679
- Path to save the model.
680
- name : str
681
- Name of the model.
682
- authors : List[dict]
683
- List of authors of the model.
684
671
  input_array : Optional[np.ndarray], optional
685
- Input array for the model, must be of shape SC(Z)YX, by default None.
686
- general_description : str
687
- General description of the model, used in the metadata of the BMZ archive.
688
- channel_names : Optional[List[str]], optional
689
- Channel names, by default None.
690
- data_description : Optional[str], optional
691
- Description of the data, by default None.
672
+ Input array, by default None.
673
+
674
+ Returns
675
+ -------
676
+ np.ndarray
677
+ Input data for BMZ export.
678
+
679
+ Raises
680
+ ------
681
+ ValueError
682
+ If mean and std are not provided in the configuration.
692
683
  """
693
684
  if input_array is None:
685
+ if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
686
+ raise ValueError(
687
+ "Mean and std cannot be None in the configuration in order to"
688
+ "export to the BMZ format. Was the model trained?"
689
+ )
690
+
694
691
  # generate images, priority is given to the prediction data module
695
692
  if self.pred_datamodule is not None:
696
693
  # unpack a batch, ignore masks or targets
@@ -698,19 +695,23 @@ class CAREamist:
698
695
 
699
696
  # convert torch.Tensor to numpy
700
697
  input_patch = input_patch.numpy()
698
+
699
+ # denormalize
700
+ denormalize = Denormalize(
701
+ mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
702
+ )
703
+ input_patch, _ = denormalize(input_patch)
704
+
701
705
  elif self.train_datamodule is not None:
702
706
  input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
703
707
  input_patch = input_patch.numpy()
704
- else:
705
- if (
706
- self.cfg.data_config.mean is None
707
- or self.cfg.data_config.std is None
708
- ):
709
- raise ValueError(
710
- "Mean and std cannot be None in the configuration in order to"
711
- "export to the BMZ format. Was the model trained?"
712
- )
713
708
 
709
+ # denormalize
710
+ denormalize = Denormalize(
711
+ mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
712
+ )
713
+ input_patch, _ = denormalize(input_patch)
714
+ else:
714
715
  # create a random input array
715
716
  input_patch = np.random.normal(
716
717
  loc=self.cfg.data_config.mean,
@@ -720,11 +721,47 @@ class CAREamist:
720
721
  np.newaxis, np.newaxis, ...
721
722
  ] # add S & C dimensions
722
723
  else:
723
- input_patch = input_array
724
+ # potentially correct shape
725
+ input_patch = reshape_array(input_array, self.cfg.data_config.axes)
724
726
 
725
- # if there is a batch dimension
727
+ # if this a batch
726
728
  if input_patch.shape[0] > 1:
727
- input_patch = input_patch[0:1, ...] # keep singleton dim
729
+ input_patch = input_patch[[0], ...] # keep singleton dim
730
+
731
+ return input_patch
732
+
733
+ def export_to_bmz(
734
+ self,
735
+ path: Union[Path, str],
736
+ name: str,
737
+ authors: List[dict],
738
+ input_array: Optional[np.ndarray] = None,
739
+ general_description: str = "",
740
+ channel_names: Optional[List[str]] = None,
741
+ data_description: Optional[str] = None,
742
+ ) -> None:
743
+ """Export the model to the BioImage Model Zoo format.
744
+
745
+ Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
746
+
747
+ Parameters
748
+ ----------
749
+ path : Union[Path, str]
750
+ Path to save the model.
751
+ name : str
752
+ Name of the model.
753
+ authors : List[dict]
754
+ List of authors of the model.
755
+ input_array : Optional[np.ndarray], optional
756
+ Input array for the model, must be of shape SC(Z)YX, by default None.
757
+ general_description : str
758
+ General description of the model, used in the metadata of the BMZ archive.
759
+ channel_names : Optional[List[str]], optional
760
+ Channel names, by default None.
761
+ data_description : Optional[str], optional
762
+ Description of the data, by default None.
763
+ """
764
+ input_patch = self._create_data_for_bmz(input_array)
728
765
 
729
766
  # axes need to be reformated for the export because reshaping was done in the
730
767
  # datamodule
@@ -1,6 +1,5 @@
1
1
  """Configuration module."""
2
2
 
3
-
4
3
  __all__ = [
5
4
  "AlgorithmConfig",
6
5
  "DataConfig",
@@ -1,3 +1,5 @@
1
+ """Algorithm configuration."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from pprint import pformat
@@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel):
17
19
  training algorithm: which algorithm, loss function, model architecture, optimizer,
18
20
  and learning rate scheduler to use.
19
21
 
20
- Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only
21
- compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows
22
- you to register your own architecture and select it using its name as
22
+ Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
23
+ only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
24
+ allows you to register your own architecture and select it using its name as
23
25
  `name` in the custom pydantic model.
24
26
 
25
27
  Attributes
@@ -1,3 +1,5 @@
1
+ """Base model for the various CAREamics architectures."""
2
+
1
3
  from typing import Any, Dict
2
4
 
3
5
  from pydantic import BaseModel
@@ -16,6 +18,11 @@ class ArchitectureModel(BaseModel):
16
18
  """
17
19
  Dump the model as a dictionary, ignoring the architecture keyword.
18
20
 
21
+ Parameters
22
+ ----------
23
+ **kwargs : Any
24
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
25
+
19
26
  Returns
20
27
  -------
21
28
  dict[str, Any]
@@ -1,3 +1,5 @@
1
+ """Custom architecture Pydantic model."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from pprint import pformat
@@ -84,6 +86,11 @@ class CustomModel(ArchitectureModel):
84
86
  value : str
85
87
  Name of the custom model as registered using the `@register_model`
86
88
  decorator.
89
+
90
+ Returns
91
+ -------
92
+ str
93
+ The custom model name.
87
94
  """
88
95
  # delegate error to get_custom_model
89
96
  model = get_custom_model(value)
@@ -134,7 +141,7 @@ class CustomModel(ArchitectureModel):
134
141
 
135
142
  Parameters
136
143
  ----------
137
- kwargs : Any
144
+ **kwargs : Any
138
145
  Additional keyword arguments from Pydantic BaseModel model_dump method.
139
146
 
140
147
  Returns
@@ -1,3 +1,5 @@
1
+ """Custom model registration utilities."""
2
+
1
3
  from typing import Callable
2
4
 
3
5
  from torch.nn import Module
@@ -53,7 +55,7 @@ def register_model(name: str) -> Callable:
53
55
  Parameters
54
56
  ----------
55
57
  model : Module
56
- Module class to register
58
+ Module class to register.
57
59
 
58
60
  Returns
59
61
  -------
@@ -1,3 +1,5 @@
1
+ """UNet Pydantic model."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Literal
@@ -39,6 +41,7 @@ class UNetModel(ArchitectureModel):
39
41
  "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
40
42
  ] = Field(default="None", validate_default=True)
41
43
  n2v2: bool = Field(default=False, validate_default=True)
44
+ independent_channels: bool = Field(default=True, validate_default=True)
42
45
 
43
46
  @field_validator("num_channels_init")
44
47
  @classmethod
@@ -1,3 +1,5 @@
1
+ """VAE Pydantic model."""
2
+
1
3
  from typing import Literal
2
4
 
3
5
  from pydantic import (
@@ -1,4 +1,5 @@
1
- """Checkpoint saving configuration."""
1
+ """Callback Pydantic models."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from datetime import timedelta
@@ -12,13 +13,7 @@ from pydantic import (
12
13
 
13
14
 
14
15
  class CheckpointModel(BaseModel):
15
- """_summary_.
16
-
17
- Parameters
18
- ----------
19
- BaseModel : _type_
20
- _description_
21
- """
16
+ """Checkpoint saving callback Pydantic model."""
22
17
 
23
18
  model_config = ConfigDict(
24
19
  validate_assignment=True,
@@ -45,13 +40,7 @@ class CheckpointModel(BaseModel):
45
40
 
46
41
 
47
42
  class EarlyStoppingModel(BaseModel):
48
- """_summary_.
49
-
50
- Parameters
51
- ----------
52
- BaseModel : _type_
53
- _description_
54
- """
43
+ """Early stopping callback Pydantic model."""
55
44
 
56
45
  model_config = ConfigDict(
57
46
  validate_assignment=True,
@@ -1,3 +1,5 @@
1
+ """Example of configurations."""
2
+
1
3
  from .algorithm_model import AlgorithmConfig
2
4
  from .architectures import UNetModel
3
5
  from .configuration_model import Configuration
@@ -19,7 +21,7 @@ from .training_model import TrainingConfig
19
21
 
20
22
 
21
23
  def full_configuration_example() -> Configuration:
22
- """Returns a dictionnary representing a full configuration example.
24
+ """Return a dictionnary representing a full configuration example.
23
25
 
24
26
  Returns
25
27
  -------
@@ -56,12 +58,10 @@ def full_configuration_example() -> Configuration:
56
58
  "name": SupportedTransform.NORMALIZE.value,
57
59
  },
58
60
  {
59
- "name": SupportedTransform.NDFLIP.value,
60
- "is_3D": False,
61
+ "name": SupportedTransform.XY_FLIP.value,
61
62
  },
62
63
  {
63
64
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
64
- "is_3D": False,
65
65
  },
66
66
  {
67
67
  "name": SupportedTransform.N2V_MANIPULATE.value,