careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (81) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +80 -44
  4. careamics/config/algorithm_model.py +5 -3
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +8 -1
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -2
  12. careamics/config/configuration_factory.py +4 -16
  13. careamics/config/data_model.py +10 -14
  14. careamics/config/inference_model.py +0 -65
  15. careamics/config/optimizer_models.py +4 -4
  16. careamics/config/support/__init__.py +0 -2
  17. careamics/config/support/supported_activations.py +2 -0
  18. careamics/config/support/supported_algorithms.py +3 -1
  19. careamics/config/support/supported_architectures.py +2 -0
  20. careamics/config/support/supported_data.py +2 -0
  21. careamics/config/support/supported_loggers.py +2 -0
  22. careamics/config/support/supported_losses.py +2 -0
  23. careamics/config/support/supported_optimizers.py +2 -0
  24. careamics/config/support/supported_pixel_manipulations.py +3 -3
  25. careamics/config/support/supported_struct_axis.py +2 -0
  26. careamics/config/support/supported_transforms.py +4 -15
  27. careamics/config/tile_information.py +2 -0
  28. careamics/config/transformations/__init__.py +3 -2
  29. careamics/config/transformations/xy_flip_model.py +43 -0
  30. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  31. careamics/conftest.py +12 -0
  32. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  33. careamics/dataset/dataset_utils/file_utils.py +4 -3
  34. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  35. careamics/dataset/dataset_utils/read_utils.py +2 -0
  36. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  37. careamics/dataset/in_memory_dataset.py +71 -32
  38. careamics/dataset/iterable_dataset.py +155 -68
  39. careamics/dataset/patching/patching.py +56 -15
  40. careamics/dataset/patching/random_patching.py +8 -2
  41. careamics/dataset/patching/sequential_patching.py +14 -8
  42. careamics/dataset/patching/tiled_patching.py +3 -1
  43. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  44. careamics/dataset/zarr_dataset.py +2 -0
  45. careamics/lightning_datamodule.py +45 -19
  46. careamics/lightning_module.py +8 -2
  47. careamics/lightning_prediction_datamodule.py +3 -13
  48. careamics/lightning_prediction_loop.py +8 -6
  49. careamics/losses/__init__.py +2 -3
  50. careamics/losses/loss_factory.py +1 -1
  51. careamics/losses/losses.py +11 -7
  52. careamics/model_io/bmz_io.py +3 -3
  53. careamics/models/activation.py +2 -0
  54. careamics/models/layers.py +121 -25
  55. careamics/models/model_factory.py +1 -1
  56. careamics/models/unet.py +35 -14
  57. careamics/prediction/stitch_prediction.py +2 -6
  58. careamics/transforms/__init__.py +2 -2
  59. careamics/transforms/compose.py +33 -7
  60. careamics/transforms/n2v_manipulate.py +49 -13
  61. careamics/transforms/normalize.py +55 -3
  62. careamics/transforms/pixel_manipulation.py +5 -5
  63. careamics/transforms/struct_mask_parameters.py +3 -1
  64. careamics/transforms/transform.py +10 -19
  65. careamics/transforms/xy_flip.py +123 -0
  66. careamics/transforms/xy_random_rotate90.py +38 -5
  67. careamics/utils/base_enum.py +28 -0
  68. careamics/utils/path_utils.py +2 -0
  69. careamics/utils/ram.py +2 -0
  70. careamics/utils/receptive_field.py +93 -87
  71. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
  72. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  73. careamics/config/noise_models.py +0 -162
  74. careamics/config/support/supported_extraction_strategies.py +0 -25
  75. careamics/config/transformations/nd_flip_model.py +0 -27
  76. careamics/losses/noise_model_factory.py +0 -40
  77. careamics/losses/noise_models.py +0 -524
  78. careamics/transforms/nd_flip.py +0 -67
  79. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  80. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  81. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ """Callback saving CAREamics configuration as hyperparameters in the model."""
2
+
1
3
  from pytorch_lightning import LightningModule, Trainer
