careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +92 -55
- careamics/config/__init__.py +0 -1
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +4 -15
- careamics/config/configuration_example.py +4 -4
- careamics/config/configuration_factory.py +113 -55
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +63 -165
- careamics/config/inference_model.py +9 -75
- careamics/config/optimizer_models.py +4 -4
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +13 -7
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +13 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +5 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +84 -76
- careamics/dataset/iterable_dataset.py +166 -134
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +56 -14
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +20 -14
- careamics/dataset/patching/tiled_patching.py +13 -7
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +63 -41
- careamics/lightning_module.py +9 -3
- careamics/lightning_prediction_datamodule.py +15 -20
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +1 -3
- careamics/losses/loss_factory.py +2 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +4 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +122 -25
- careamics/models/model_factory.py +2 -1
- careamics/models/unet.py +114 -19
- careamics/prediction/stitch_prediction.py +2 -5
- careamics/transforms/__init__.py +4 -25
- careamics/transforms/compose.py +124 -0
- careamics/transforms/n2v_manipulate.py +65 -34
- careamics/transforms/normalize.py +91 -28
- careamics/transforms/pixel_manipulation.py +7 -7
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +66 -60
- careamics/utils/__init__.py +0 -1
- careamics/utils/base_enum.py +28 -0
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -24
- careamics/config/transformations/nd_flip_model.py +0 -32
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -93
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Callback saving CAREamics configuration as hyperparameters in the model."""
|
|
2
|
+
|
|
1
3
|
from pytorch_lightning import LightningModule, Trainer
|
|
2
4
|
from pytorch_lightning.callbacks import Callback
|
|
3
5
|
|
|
@@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
|
|
|
11
13
|
This allows saving the configuration as dictionnary in the checkpoints, and
|
|
12
14
|
loading it subsequently in a CAREamist instance.
|
|
13
15
|
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
config : Configuration
|
|
19
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
20
|
+
|
|
14
21
|
Attributes
|
|
15
22
|
----------
|
|
16
23
|
config : Configuration
|
|
17
24
|
CAREamics configuration to be saved as hyperparameter in the model.
|
|
18
25
|
"""
|
|
19
26
|
|
|
20
|
-
def __init__(self, config: Configuration):
|
|
27
|
+
def __init__(self, config: Configuration) -> None:
|
|
21
28
|
"""
|
|
22
29
|
Constructor.
|
|
23
30
|
|
|
@@ -28,14 +35,14 @@ class HyperParametersCallback(Callback):
|
|
|
28
35
|
"""
|
|
29
36
|
self.config = config
|
|
30
37
|
|
|
31
|
-
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
|
|
38
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
32
39
|
"""
|
|
33
40
|
Update the hyperparameters of the model with the configuration on train start.
|
|
34
41
|
|
|
35
42
|
Parameters
|
|
36
43
|
----------
|
|
37
44
|
trainer : Trainer
|
|
38
|
-
PyTorch Lightning trainer.
|
|
45
|
+
PyTorch Lightning trainer, unused.
|
|
39
46
|
pl_module : LightningModule
|
|
40
47
|
PyTorch Lightning module.
|
|
41
48
|
"""
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Progressbar callback."""
|
|
2
|
+
|
|
1
3
|
import sys
|
|
2
4
|
from typing import Dict, Union
|
|
3
5
|
|
|
@@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
10
12
|
"""Progress bar for training and validation steps."""
|
|
11
13
|
|
|
12
14
|
def init_train_tqdm(self) -> tqdm:
|
|
13
|
-
"""Override this to customize the tqdm bar for training.
|
|
15
|
+
"""Override this to customize the tqdm bar for training.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
tqdm
|
|
20
|
+
A tqdm bar.
|
|
21
|
+
"""
|
|
14
22
|
bar = tqdm(
|
|
15
23
|
desc="Training",
|
|
16
24
|
position=(2 * self.process_position),
|
|
@@ -23,7 +31,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
23
31
|
return bar
|
|
24
32
|
|
|
25
33
|
def init_validation_tqdm(self) -> tqdm:
|
|
26
|
-
"""Override this to customize the tqdm bar for validation.
|
|
34
|
+
"""Override this to customize the tqdm bar for validation.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
tqdm
|
|
39
|
+
A tqdm bar.
|
|
40
|
+
"""
|
|
27
41
|
# The main progress bar doesn't exist in `trainer.validate()`
|
|
28
42
|
has_main_bar = self.train_progress_bar is not None
|
|
29
43
|
bar = tqdm(
|
|
@@ -37,7 +51,13 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
37
51
|
return bar
|
|
38
52
|
|
|
39
53
|
def init_test_tqdm(self) -> tqdm:
|
|
40
|
-
"""Override this to customize the tqdm bar for testing.
|
|
54
|
+
"""Override this to customize the tqdm bar for testing.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
tqdm
|
|
59
|
+
A tqdm bar.
|
|
60
|
+
"""
|
|
41
61
|
bar = tqdm(
|
|
42
62
|
desc="Testing",
|
|
43
63
|
position=(2 * self.process_position),
|
|
@@ -52,6 +72,19 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
52
72
|
def get_metrics(
|
|
53
73
|
self, trainer: Trainer, pl_module: LightningModule
|
|
54
74
|
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
|
|
55
|
-
"""Override this to customize the metrics displayed in the progress bar.
|
|
75
|
+
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
trainer : Trainer
|
|
80
|
+
The trainer object.
|
|
81
|
+
pl_module : LightningModule
|
|
82
|
+
The LightningModule object, unused.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
dict
|
|
87
|
+
A dictionary with the metrics to display in the progress bar.
|
|
88
|
+
"""
|
|
56
89
|
pbar_metrics = trainer.progress_bar_metrics
|
|
57
90
|
return {**pbar_metrics}
|
careamics/careamist.py
CHANGED
|
@@ -18,13 +18,14 @@ from careamics.config import (
|
|
|
18
18
|
create_inference_configuration,
|
|
19
19
|
load_configuration,
|
|
20
20
|
)
|
|
21
|
-
from careamics.config.inference_model import TRANSFORMS_UNION
|
|
22
21
|
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
|
|
22
|
+
from careamics.dataset.dataset_utils import reshape_array
|
|
23
23
|
from careamics.lightning_datamodule import CAREamicsTrainData
|
|
24
24
|
from careamics.lightning_module import CAREamicsModule
|
|
25
25
|
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
|
|
26
26
|
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
|
|
27
27
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
28
|
+
from careamics.transforms import Denormalize
|
|
28
29
|
from careamics.utils import check_path_exists, get_logger
|
|
29
30
|
|
|
30
31
|
from .callbacks import HyperParametersCallback
|
|
@@ -73,8 +74,7 @@ class CAREamist:
|
|
|
73
74
|
source: Union[Path, str],
|
|
74
75
|
work_dir: Optional[str] = None,
|
|
75
76
|
experiment_name: str = "CAREamics",
|
|
76
|
-
) -> None:
|
|
77
|
-
...
|
|
77
|
+
) -> None: ...
|
|
78
78
|
|
|
79
79
|
@overload
|
|
80
80
|
def __init__( # numpydoc ignore=GL08
|
|
@@ -82,8 +82,7 @@ class CAREamist:
|
|
|
82
82
|
source: Configuration,
|
|
83
83
|
work_dir: Optional[str] = None,
|
|
84
84
|
experiment_name: str = "CAREamics",
|
|
85
|
-
) -> None:
|
|
86
|
-
...
|
|
85
|
+
) -> None: ...
|
|
87
86
|
|
|
88
87
|
def __init__(
|
|
89
88
|
self,
|
|
@@ -478,8 +477,7 @@ class CAREamist:
|
|
|
478
477
|
source: CAREamicsPredictData,
|
|
479
478
|
*,
|
|
480
479
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
481
|
-
) -> Union[list, np.ndarray]:
|
|
482
|
-
...
|
|
480
|
+
) -> Union[list, np.ndarray]: ...
|
|
483
481
|
|
|
484
482
|
@overload
|
|
485
483
|
def predict( # numpydoc ignore=GL08
|
|
@@ -491,14 +489,12 @@ class CAREamist:
|
|
|
491
489
|
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
492
490
|
axes: Optional[str] = None,
|
|
493
491
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
494
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
495
492
|
tta_transforms: bool = True,
|
|
496
493
|
dataloader_params: Optional[Dict] = None,
|
|
497
494
|
read_source_func: Optional[Callable] = None,
|
|
498
495
|
extension_filter: str = "",
|
|
499
496
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
500
|
-
) -> Union[list, np.ndarray]:
|
|
501
|
-
...
|
|
497
|
+
) -> Union[list, np.ndarray]: ...
|
|
502
498
|
|
|
503
499
|
@overload
|
|
504
500
|
def predict( # numpydoc ignore=GL08
|
|
@@ -510,12 +506,10 @@ class CAREamist:
|
|
|
510
506
|
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
511
507
|
axes: Optional[str] = None,
|
|
512
508
|
data_type: Optional[Literal["array"]] = None,
|
|
513
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
514
509
|
tta_transforms: bool = True,
|
|
515
510
|
dataloader_params: Optional[Dict] = None,
|
|
516
511
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
517
|
-
) -> Union[list, np.ndarray]:
|
|
518
|
-
...
|
|
512
|
+
) -> Union[list, np.ndarray]: ...
|
|
519
513
|
|
|
520
514
|
def predict(
|
|
521
515
|
self,
|
|
@@ -526,7 +520,6 @@ class CAREamist:
|
|
|
526
520
|
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
527
521
|
axes: Optional[str] = None,
|
|
528
522
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
529
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
530
523
|
tta_transforms: bool = True,
|
|
531
524
|
dataloader_params: Optional[Dict] = None,
|
|
532
525
|
read_source_func: Optional[Callable] = None,
|
|
@@ -543,11 +536,15 @@ class CAREamist:
|
|
|
543
536
|
configuration parameters will be used, with the `patch_size` instead of
|
|
544
537
|
`tile_size`.
|
|
545
538
|
|
|
546
|
-
The default transforms are defined in the `InferenceModel` Pydantic model.
|
|
547
|
-
|
|
548
539
|
Test-time augmentation (TTA) can be switched off using the `tta_transforms`
|
|
549
540
|
parameter.
|
|
550
541
|
|
|
542
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
543
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
544
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
545
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
546
|
+
happen in the Z dimension, consider padding your image.
|
|
547
|
+
|
|
551
548
|
Parameters
|
|
552
549
|
----------
|
|
553
550
|
source : Union[CAREamicsClay, Path, str, np.ndarray]
|
|
@@ -562,8 +559,6 @@ class CAREamist:
|
|
|
562
559
|
Axes of the input data, by default None.
|
|
563
560
|
data_type : Optional[Literal["array", "tiff", "custom"]], optional
|
|
564
561
|
Type of the input data, by default None.
|
|
565
|
-
transforms : Optional[List[TRANSFORMS_UNION]], optional
|
|
566
|
-
List of transforms to apply to the data, by default None.
|
|
567
562
|
tta_transforms : bool, optional
|
|
568
563
|
Whether to apply test-time augmentation, by default True.
|
|
569
564
|
dataloader_params : Optional[Dict], optional
|
|
@@ -602,12 +597,11 @@ class CAREamist:
|
|
|
602
597
|
)
|
|
603
598
|
# create predict config, reuse training config if parameters missing
|
|
604
599
|
prediction_config = create_inference_configuration(
|
|
605
|
-
|
|
600
|
+
configuration=self.cfg,
|
|
606
601
|
tile_size=tile_size,
|
|
607
602
|
tile_overlap=tile_overlap,
|
|
608
603
|
data_type=data_type,
|
|
609
604
|
axes=axes,
|
|
610
|
-
transforms=transforms,
|
|
611
605
|
tta_transforms=tta_transforms,
|
|
612
606
|
batch_size=batch_size,
|
|
613
607
|
)
|
|
@@ -659,38 +653,41 @@ class CAREamist:
|
|
|
659
653
|
f"np.ndarray (got {type(source)})."
|
|
660
654
|
)
|
|
661
655
|
|
|
662
|
-
def
|
|
656
|
+
def _create_data_for_bmz(
|
|
663
657
|
self,
|
|
664
|
-
path: Union[Path, str],
|
|
665
|
-
name: str,
|
|
666
|
-
authors: List[dict],
|
|
667
658
|
input_array: Optional[np.ndarray] = None,
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
data_description: Optional[str] = None,
|
|
671
|
-
) -> None:
|
|
672
|
-
"""Export the model to the BioImage Model Zoo format.
|
|
659
|
+
) -> np.ndarray:
|
|
660
|
+
"""Create data for BMZ export.
|
|
673
661
|
|
|
674
|
-
|
|
662
|
+
If no `input_array` is provided, this method checks if there is a prediction
|
|
663
|
+
datamodule, or a training data module, to extract a patch. If none exists,
|
|
664
|
+
then a random aray is created.
|
|
665
|
+
|
|
666
|
+
If there is a non-singleton batch dimension, this method returns only the first
|
|
667
|
+
element.
|
|
675
668
|
|
|
676
669
|
Parameters
|
|
677
670
|
----------
|
|
678
|
-
path : Union[Path, str]
|
|
679
|
-
Path to save the model.
|
|
680
|
-
name : str
|
|
681
|
-
Name of the model.
|
|
682
|
-
authors : List[dict]
|
|
683
|
-
List of authors of the model.
|
|
684
671
|
input_array : Optional[np.ndarray], optional
|
|
685
|
-
Input array
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
672
|
+
Input array, by default None.
|
|
673
|
+
|
|
674
|
+
Returns
|
|
675
|
+
-------
|
|
676
|
+
np.ndarray
|
|
677
|
+
Input data for BMZ export.
|
|
678
|
+
|
|
679
|
+
Raises
|
|
680
|
+
------
|
|
681
|
+
ValueError
|
|
682
|
+
If mean and std are not provided in the configuration.
|
|
692
683
|
"""
|
|
693
684
|
if input_array is None:
|
|
685
|
+
if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
|
|
686
|
+
raise ValueError(
|
|
687
|
+
"Mean and std cannot be None in the configuration in order to"
|
|
688
|
+
"export to the BMZ format. Was the model trained?"
|
|
689
|
+
)
|
|
690
|
+
|
|
694
691
|
# generate images, priority is given to the prediction data module
|
|
695
692
|
if self.pred_datamodule is not None:
|
|
696
693
|
# unpack a batch, ignore masks or targets
|
|
@@ -698,19 +695,23 @@ class CAREamist:
|
|
|
698
695
|
|
|
699
696
|
# convert torch.Tensor to numpy
|
|
700
697
|
input_patch = input_patch.numpy()
|
|
698
|
+
|
|
699
|
+
# denormalize
|
|
700
|
+
denormalize = Denormalize(
|
|
701
|
+
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
|
|
702
|
+
)
|
|
703
|
+
input_patch, _ = denormalize(input_patch)
|
|
704
|
+
|
|
701
705
|
elif self.train_datamodule is not None:
|
|
702
706
|
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
|
|
703
707
|
input_patch = input_patch.numpy()
|
|
704
|
-
else:
|
|
705
|
-
if (
|
|
706
|
-
self.cfg.data_config.mean is None
|
|
707
|
-
or self.cfg.data_config.std is None
|
|
708
|
-
):
|
|
709
|
-
raise ValueError(
|
|
710
|
-
"Mean and std cannot be None in the configuration in order to"
|
|
711
|
-
"export to the BMZ format. Was the model trained?"
|
|
712
|
-
)
|
|
713
708
|
|
|
709
|
+
# denormalize
|
|
710
|
+
denormalize = Denormalize(
|
|
711
|
+
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
|
|
712
|
+
)
|
|
713
|
+
input_patch, _ = denormalize(input_patch)
|
|
714
|
+
else:
|
|
714
715
|
# create a random input array
|
|
715
716
|
input_patch = np.random.normal(
|
|
716
717
|
loc=self.cfg.data_config.mean,
|
|
@@ -720,11 +721,47 @@ class CAREamist:
|
|
|
720
721
|
np.newaxis, np.newaxis, ...
|
|
721
722
|
] # add S & C dimensions
|
|
722
723
|
else:
|
|
723
|
-
|
|
724
|
+
# potentially correct shape
|
|
725
|
+
input_patch = reshape_array(input_array, self.cfg.data_config.axes)
|
|
724
726
|
|
|
725
|
-
# if
|
|
727
|
+
# if this a batch
|
|
726
728
|
if input_patch.shape[0] > 1:
|
|
727
|
-
input_patch = input_patch[0
|
|
729
|
+
input_patch = input_patch[[0], ...] # keep singleton dim
|
|
730
|
+
|
|
731
|
+
return input_patch
|
|
732
|
+
|
|
733
|
+
def export_to_bmz(
|
|
734
|
+
self,
|
|
735
|
+
path: Union[Path, str],
|
|
736
|
+
name: str,
|
|
737
|
+
authors: List[dict],
|
|
738
|
+
input_array: Optional[np.ndarray] = None,
|
|
739
|
+
general_description: str = "",
|
|
740
|
+
channel_names: Optional[List[str]] = None,
|
|
741
|
+
data_description: Optional[str] = None,
|
|
742
|
+
) -> None:
|
|
743
|
+
"""Export the model to the BioImage Model Zoo format.
|
|
744
|
+
|
|
745
|
+
Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
|
|
746
|
+
|
|
747
|
+
Parameters
|
|
748
|
+
----------
|
|
749
|
+
path : Union[Path, str]
|
|
750
|
+
Path to save the model.
|
|
751
|
+
name : str
|
|
752
|
+
Name of the model.
|
|
753
|
+
authors : List[dict]
|
|
754
|
+
List of authors of the model.
|
|
755
|
+
input_array : Optional[np.ndarray], optional
|
|
756
|
+
Input array for the model, must be of shape SC(Z)YX, by default None.
|
|
757
|
+
general_description : str
|
|
758
|
+
General description of the model, used in the metadata of the BMZ archive.
|
|
759
|
+
channel_names : Optional[List[str]], optional
|
|
760
|
+
Channel names, by default None.
|
|
761
|
+
data_description : Optional[str], optional
|
|
762
|
+
Description of the data, by default None.
|
|
763
|
+
"""
|
|
764
|
+
input_patch = self._create_data_for_bmz(input_array)
|
|
728
765
|
|
|
729
766
|
# axes need to be reformated for the export because reshaping was done in the
|
|
730
767
|
# datamodule
|
careamics/config/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Algorithm configuration."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from pprint import pformat
|
|
@@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel):
|
|
|
17
19
|
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
18
20
|
and learning rate scheduler to use.
|
|
19
21
|
|
|
20
|
-
Currently, we only support N2V and custom
|
|
21
|
-
compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
22
|
-
you to register your own architecture and select it using its name as
|
|
22
|
+
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
|
|
23
|
+
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
24
|
+
allows you to register your own architecture and select it using its name as
|
|
23
25
|
`name` in the custom pydantic model.
|
|
24
26
|
|
|
25
27
|
Attributes
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Base model for the various CAREamics architectures."""
|
|
2
|
+
|
|
1
3
|
from typing import Any, Dict
|
|
2
4
|
|
|
3
5
|
from pydantic import BaseModel
|
|
@@ -16,6 +18,11 @@ class ArchitectureModel(BaseModel):
|
|
|
16
18
|
"""
|
|
17
19
|
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
18
20
|
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
**kwargs : Any
|
|
24
|
+
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
25
|
+
|
|
19
26
|
Returns
|
|
20
27
|
-------
|
|
21
28
|
dict[str, Any]
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Custom architecture Pydantic model."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from pprint import pformat
|
|
@@ -84,6 +86,11 @@ class CustomModel(ArchitectureModel):
|
|
|
84
86
|
value : str
|
|
85
87
|
Name of the custom model as registered using the `@register_model`
|
|
86
88
|
decorator.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
str
|
|
93
|
+
The custom model name.
|
|
87
94
|
"""
|
|
88
95
|
# delegate error to get_custom_model
|
|
89
96
|
model = get_custom_model(value)
|
|
@@ -134,7 +141,7 @@ class CustomModel(ArchitectureModel):
|
|
|
134
141
|
|
|
135
142
|
Parameters
|
|
136
143
|
----------
|
|
137
|
-
kwargs : Any
|
|
144
|
+
**kwargs : Any
|
|
138
145
|
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
139
146
|
|
|
140
147
|
Returns
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Custom model registration utilities."""
|
|
2
|
+
|
|
1
3
|
from typing import Callable
|
|
2
4
|
|
|
3
5
|
from torch.nn import Module
|
|
@@ -53,7 +55,7 @@ def register_model(name: str) -> Callable:
|
|
|
53
55
|
Parameters
|
|
54
56
|
----------
|
|
55
57
|
model : Module
|
|
56
|
-
Module class to register
|
|
58
|
+
Module class to register.
|
|
57
59
|
|
|
58
60
|
Returns
|
|
59
61
|
-------
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""UNet Pydantic model."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from typing import Literal
|
|
@@ -39,6 +41,7 @@ class UNetModel(ArchitectureModel):
|
|
|
39
41
|
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
40
42
|
] = Field(default="None", validate_default=True)
|
|
41
43
|
n2v2: bool = Field(default=False, validate_default=True)
|
|
44
|
+
independent_channels: bool = Field(default=True, validate_default=True)
|
|
42
45
|
|
|
43
46
|
@field_validator("num_channels_init")
|
|
44
47
|
@classmethod
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Callback Pydantic models."""
|
|
2
|
+
|
|
2
3
|
from __future__ import annotations
|
|
3
4
|
|
|
4
5
|
from datetime import timedelta
|
|
@@ -12,13 +13,7 @@ from pydantic import (
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class CheckpointModel(BaseModel):
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
Parameters
|
|
18
|
-
----------
|
|
19
|
-
BaseModel : _type_
|
|
20
|
-
_description_
|
|
21
|
-
"""
|
|
16
|
+
"""Checkpoint saving callback Pydantic model."""
|
|
22
17
|
|
|
23
18
|
model_config = ConfigDict(
|
|
24
19
|
validate_assignment=True,
|
|
@@ -45,13 +40,7 @@ class CheckpointModel(BaseModel):
|
|
|
45
40
|
|
|
46
41
|
|
|
47
42
|
class EarlyStoppingModel(BaseModel):
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
Parameters
|
|
51
|
-
----------
|
|
52
|
-
BaseModel : _type_
|
|
53
|
-
_description_
|
|
54
|
-
"""
|
|
43
|
+
"""Early stopping callback Pydantic model."""
|
|
55
44
|
|
|
56
45
|
model_config = ConfigDict(
|
|
57
46
|
validate_assignment=True,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Example of configurations."""
|
|
2
|
+
|
|
1
3
|
from .algorithm_model import AlgorithmConfig
|
|
2
4
|
from .architectures import UNetModel
|
|
3
5
|
from .configuration_model import Configuration
|
|
@@ -19,7 +21,7 @@ from .training_model import TrainingConfig
|
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def full_configuration_example() -> Configuration:
|
|
22
|
-
"""
|
|
24
|
+
"""Return a dictionnary representing a full configuration example.
|
|
23
25
|
|
|
24
26
|
Returns
|
|
25
27
|
-------
|
|
@@ -56,12 +58,10 @@ def full_configuration_example() -> Configuration:
|
|
|
56
58
|
"name": SupportedTransform.NORMALIZE.value,
|
|
57
59
|
},
|
|
58
60
|
{
|
|
59
|
-
"name": SupportedTransform.
|
|
60
|
-
"is_3D": False,
|
|
61
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
61
62
|
},
|
|
62
63
|
{
|
|
63
64
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
64
|
-
"is_3D": False,
|
|
65
65
|
},
|
|
66
66
|
{
|
|
67
67
|
"name": SupportedTransform.N2V_MANIPULATE.value,
|