careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +80 -44
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -2
- careamics/config/configuration_factory.py +4 -16
- careamics/config/data_model.py +10 -14
- careamics/config/inference_model.py +0 -65
- careamics/config/optimizer_models.py +4 -4
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/conftest.py +12 -0
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +71 -32
- careamics/dataset/iterable_dataset.py +155 -68
- careamics/dataset/patching/patching.py +56 -15
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/tiled_patching.py +3 -1
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +45 -19
- careamics/lightning_module.py +8 -2
- careamics/lightning_prediction_datamodule.py +3 -13
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/bmz_io.py +3 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction/stitch_prediction.py +2 -6
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +49 -13
- careamics/transforms/normalize.py +55 -3
- careamics/transforms/pixel_manipulation.py +5 -5
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -67
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Callback saving CAREamics configuration as hyperparameters in the model."""
|
|
2
|
+
|
|
1
3
|
from pytorch_lightning import LightningModule, Trainer
|
|
2
4
|
from pytorch_lightning.callbacks import Callback
|
|
3
5
|
|
|
@@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
|
|
|
11
13
|
This allows saving the configuration as dictionnary in the checkpoints, and
|
|
12
14
|
loading it subsequently in a CAREamist instance.
|
|
13
15
|
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
config : Configuration
|
|
19
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
20
|
+
|
|
14
21
|
Attributes
|
|
15
22
|
----------
|
|
16
23
|
config : Configuration
|
|
17
24
|
CAREamics configuration to be saved as hyperparameter in the model.
|
|
18
25
|
"""
|
|
19
26
|
|
|
20
|
-
def __init__(self, config: Configuration):
|
|
27
|
+
def __init__(self, config: Configuration) -> None:
|
|
21
28
|
"""
|
|
22
29
|
Constructor.
|
|
23
30
|
|
|
@@ -28,14 +35,14 @@ class HyperParametersCallback(Callback):
|
|
|
28
35
|
"""
|
|
29
36
|
self.config = config
|
|
30
37
|
|
|
31
|
-
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
38
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
32
39
|
"""
|
|
33
40
|
Update the hyperparameters of the model with the configuration on train start.
|
|
34
41
|
|
|
35
42
|
Parameters
|
|
36
43
|
----------
|
|
37
44
|
trainer : Trainer
|
|
38
|
-
PyTorch Lightning trainer.
|
|
45
|
+
PyTorch Lightning trainer, unused.
|
|
39
46
|
pl_module : LightningModule
|
|
40
47
|
PyTorch Lightning module.
|
|
41
48
|
"""
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Progressbar callback."""
|
|
2
|
+
|
|
1
3
|
import sys
|
|
2
4
|
from typing import Dict, Union
|
|
3
5
|
|
|
@@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
10
12
|
"""Progress bar for training and validation steps."""
|
|
11
13
|
|
|
12
14
|
def init_train_tqdm(self) -> tqdm:
|
|
13
|
-
"""Override this to customize the tqdm bar for training.
|
|
15
|
+
"""Override this to customize the tqdm bar for training.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
tqdm
|
|
20
|
+
A tqdm bar.
|
|
21
|
+
"""
|
|
14
22
|
bar = tqdm(
|
|
15
23
|
desc="Training",
|
|
16
24
|
position=(2 * self.process_position),
|
|
@@ -23,7 +31,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
23
31
|
return bar
|
|
24
32
|
|
|
25
33
|
def init_validation_tqdm(self) -> tqdm:
|
|
26
|
-
"""Override this to customize the tqdm bar for validation.
|
|
34
|
+
"""Override this to customize the tqdm bar for validation.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
tqdm
|
|
39
|
+
A tqdm bar.
|
|
40
|
+
"""
|
|
27
41
|
# The main progress bar doesn't exist in `trainer.validate()`
|
|
28
42
|
has_main_bar = self.train_progress_bar is not None
|
|
29
43
|
bar = tqdm(
|
|
@@ -37,7 +51,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
37
51
|
return bar
|
|
38
52
|
|
|
39
53
|
def init_test_tqdm(self) -> tqdm:
|
|
40
|
-
"""Override this to customize the tqdm bar for testing.
|
|
54
|
+
"""Override this to customize the tqdm bar for testing.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
tqdm
|
|
59
|
+
A tqdm bar.
|
|
60
|
+
"""
|
|
41
61
|
bar = tqdm(
|
|
42
62
|
desc="Testing",
|
|
43
63
|
position=(2 * self.process_position),
|
|
@@ -52,6 +72,19 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
52
72
|
def get_metrics(
|
|
53
73
|
self, trainer: Trainer, pl_module: LightningModule
|
|
54
74
|
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
|
|
55
|
-
"""Override this to customize the metrics displayed in the progress bar.
|
|
75
|
+
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
trainer : Trainer
|
|
80
|
+
The trainer object.
|
|
81
|
+
pl_module : LightningModule
|
|
82
|
+
The LightningModule object, unused.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
dict
|
|
87
|
+
A dictionary with the metrics to display in the progress bar.
|
|
88
|
+
"""
|
|
56
89
|
pbar_metrics = trainer.progress_bar_metrics
|
|
57
90
|
return {**pbar_metrics}
|
careamics/careamist.py
CHANGED
|
@@ -18,13 +18,14 @@ from careamics.config import (
|
|
|
18
18
|
create_inference_configuration,
|
|
19
19
|
load_configuration,
|
|
20
20
|
)
|
|
21
|
-
from careamics.config.inference_model import TRANSFORMS_UNION
|
|
22
21
|
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
|
|
22
|
+
from careamics.dataset.dataset_utils import reshape_array
|
|
23
23
|
from careamics.lightning_datamodule import CAREamicsTrainData
|
|
24
24
|
from careamics.lightning_module import CAREamicsModule
|
|
25
25
|
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
|
|
26
26
|
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
|
|
27
27
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
28
|
+
from careamics.transforms import Denormalize
|
|
28
29
|
from careamics.utils import check_path_exists, get_logger
|
|
29
30
|
|
|
30
31
|
from .callbacks import HyperParametersCallback
|
|
@@ -488,7 +489,6 @@ class CAREamist:
|
|
|
488
489
|
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
489
490
|
axes: Optional[str] = None,
|
|
490
491
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
491
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
492
492
|
tta_transforms: bool = True,
|
|
493
493
|
dataloader_params: Optional[Dict] = None,
|
|
494
494
|
read_source_func: Optional[Callable] = None,
|
|
@@ -506,7 +506,6 @@ class CAREamist:
|
|
|
506
506
|
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
507
507
|
axes: Optional[str] = None,
|
|
508
508
|
data_type: Optional[Literal["array"]] = None,
|
|
509
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
510
509
|
tta_transforms: bool = True,
|
|
511
510
|
dataloader_params: Optional[Dict] = None,
|
|
512
511
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
@@ -521,7 +520,6 @@ class CAREamist:
|
|
|
521
520
|
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
522
521
|
axes: Optional[str] = None,
|
|
523
522
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
524
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
525
523
|
tta_transforms: bool = True,
|
|
526
524
|
dataloader_params: Optional[Dict] = None,
|
|
527
525
|
read_source_func: Optional[Callable] = None,
|
|
@@ -538,8 +536,6 @@ class CAREamist:
|
|
|
538
536
|
configuration parameters will be used, with the `patch_size` instead of
|
|
539
537
|
`tile_size`.
|
|
540
538
|
|
|
541
|
-
The default transforms are defined in the `InferenceModel` Pydantic model.
|
|
542
|
-
|
|
543
539
|
Test-time augmentation (TTA) can be switched off using the `tta_transforms`
|
|
544
540
|
parameter.
|
|
545
541
|
|
|
@@ -563,8 +559,6 @@ class CAREamist:
|
|
|
563
559
|
Axes of the input data, by default None.
|
|
564
560
|
data_type : Optional[Literal["array", "tiff", "custom"]], optional
|
|
565
561
|
Type of the input data, by default None.
|
|
566
|
-
transforms : Optional[List[TRANSFORMS_UNION]], optional
|
|
567
|
-
List of transforms to apply to the data, by default None.
|
|
568
562
|
tta_transforms : bool, optional
|
|
569
563
|
Whether to apply test-time augmentation, by default True.
|
|
570
564
|
dataloader_params : Optional[Dict], optional
|
|
@@ -608,7 +602,6 @@ class CAREamist:
|
|
|
608
602
|
tile_overlap=tile_overlap,
|
|
609
603
|
data_type=data_type,
|
|
610
604
|
axes=axes,
|
|
611
|
-
transforms=transforms,
|
|
612
605
|
tta_transforms=tta_transforms,
|
|
613
606
|
batch_size=batch_size,
|
|
614
607
|
)
|
|
@@ -660,38 +653,41 @@ class CAREamist:
|
|
|
660
653
|
f"np.ndarray (got {type(source)})."
|
|
661
654
|
)
|
|
662
655
|
|
|
663
|
-
def
|
|
656
|
+
def _create_data_for_bmz(
|
|
664
657
|
self,
|
|
665
|
-
path: Union[Path, str],
|
|
666
|
-
name: str,
|
|
667
|
-
authors: List[dict],
|
|
668
658
|
input_array: Optional[np.ndarray] = None,
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
data_description: Optional[str] = None,
|
|
672
|
-
) -> None:
|
|
673
|
-
"""Export the model to the BioImage Model Zoo format.
|
|
659
|
+
) -> np.ndarray:
|
|
660
|
+
"""Create data for BMZ export.
|
|
674
661
|
|
|
675
|
-
|
|
662
|
+
If no `input_array` is provided, this method checks if there is a prediction
|
|
663
|
+
datamodule, or a training data module, to extract a patch. If none exists,
|
|
664
|
+
then a random aray is created.
|
|
665
|
+
|
|
666
|
+
If there is a non-singleton batch dimension, this method returns only the first
|
|
667
|
+
element.
|
|
676
668
|
|
|
677
669
|
Parameters
|
|
678
670
|
----------
|
|
679
|
-
path : Union[Path, str]
|
|
680
|
-
Path to save the model.
|
|
681
|
-
name : str
|
|
682
|
-
Name of the model.
|
|
683
|
-
authors : List[dict]
|
|
684
|
-
List of authors of the model.
|
|
685
671
|
input_array : Optional[np.ndarray], optional
|
|
686
|
-
Input array
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
672
|
+
Input array, by default None.
|
|
673
|
+
|
|
674
|
+
Returns
|
|
675
|
+
-------
|
|
676
|
+
np.ndarray
|
|
677
|
+
Input data for BMZ export.
|
|
678
|
+
|
|
679
|
+
Raises
|
|
680
|
+
------
|
|
681
|
+
ValueError
|
|
682
|
+
If mean and std are not provided in the configuration.
|
|
693
683
|
"""
|
|
694
684
|
if input_array is None:
|
|
685
|
+
if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
|
|
686
|
+
raise ValueError(
|
|
687
|
+
"Mean and std cannot be None in the configuration in order to"
|
|
688
|
+
"export to the BMZ format. Was the model trained?"
|
|
689
|
+
)
|
|
690
|
+
|
|
695
691
|
# generate images, priority is given to the prediction data module
|
|
696
692
|
if self.pred_datamodule is not None:
|
|
697
693
|
# unpack a batch, ignore masks or targets
|
|
@@ -699,19 +695,23 @@ class CAREamist:
|
|
|
699
695
|
|
|
700
696
|
# convert torch.Tensor to numpy
|
|
701
697
|
input_patch = input_patch.numpy()
|
|
698
|
+
|
|
699
|
+
# denormalize
|
|
700
|
+
denormalize = Denormalize(
|
|
701
|
+
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
|
|
702
|
+
)
|
|
703
|
+
input_patch, _ = denormalize(input_patch)
|
|
704
|
+
|
|
702
705
|
elif self.train_datamodule is not None:
|
|
703
706
|
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
|
|
704
707
|
input_patch = input_patch.numpy()
|
|
705
|
-
else:
|
|
706
|
-
if (
|
|
707
|
-
self.cfg.data_config.mean is None
|
|
708
|
-
or self.cfg.data_config.std is None
|
|
709
|
-
):
|
|
710
|
-
raise ValueError(
|
|
711
|
-
"Mean and std cannot be None in the configuration in order to"
|
|
712
|
-
"export to the BMZ format. Was the model trained?"
|
|
713
|
-
)
|
|
714
708
|
|
|
709
|
+
# denormalize
|
|
710
|
+
denormalize = Denormalize(
|
|
711
|
+
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
|
|
712
|
+
)
|
|
713
|
+
input_patch, _ = denormalize(input_patch)
|
|
714
|
+
else:
|
|
715
715
|
# create a random input array
|
|
716
716
|
input_patch = np.random.normal(
|
|
717
717
|
loc=self.cfg.data_config.mean,
|
|
@@ -721,11 +721,47 @@ class CAREamist:
|
|
|
721
721
|
np.newaxis, np.newaxis, ...
|
|
722
722
|
] # add S & C dimensions
|
|
723
723
|
else:
|
|
724
|
-
|
|
724
|
+
# potentially correct shape
|
|
725
|
+
input_patch = reshape_array(input_array, self.cfg.data_config.axes)
|
|
725
726
|
|
|
726
|
-
# if
|
|
727
|
+
# if this a batch
|
|
727
728
|
if input_patch.shape[0] > 1:
|
|
728
|
-
input_patch = input_patch[0
|
|
729
|
+
input_patch = input_patch[[0], ...] # keep singleton dim
|
|
730
|
+
|
|
731
|
+
return input_patch
|
|
732
|
+
|
|
733
|
+
def export_to_bmz(
|
|
734
|
+
self,
|
|
735
|
+
path: Union[Path, str],
|
|
736
|
+
name: str,
|
|
737
|
+
authors: List[dict],
|
|
738
|
+
input_array: Optional[np.ndarray] = None,
|
|
739
|
+
general_description: str = "",
|
|
740
|
+
channel_names: Optional[List[str]] = None,
|
|
741
|
+
data_description: Optional[str] = None,
|
|
742
|
+
) -> None:
|
|
743
|
+
"""Export the model to the BioImage Model Zoo format.
|
|
744
|
+
|
|
745
|
+
Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
|
|
746
|
+
|
|
747
|
+
Parameters
|
|
748
|
+
----------
|
|
749
|
+
path : Union[Path, str]
|
|
750
|
+
Path to save the model.
|
|
751
|
+
name : str
|
|
752
|
+
Name of the model.
|
|
753
|
+
authors : List[dict]
|
|
754
|
+
List of authors of the model.
|
|
755
|
+
input_array : Optional[np.ndarray], optional
|
|
756
|
+
Input array for the model, must be of shape SC(Z)YX, by default None.
|
|
757
|
+
general_description : str
|
|
758
|
+
General description of the model, used in the metadata of the BMZ archive.
|
|
759
|
+
channel_names : Optional[List[str]], optional
|
|
760
|
+
Channel names, by default None.
|
|
761
|
+
data_description : Optional[str], optional
|
|
762
|
+
Description of the data, by default None.
|
|
763
|
+
"""
|
|
764
|
+
input_patch = self._create_data_for_bmz(input_array)
|
|
729
765
|
|
|
730
766
|
# axes need to be reformated for the export because reshaping was done in the
|
|
731
767
|
# datamodule
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Algorithm configuration."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from pprint import pformat
|
|
@@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel):
|
|
|
17
19
|
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
18
20
|
and learning rate scheduler to use.
|
|
19
21
|
|
|
20
|
-
Currently, we only support N2V and custom
|
|
21
|
-
compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
22
|
-
you to register your own architecture and select it using its name as
|
|
22
|
+
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
|
|
23
|
+
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
24
|
+
allows you to register your own architecture and select it using its name as
|
|
23
25
|
`name` in the custom pydantic model.
|
|
24
26
|
|
|
25
27
|
Attributes
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Base model for the various CAREamics architectures."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Dict
|
|
2
4
|
|
|
3
5
|
from pydantic import BaseModel
|
|
@@ -16,6 +18,11 @@ class ArchitectureModel(BaseModel):
|
|
|
16
18
|
"""
|
|
17
19
|
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
18
20
|
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
**kwargs : Any
|
|
24
|
+
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
25
|
+
|
|
19
26
|
Returns
|
|
20
27
|
-------
|
|
21
28
|
dict[str, Any]
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Custom architecture Pydantic model."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from pprint import pformat
|
|
@@ -84,6 +86,11 @@ class CustomModel(ArchitectureModel):
|
|
|
84
86
|
value : str
|
|
85
87
|
Name of the custom model as registered using the `@register_model`
|
|
86
88
|
decorator.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
str
|
|
93
|
+
The custom model name.
|
|
87
94
|
"""
|
|
88
95
|
# delegate error to get_custom_model
|
|
89
96
|
model = get_custom_model(value)
|
|
@@ -134,7 +141,7 @@ class CustomModel(ArchitectureModel):
|
|
|
134
141
|
|
|
135
142
|
Parameters
|
|
136
143
|
----------
|
|
137
|
-
kwargs : Any
|
|
144
|
+
**kwargs : Any
|
|
138
145
|
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
139
146
|
|
|
140
147
|
Returns
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Custom model registration utilities."""
|
|
2
|
+
|
|
1
3
|
from typing import Callable
|
|
2
4
|
|
|
3
5
|
from torch.nn import Module
|
|
@@ -53,7 +55,7 @@ def register_model(name: str) -> Callable:
|
|
|
53
55
|
Parameters
|
|
54
56
|
----------
|
|
55
57
|
model : Module
|
|
56
|
-
Module class to register
|
|
58
|
+
Module class to register.
|
|
57
59
|
|
|
58
60
|
Returns
|
|
59
61
|
-------
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Callback Pydantic models."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
@@ -13,13 +13,7 @@ from pydantic import (
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class CheckpointModel(BaseModel):
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
Parameters
|
|
19
|
-
----------
|
|
20
|
-
BaseModel : _type_
|
|
21
|
-
_description_
|
|
22
|
-
"""
|
|
16
|
+
"""Checkpoint saving callback Pydantic model."""
|
|
23
17
|
|
|
24
18
|
model_config = ConfigDict(
|
|
25
19
|
validate_assignment=True,
|
|
@@ -46,13 +40,7 @@ class CheckpointModel(BaseModel):
|
|
|
46
40
|
|
|
47
41
|
|
|
48
42
|
class EarlyStoppingModel(BaseModel):
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
Parameters
|
|
52
|
-
----------
|
|
53
|
-
BaseModel : _type_
|
|
54
|
-
_description_
|
|
55
|
-
"""
|
|
43
|
+
"""Early stopping callback Pydantic model."""
|
|
56
44
|
|
|
57
45
|
model_config = ConfigDict(
|
|
58
46
|
validate_assignment=True,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Example of configurations."""
|
|
2
|
+
|
|
1
3
|
from .algorithm_model import AlgorithmConfig
|
|
2
4
|
from .architectures import UNetModel
|
|
3
5
|
from .configuration_model import Configuration
|
|
@@ -19,7 +21,7 @@ from .training_model import TrainingConfig
|
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def full_configuration_example() -> Configuration:
|
|
22
|
-
"""
|
|
24
|
+
"""Return a dictionnary representing a full configuration example.
|
|
23
25
|
|
|
24
26
|
Returns
|
|
25
27
|
-------
|
|
@@ -56,7 +58,7 @@ def full_configuration_example() -> Configuration:
|
|
|
56
58
|
"name": SupportedTransform.NORMALIZE.value,
|
|
57
59
|
},
|
|
58
60
|
{
|
|
59
|
-
"name": SupportedTransform.
|
|
61
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
60
62
|
},
|
|
61
63
|
{
|
|
62
64
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Convenience functions to create configurations for training and inference."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
3
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
from .algorithm_model import AlgorithmConfig
|
|
6
6
|
from .architectures import UNetModel
|
|
@@ -111,7 +111,7 @@ def _create_supervised_configuration(
|
|
|
111
111
|
"name": SupportedTransform.NORMALIZE.value,
|
|
112
112
|
},
|
|
113
113
|
{
|
|
114
|
-
"name": SupportedTransform.
|
|
114
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
115
115
|
},
|
|
116
116
|
{
|
|
117
117
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
@@ -526,7 +526,7 @@ def create_n2v_configuration(
|
|
|
526
526
|
"name": SupportedTransform.NORMALIZE.value,
|
|
527
527
|
},
|
|
528
528
|
{
|
|
529
|
-
"name": SupportedTransform.
|
|
529
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
530
530
|
},
|
|
531
531
|
{
|
|
532
532
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
@@ -587,7 +587,6 @@ def create_inference_configuration(
|
|
|
587
587
|
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
588
588
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
589
589
|
axes: Optional[str] = None,
|
|
590
|
-
transforms: Optional[Union[List[Dict[str, Any]]]] = None,
|
|
591
590
|
tta_transforms: bool = True,
|
|
592
591
|
batch_size: Optional[int] = 1,
|
|
593
592
|
) -> InferenceConfig:
|
|
@@ -595,7 +594,7 @@ def create_inference_configuration(
|
|
|
595
594
|
Create a configuration for inference with N2V.
|
|
596
595
|
|
|
597
596
|
If not provided, `data_type` and `axes` are taken from the training
|
|
598
|
-
configuration.
|
|
597
|
+
configuration.
|
|
599
598
|
|
|
600
599
|
Parameters
|
|
601
600
|
----------
|
|
@@ -609,8 +608,6 @@ def create_inference_configuration(
|
|
|
609
608
|
Type of the data, by default "tiff".
|
|
610
609
|
axes : str, optional
|
|
611
610
|
Axes of the data, by default "YX".
|
|
612
|
-
transforms : List[Dict[str, Any]], optional
|
|
613
|
-
Transformations to apply to the data, by default None.
|
|
614
611
|
tta_transforms : bool, optional
|
|
615
612
|
Whether to apply test-time augmentations, by default True.
|
|
616
613
|
batch_size : int, optional
|
|
@@ -624,14 +621,6 @@ def create_inference_configuration(
|
|
|
624
621
|
if configuration.data_config.mean is None or configuration.data_config.std is None:
|
|
625
622
|
raise ValueError("Mean and std must be provided in the configuration.")
|
|
626
623
|
|
|
627
|
-
# minimum transform
|
|
628
|
-
if transforms is None:
|
|
629
|
-
transforms = [
|
|
630
|
-
{
|
|
631
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
632
|
-
},
|
|
633
|
-
]
|
|
634
|
-
|
|
635
624
|
# tile size for UNets
|
|
636
625
|
if tile_size is not None:
|
|
637
626
|
model = configuration.algorithm_config.model
|
|
@@ -661,7 +650,6 @@ def create_inference_configuration(
|
|
|
661
650
|
axes=axes or configuration.data_config.axes,
|
|
662
651
|
mean=configuration.data_config.mean,
|
|
663
652
|
std=configuration.data_config.std,
|
|
664
|
-
transforms=transforms,
|
|
665
653
|
tta_transforms=tta_transforms,
|
|
666
654
|
batch_size=batch_size,
|
|
667
655
|
)
|
careamics/config/data_model.py
CHANGED
|
@@ -17,14 +17,14 @@ from typing_extensions import Annotated, Self
|
|
|
17
17
|
|
|
18
18
|
from .support import SupportedTransform
|
|
19
19
|
from .transformations.n2v_manipulate_model import N2VManipulateModel
|
|
20
|
-
from .transformations.nd_flip_model import NDFlipModel
|
|
21
20
|
from .transformations.normalize_model import NormalizeModel
|
|
21
|
+
from .transformations.xy_flip_model import XYFlipModel
|
|
22
22
|
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
23
23
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
24
24
|
|
|
25
25
|
TRANSFORMS_UNION = Annotated[
|
|
26
26
|
Union[
|
|
27
|
-
|
|
27
|
+
XYFlipModel,
|
|
28
28
|
XYRandomRotate90Model,
|
|
29
29
|
NormalizeModel,
|
|
30
30
|
N2VManipulateModel,
|
|
@@ -41,6 +41,8 @@ class DataConfig(BaseModel):
|
|
|
41
41
|
and then the mean (if they were both `None` before) will raise a validation error.
|
|
42
42
|
Prefer instead `set_mean_and_std` to set both at once.
|
|
43
43
|
|
|
44
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
45
|
+
|
|
44
46
|
Examples
|
|
45
47
|
--------
|
|
46
48
|
Minimum example:
|
|
@@ -56,7 +58,7 @@ class DataConfig(BaseModel):
|
|
|
56
58
|
>>> data.set_mean_and_std(mean=214.3, std=84.5)
|
|
57
59
|
|
|
58
60
|
One can pass also a list of transformations, by keyword, using the
|
|
59
|
-
SupportedTransform
|
|
61
|
+
SupportedTransform value:
|
|
60
62
|
>>> from careamics.config.support import SupportedTransform
|
|
61
63
|
>>> data = DataConfig(
|
|
62
64
|
... data_type="tiff",
|
|
@@ -70,7 +72,7 @@ class DataConfig(BaseModel):
|
|
|
70
72
|
... "std": 47.2,
|
|
71
73
|
... },
|
|
72
74
|
... {
|
|
73
|
-
... "name": "
|
|
75
|
+
... "name": "XYFlip",
|
|
74
76
|
... }
|
|
75
77
|
... ]
|
|
76
78
|
... )
|
|
@@ -97,7 +99,7 @@ class DataConfig(BaseModel):
|
|
|
97
99
|
"name": SupportedTransform.NORMALIZE.value,
|
|
98
100
|
},
|
|
99
101
|
{
|
|
100
|
-
"name": SupportedTransform.
|
|
102
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
101
103
|
},
|
|
102
104
|
{
|
|
103
105
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
@@ -202,7 +204,7 @@ class DataConfig(BaseModel):
|
|
|
202
204
|
|
|
203
205
|
if SupportedTransform.N2V_MANIPULATE in transform_list:
|
|
204
206
|
# multiple N2V_MANIPULATE
|
|
205
|
-
if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
|
|
207
|
+
if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
|
|
206
208
|
raise ValueError(
|
|
207
209
|
f"Multiple instances of "
|
|
208
210
|
f"{SupportedTransform.N2V_MANIPULATE} transforms "
|
|
@@ -211,7 +213,7 @@ class DataConfig(BaseModel):
|
|
|
211
213
|
|
|
212
214
|
# N2V_MANIPULATE not the last transform
|
|
213
215
|
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
|
|
214
|
-
index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
|
|
216
|
+
index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
|
|
215
217
|
transform = transforms.pop(index)
|
|
216
218
|
transforms.append(transform)
|
|
217
219
|
|
|
@@ -250,7 +252,7 @@ class DataConfig(BaseModel):
|
|
|
250
252
|
Self
|
|
251
253
|
Data model with mean and std added to the Normalize transform.
|
|
252
254
|
"""
|
|
253
|
-
if self.mean is not None
|
|
255
|
+
if self.mean is not None and self.std is not None:
|
|
254
256
|
# search in the transforms for Normalize and update parameters
|
|
255
257
|
for transform in self.transforms:
|
|
256
258
|
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
@@ -355,12 +357,6 @@ class DataConfig(BaseModel):
|
|
|
355
357
|
"""
|
|
356
358
|
self._update(mean=mean, std=std)
|
|
357
359
|
|
|
358
|
-
# search in the transforms for Normalize and update parameters
|
|
359
|
-
for transform in self.transforms:
|
|
360
|
-
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
361
|
-
transform.mean = mean
|
|
362
|
-
transform.std = std
|
|
363
|
-
|
|
364
360
|
def set_3D(self, axes: str, patch_size: List[int]) -> None:
|
|
365
361
|
"""
|
|
366
362
|
Set 3D parameters.
|