2
4
  from pytorch_lightning.callbacks import Callback
3
5
 
@@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
11
13
  This allows saving the configuration as dictionnary in the checkpoints, and
12
14
  loading it subsequently in a CAREamist instance.
13
15
 
16
+ Parameters
17
+ ----------
18
+ config : Configuration
19
+ CAREamics configuration to be saved as hyperparameter in the model.
20
+
14
21
  Attributes
15
22
  ----------
16
23
  config : Configuration
17
24
  CAREamics configuration to be saved as hyperparameter in the model.
18
25
  """
19
26
 
20
- def __init__(self, config: Configuration):
27
+ def __init__(self, config: Configuration) -> None:
21
28
  """
22
29
  Constructor.
23
30
 
@@ -28,14 +35,14 @@ class HyperParametersCallback(Callback):
28
35
  """
29
36
  self.config = config
30
37
 
31
- def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
38
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
32
39
  """
33
40
  Update the hyperparameters of the model with the configuration on train start.
34
41
 
35
42
  Parameters
36
43
  ----------
37
44
  trainer : Trainer
38
- PyTorch Lightning trainer.
45
+ PyTorch Lightning trainer, unused.
39
46
  pl_module : LightningModule
40
47
  PyTorch Lightning module.
41
48
  """
@@ -1,3 +1,5 @@
1
+ """Progressbar callback."""
2
+
1
3
  import sys
2
4
  from typing import Dict, Union
3
5
 
@@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
10
12
  """Progress bar for training and validation steps."""
11
13
 
12
14
  def init_train_tqdm(self) -> tqdm:
13
- """Override this to customize the tqdm bar for training."""
15
+ """Override this to customize the tqdm bar for training.
16
+
17
+ Returns
18
+ -------
19
+ tqdm
20
+ A tqdm bar.
21
+ """
14
22
  bar = tqdm(
15
23
  desc="Training",
16
24
  position=(2 * self.process_position),
@@ -23,7 +31,13 @@ class ProgressBarCallback(TQDMProgressBar):
23
31
  return bar
24
32
 
25
33
  def init_validation_tqdm(self) -> tqdm:
26
- """Override this to customize the tqdm bar for validation."""
34
+ """Override this to customize the tqdm bar for validation.
35
+
36
+ Returns
37
+ -------
38
+ tqdm
39
+ A tqdm bar.
40
+ """
27
41
  # The main progress bar doesn't exist in `trainer.validate()`
28
42
  has_main_bar = self.train_progress_bar is not None
29
43
  bar = tqdm(
@@ -37,7 +51,13 @@ class ProgressBarCallback(TQDMProgressBar):
37
51
  return bar
38
52
 
39
53
  def init_test_tqdm(self) -> tqdm:
40
- """Override this to customize the tqdm bar for testing."""
54
+ """Override this to customize the tqdm bar for testing.
55
+
56
+ Returns
57
+ -------
58
+ tqdm
59
+ A tqdm bar.
60
+ """
41
61
  bar = tqdm(
42
62
  desc="Testing",
43
63
  position=(2 * self.process_position),
@@ -52,6 +72,19 @@ class ProgressBarCallback(TQDMProgressBar):
52
72
  def get_metrics(
53
73
  self, trainer: Trainer, pl_module: LightningModule
54
74
  ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
55
- """Override this to customize the metrics displayed in the progress bar."""
75
+ """Override this to customize the metrics displayed in the progress bar.
76
+
77
+ Parameters
78
+ ----------
79
+ trainer : Trainer
80
+ The trainer object.
81
+ pl_module : LightningModule
82
+ The LightningModule object, unused.
83
+
84
+ Returns
85
+ -------
86
+ dict
87
+ A dictionary with the metrics to display in the progress bar.
88
+ """
56
89
  pbar_metrics = trainer.progress_bar_metrics
57
90
  return {**pbar_metrics}
careamics/careamist.py CHANGED
@@ -18,13 +18,14 @@ from careamics.config import (
18
18
  create_inference_configuration,
19
19
  load_configuration,
20
20
  )
21
- from careamics.config.inference_model import TRANSFORMS_UNION
22
21
  from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
22
+ from careamics.dataset.dataset_utils import reshape_array
23
23
  from careamics.lightning_datamodule import CAREamicsTrainData
24
24
  from careamics.lightning_module import CAREamicsModule
25
25
  from careamics.lightning_prediction_datamodule import CAREamicsPredictData
26
26
  from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
27
27
  from careamics.model_io import export_to_bmz, load_pretrained
28
+ from careamics.transforms import Denormalize
28
29
  from careamics.utils import check_path_exists, get_logger
29
30
 
30
31
  from .callbacks import HyperParametersCallback
@@ -488,7 +489,6 @@ class CAREamist:
488
489
  tile_overlap: Tuple[int, ...] = (48, 48),
489
490
  axes: Optional[str] = None,
490
491
  data_type: Optional[Literal["tiff", "custom"]] = None,
491
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
492
492
  tta_transforms: bool = True,
493
493
  dataloader_params: Optional[Dict] = None,
494
494
  read_source_func: Optional[Callable] = None,
@@ -506,7 +506,6 @@ class CAREamist:
506
506
  tile_overlap: Tuple[int, ...] = (48, 48),
507
507
  axes: Optional[str] = None,
508
508
  data_type: Optional[Literal["array"]] = None,
509
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
510
509
  tta_transforms: bool = True,
511
510
  dataloader_params: Optional[Dict] = None,
512
511
  checkpoint: Optional[Literal["best", "last"]] = None,
@@ -521,7 +520,6 @@ class CAREamist:
521
520
  tile_overlap: Tuple[int, ...] = (48, 48),
522
521
  axes: Optional[str] = None,
523
522
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
524
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
525
523
  tta_transforms: bool = True,
526
524
  dataloader_params: Optional[Dict] = None,
527
525
  read_source_func: Optional[Callable] = None,
@@ -538,8 +536,6 @@ class CAREamist:
538
536
  configuration parameters will be used, with the `patch_size` instead of
539
537
  `tile_size`.
540
538
 
541
- The default transforms are defined in the `InferenceModel` Pydantic model.
542
-
543
539
  Test-time augmentation (TTA) can be switched off using the `tta_transforms`
544
540
  parameter.
545
541
 
@@ -563,8 +559,6 @@ class CAREamist:
563
559
  Axes of the input data, by default None.
564
560
  data_type : Optional[Literal["array", "tiff", "custom"]], optional
565
561
  Type of the input data, by default None.
566
- transforms : Optional[List[TRANSFORMS_UNION]], optional
567
- List of transforms to apply to the data, by default None.
568
562
  tta_transforms : bool, optional
569
563
  Whether to apply test-time augmentation, by default True.
570
564
  dataloader_params : Optional[Dict], optional
@@ -608,7 +602,6 @@ class CAREamist:
608
602
  tile_overlap=tile_overlap,
609
603
  data_type=data_type,
610
604
  axes=axes,
611
- transforms=transforms,
612
605
  tta_transforms=tta_transforms,
613
606
  batch_size=batch_size,
614
607
  )
@@ -660,38 +653,41 @@ class CAREamist:
660
653
  f"np.ndarray (got {type(source)})."
661
654
  )
662
655
 
663
- def export_to_bmz(
656
+ def _create_data_for_bmz(
664
657
  self,
665
- path: Union[Path, str],
666
- name: str,
667
- authors: List[dict],
668
658
  input_array: Optional[np.ndarray] = None,
669
- general_description: str = "",
670
- channel_names: Optional[List[str]] = None,
671
- data_description: Optional[str] = None,
672
- ) -> None:
673
- """Export the model to the BioImage Model Zoo format.
659
+ ) -> np.ndarray:
660
+ """Create data for BMZ export.
674
661
 
675
- Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
662
+ If no `input_array` is provided, this method checks if there is a prediction
663
+ datamodule, or a training data module, to extract a patch. If none exists,
664
+ then a random aray is created.
665
+
666
+ If there is a non-singleton batch dimension, this method returns only the first
667
+ element.
676
668
 
677
669
  Parameters
678
670
  ----------
679
- path : Union[Path, str]
680
- Path to save the model.
681
- name : str
682
- Name of the model.
683
- authors : List[dict]
684
- List of authors of the model.
685
671
  input_array : Optional[np.ndarray], optional
686
- Input array for the model, must be of shape SC(Z)YX, by default None.
687
- general_description : str
688
- General description of the model, used in the metadata of the BMZ archive.
689
- channel_names : Optional[List[str]], optional
690
- Channel names, by default None.
691
- data_description : Optional[str], optional
692
- Description of the data, by default None.
672
+ Input array, by default None.
673
+
674
+ Returns
675
+ -------
676
+ np.ndarray
677
+ Input data for BMZ export.
678
+
679
+ Raises
680
+ ------
681
+ ValueError
682
+ If mean and std are not provided in the configuration.
693
683
  """
694
684
  if input_array is None:
685
+ if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
686
+ raise ValueError(
687
+ "Mean and std cannot be None in the configuration in order to"
688
+ "export to the BMZ format. Was the model trained?"
689
+ )
690
+
695
691
  # generate images, priority is given to the prediction data module
696
692
  if self.pred_datamodule is not None:
697
693
  # unpack a batch, ignore masks or targets
@@ -699,19 +695,23 @@ class CAREamist:
699
695
 
700
696
  # convert torch.Tensor to numpy
701
697
  input_patch = input_patch.numpy()
698
+
699
+ # denormalize
700
+ denormalize = Denormalize(
701
+ mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
702
+ )
703
+ input_patch, _ = denormalize(input_patch)
704
+
702
705
  elif self.train_datamodule is not None:
703
706
  input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
704
707
  input_patch = input_patch.numpy()
705
- else:
706
- if (
707
- self.cfg.data_config.mean is None
708
- or self.cfg.data_config.std is None
709
- ):
710
- raise ValueError(
711
- "Mean and std cannot be None in the configuration in order to"
712
- "export to the BMZ format. Was the model trained?"
713
- )
714
708
 
709
+ # denormalize
710
+ denormalize = Denormalize(
711
+ mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
712
+ )
713
+ input_patch, _ = denormalize(input_patch)
714
+ else:
715
715
  # create a random input array
716
716
  input_patch = np.random.normal(
717
717
  loc=self.cfg.data_config.mean,
@@ -721,11 +721,47 @@ class CAREamist:
721
721
  np.newaxis, np.newaxis, ...
722
722
  ] # add S & C dimensions
723
723
  else:
724
- input_patch = input_array
724
+ # potentially correct shape
725
+ input_patch = reshape_array(input_array, self.cfg.data_config.axes)
725
726
 
726
- # if there is a batch dimension
727
+ # if this a batch
727
728
  if input_patch.shape[0] > 1:
728
- input_patch = input_patch[0:1, ...] # keep singleton dim
729
+ input_patch = input_patch[[0], ...] # keep singleton dim
730
+
731
+ return input_patch
732
+
733
+ def export_to_bmz(
734
+ self,
735
+ path: Union[Path, str],
736
+ name: str,
737
+ authors: List[dict],
738
+ input_array: Optional[np.ndarray] = None,
739
+ general_description: str = "",
740
+ channel_names: Optional[List[str]] = None,
741
+ data_description: Optional[str] = None,
742
+ ) -> None:
743
+ """Export the model to the BioImage Model Zoo format.
744
+
745
+ Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
746
+
747
+ Parameters
748
+ ----------
749
+ path : Union[Path, str]
750
+ Path to save the model.
751
+ name : str
752
+ Name of the model.
753
+ authors : List[dict]
754
+ List of authors of the model.
755
+ input_array : Optional[np.ndarray], optional
756
+ Input array for the model, must be of shape SC(Z)YX, by default None.
757
+ general_description : str
758
+ General description of the model, used in the metadata of the BMZ archive.
759
+ channel_names : Optional[List[str]], optional
760
+ Channel names, by default None.
761
+ data_description : Optional[str], optional
762
+ Description of the data, by default None.
763
+ """
764
+ input_patch = self._create_data_for_bmz(input_array)
729
765
 
730
766
  # axes need to be reformated for the export because reshaping was done in the
731
767
  # datamodule
@@ -1,3 +1,5 @@
1
+ """Algorithm configuration."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from pprint import pformat
@@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel):
17
19
  training algorithm: which algorithm, loss function, model architecture, optimizer,
18
20
  and learning rate scheduler to use.
19
21
 
20
- Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only
21
- compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows
22
- you to register your own architecture and select it using its name as
22
+ Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
23
+ only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
24
+ allows you to register your own architecture and select it using its name as
23
25
  `name` in the custom pydantic model.
24
26
 
25
27
  Attributes
@@ -1,3 +1,5 @@
1
+ """Base model for the various CAREamics architectures."""
2
+
1
3
  from typing import Any, Dict
2
4
 
3
5
  from pydantic import BaseModel
@@ -16,6 +18,11 @@ class ArchitectureModel(BaseModel):
16
18
  """
17
19
  Dump the model as a dictionary, ignoring the architecture keyword.
18
20
 
21
+ Parameters
22
+ ----------
23
+ **kwargs : Any
24
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
25
+
19
26
  Returns
20
27
  -------
21
28
  dict[str, Any]
@@ -1,3 +1,5 @@
1
+ """Custom architecture Pydantic model."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from pprint import pformat
@@ -84,6 +86,11 @@ class CustomModel(ArchitectureModel):
84
86
  value : str
85
87
  Name of the custom model as registered using the `@register_model`
86
88
  decorator.
89
+
90
+ Returns
91
+ -------
92
+ str
93
+ The custom model name.
87
94
  """
88
95
  # delegate error to get_custom_model
89
96
  model = get_custom_model(value)
@@ -134,7 +141,7 @@ class CustomModel(ArchitectureModel):
134
141
 
135
142
  Parameters
136
143
  ----------
137
- kwargs : Any
144
+ **kwargs : Any
138
145
  Additional keyword arguments from Pydantic BaseModel model_dump method.
139
146
 
140
147
  Returns
@@ -1,3 +1,5 @@
1
+ """Custom model registration utilities."""
2
+
1
3
  from typing import Callable
2
4
 
3
5
  from torch.nn import Module
@@ -53,7 +55,7 @@ def register_model(name: str) -> Callable:
53
55
  Parameters
54
56
  ----------
55
57
  model : Module
56
- Module class to register
58
+ Module class to register.
57
59
 
58
60
  Returns
59
61
  -------
@@ -1,3 +1,5 @@
1
+ """UNet Pydantic model."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Literal
@@ -1,3 +1,5 @@
1
+ """VAE Pydantic model."""
2
+
1
3
  from typing import Literal
2
4
 
3
5
  from pydantic import (
@@ -1,4 +1,4 @@
1
- """Checkpoint saving configuration."""
1
+ """Callback Pydantic models."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
@@ -13,13 +13,7 @@ from pydantic import (
13
13
 
14
14
 
15
15
  class CheckpointModel(BaseModel):
16
- """_summary_.
17
-
18
- Parameters
19
- ----------
20
- BaseModel : _type_
21
- _description_
22
- """
16
+ """Checkpoint saving callback Pydantic model."""
23
17
 
24
18
  model_config = ConfigDict(
25
19
  validate_assignment=True,
@@ -46,13 +40,7 @@ class CheckpointModel(BaseModel):
46
40
 
47
41
 
48
42
  class EarlyStoppingModel(BaseModel):
49
- """_summary_.
50
-
51
- Parameters
52
- ----------
53
- BaseModel : _type_
54
- _description_
55
- """
43
+ """Early stopping callback Pydantic model."""
56
44
 
57
45
  model_config = ConfigDict(
58
46
  validate_assignment=True,
@@ -1,3 +1,5 @@
1
+ """Example of configurations."""
2
+
1
3
  from .algorithm_model import AlgorithmConfig
2
4
  from .architectures import UNetModel
3
5
  from .configuration_model import Configuration
@@ -19,7 +21,7 @@ from .training_model import TrainingConfig
19
21
 
20
22
 
21
23
  def full_configuration_example() -> Configuration:
22
- """Returns a dictionnary representing a full configuration example.
24
+ """Return a dictionnary representing a full configuration example.
23
25
 
24
26
  Returns
25
27
  -------
@@ -56,7 +58,7 @@ def full_configuration_example() -> Configuration:
56
58
  "name": SupportedTransform.NORMALIZE.value,
57
59
  },
58
60
  {
59
- "name": SupportedTransform.NDFLIP.value,
61
+ "name": SupportedTransform.XY_FLIP.value,
60
62
  },
61
63
  {
62
64
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -1,6 +1,6 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
3
+ from typing import Any, Dict, List, Literal, Optional, Tuple
4
4
 
5
5
  from .algorithm_model import AlgorithmConfig
6
6
  from .architectures import UNetModel
@@ -111,7 +111,7 @@ def _create_supervised_configuration(
111
111
  "name": SupportedTransform.NORMALIZE.value,
112
112
  },
113
113
  {
114
- "name": SupportedTransform.NDFLIP.value,
114
+ "name": SupportedTransform.XY_FLIP.value,
115
115
  },
116
116
  {
117
117
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -526,7 +526,7 @@ def create_n2v_configuration(
526
526
  "name": SupportedTransform.NORMALIZE.value,
527
527
  },
528
528
  {
529
- "name": SupportedTransform.NDFLIP.value,
529
+ "name": SupportedTransform.XY_FLIP.value,
530
530
  },
531
531
  {
532
532
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -587,7 +587,6 @@ def create_inference_configuration(
587
587
  tile_overlap: Optional[Tuple[int, ...]] = None,
588
588
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
589
589
  axes: Optional[str] = None,
590
- transforms: Optional[Union[List[Dict[str, Any]]]] = None,
591
590
  tta_transforms: bool = True,
592
591
  batch_size: Optional[int] = 1,
593
592
  ) -> InferenceConfig:
@@ -595,7 +594,7 @@ def create_inference_configuration(
595
594
  Create a configuration for inference with N2V.
596
595
 
597
596
  If not provided, `data_type` and `axes` are taken from the training
598
- configuration. If `transforms` are not provided, only normalization is applied.
597
+ configuration.
599
598
 
600
599
  Parameters
601
600
  ----------
@@ -609,8 +608,6 @@ def create_inference_configuration(
609
608
  Type of the data, by default "tiff".
610
609
  axes : str, optional
611
610
  Axes of the data, by default "YX".
612
- transforms : List[Dict[str, Any]], optional
613
- Transformations to apply to the data, by default None.
614
611
  tta_transforms : bool, optional
615
612
  Whether to apply test-time augmentations, by default True.
616
613
  batch_size : int, optional
@@ -624,14 +621,6 @@ def create_inference_configuration(
624
621
  if configuration.data_config.mean is None or configuration.data_config.std is None:
625
622
  raise ValueError("Mean and std must be provided in the configuration.")
626
623
 
627
- # minimum transform
628
- if transforms is None:
629
- transforms = [
630
- {
631
- "name": SupportedTransform.NORMALIZE.value,
632
- },
633
- ]
634
-
635
624
  # tile size for UNets
636
625
  if tile_size is not None:
637
626
  model = configuration.algorithm_config.model
@@ -661,7 +650,6 @@ def create_inference_configuration(
661
650
  axes=axes or configuration.data_config.axes,
662
651
  mean=configuration.data_config.mean,
663
652
  std=configuration.data_config.std,
664
- transforms=transforms,
665
653
  tta_transforms=tta_transforms,
666
654
  batch_size=batch_size,
667
655
  )
@@ -17,14 +17,14 @@ from typing_extensions import Annotated, Self
17
17
 
18
18
  from .support import SupportedTransform
19
19
  from .transformations.n2v_manipulate_model import N2VManipulateModel
20
- from .transformations.nd_flip_model import NDFlipModel
21
20
  from .transformations.normalize_model import NormalizeModel
21
+ from .transformations.xy_flip_model import XYFlipModel
22
22
  from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
23
23
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
24
24
 
25
25
  TRANSFORMS_UNION = Annotated[
26
26
  Union[
27
- NDFlipModel,
27
+ XYFlipModel,
28
28
  XYRandomRotate90Model,
29
29
  NormalizeModel,
30
30
  N2VManipulateModel,
@@ -41,6 +41,8 @@ class DataConfig(BaseModel):
41
41
  and then the mean (if they were both `None` before) will raise a validation error.
42
42
  Prefer instead `set_mean_and_std` to set both at once.
43
43
 
44
+ All supported transforms are defined in the SupportedTransform enum.
45
+
44
46
  Examples
45
47
  --------
46
48
  Minimum example:
@@ -56,7 +58,7 @@ class DataConfig(BaseModel):
56
58
  >>> data.set_mean_and_std(mean=214.3, std=84.5)
57
59
 
58
60
  One can pass also a list of transformations, by keyword, using the
59
- SupportedTransform or the name of an Albumentation transform:
61
+ SupportedTransform value:
60
62
  >>> from careamics.config.support import SupportedTransform
61
63
  >>> data = DataConfig(
62
64
  ... data_type="tiff",
@@ -70,7 +72,7 @@ class DataConfig(BaseModel):
70
72
  ... "std": 47.2,
71
73
  ... },
72
74
  ... {
73
- ... "name": "NDFlip",
75
+ ... "name": "XYFlip",
74
76
  ... }
75
77
  ... ]
76
78
  ... )
@@ -97,7 +99,7 @@ class DataConfig(BaseModel):
97
99
  "name": SupportedTransform.NORMALIZE.value,
98
100
  },
99
101
  {
100
- "name": SupportedTransform.NDFLIP.value,
102
+ "name": SupportedTransform.XY_FLIP.value,
101
103
  },
102
104
  {
103
105
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -202,7 +204,7 @@ class DataConfig(BaseModel):
202
204
 
203
205
  if SupportedTransform.N2V_MANIPULATE in transform_list:
204
206
  # multiple N2V_MANIPULATE
205
- if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
207
+ if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
206
208
  raise ValueError(
207
209
  f"Multiple instances of "
208
210
  f"{SupportedTransform.N2V_MANIPULATE} transforms "
@@ -211,7 +213,7 @@ class DataConfig(BaseModel):
211
213
 
212
214
  # N2V_MANIPULATE not the last transform
213
215
  elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
214
- index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
216
+ index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
215
217
  transform = transforms.pop(index)
216
218
  transforms.append(transform)
217
219
 
@@ -250,7 +252,7 @@ class DataConfig(BaseModel):
250
252
  Self
251
253
  Data model with mean and std added to the Normalize transform.
252
254
  """
253
- if self.mean is not None or self.std is not None:
255
+ if self.mean is not None and self.std is not None:
254
256
  # search in the transforms for Normalize and update parameters
255
257
  for transform in self.transforms:
256
258
  if transform.name == SupportedTransform.NORMALIZE.value:
@@ -355,12 +357,6 @@ class DataConfig(BaseModel):
355
357
  """
356
358
  self._update(mean=mean, std=std)
357
359
 
358
- # search in the transforms for Normalize and update parameters
359
- for transform in self.transforms:
360
- if transform.name == SupportedTransform.NORMALIZE.value:
361
- transform.mean = mean
362
- transform.std = std
363
-
364
360
  def set_3D(self, axes: str, patch_size: List[int]) -> None:
365
361
  """
366
362
  Set 3D parameters.