careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""Function to export to the BioImage Model Zoo format."""
|
|
2
|
+
|
|
3
|
+
import tempfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from bioimageio.core import load_model_description, test_model
|
|
9
|
+
from bioimageio.spec import ValidationSummary, save_bioimageio_package
|
|
10
|
+
from pydantic import HttpUrl
|
|
11
|
+
from torch import __version__ as PYTORCH_VERSION
|
|
12
|
+
from torch import load, save
|
|
13
|
+
from torchvision import __version__ as TORCHVISION_VERSION
|
|
14
|
+
|
|
15
|
+
from careamics.config import Configuration, load_configuration, save_configuration
|
|
16
|
+
from careamics.config.support import SupportedArchitecture
|
|
17
|
+
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
18
|
+
from careamics.utils.version import get_careamics_version
|
|
19
|
+
|
|
20
|
+
from .bioimage import (
|
|
21
|
+
create_env_text,
|
|
22
|
+
create_model_description,
|
|
23
|
+
extract_model_path,
|
|
24
|
+
)
|
|
25
|
+
from .bioimage.cover_factory import create_cover
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _export_state_dict(
|
|
29
|
+
model: Union[FCNModule, VAEModule], path: Union[Path, str]
|
|
30
|
+
) -> Path:
|
|
31
|
+
"""
|
|
32
|
+
Export the model state dictionary to a file.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
model : CAREamicsKiln
|
|
37
|
+
CAREamics model to export.
|
|
38
|
+
path : Union[Path, str]
|
|
39
|
+
Path to the file where to save the model state dictionary.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
Path
|
|
44
|
+
Path to the saved model state dictionary.
|
|
45
|
+
"""
|
|
46
|
+
path = Path(path)
|
|
47
|
+
|
|
48
|
+
# make sure it has the correct suffix
|
|
49
|
+
if path.suffix not in ".pth":
|
|
50
|
+
path = path.with_suffix(".pth")
|
|
51
|
+
|
|
52
|
+
# save model state dictionary
|
|
53
|
+
# we save through the torch model itself to avoid the initial "model." in the
|
|
54
|
+
# layers naming, which is incompatible with the way the BMZ load torch state dicts
|
|
55
|
+
save(model.model.state_dict(), path)
|
|
56
|
+
|
|
57
|
+
return path
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _load_state_dict(
|
|
61
|
+
model: Union[FCNModule, VAEModule], path: Union[Path, str]
|
|
62
|
+
) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Load a model from a state dictionary.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
model : CAREamicsKiln
|
|
69
|
+
CAREamics model to be updated with the weights.
|
|
70
|
+
path : Union[Path, str]
|
|
71
|
+
Path to the model state dictionary.
|
|
72
|
+
"""
|
|
73
|
+
path = Path(path)
|
|
74
|
+
|
|
75
|
+
# load model state dictionary
|
|
76
|
+
# same as in _export_state_dict, we load through the torch model to be compatible
|
|
77
|
+
# witht bioimageio.core expectations for a torch state dict
|
|
78
|
+
state_dict = load(path)
|
|
79
|
+
model.model.load_state_dict(state_dict)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# TODO break down in subfunctions
|
|
83
|
+
def export_to_bmz(
|
|
84
|
+
model: Union[FCNModule, VAEModule],
|
|
85
|
+
config: Configuration,
|
|
86
|
+
path_to_archive: Union[Path, str],
|
|
87
|
+
model_name: str,
|
|
88
|
+
general_description: str,
|
|
89
|
+
data_description: str,
|
|
90
|
+
authors: list[dict],
|
|
91
|
+
input_array: np.ndarray,
|
|
92
|
+
output_array: np.ndarray,
|
|
93
|
+
covers: list[Union[Path, str]] | None = None,
|
|
94
|
+
channel_names: list[str] | None = None,
|
|
95
|
+
model_version: str = "0.1.0",
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Export the model to BioImage Model Zoo format.
|
|
98
|
+
|
|
99
|
+
Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C.
|
|
100
|
+
|
|
101
|
+
`model_name` should consist of letters, numbers, dashes, underscores and parentheses
|
|
102
|
+
only.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
model : CAREamicsModule
|
|
107
|
+
CAREamics model to export.
|
|
108
|
+
config : Configuration
|
|
109
|
+
Model configuration.
|
|
110
|
+
path_to_archive : Union[Path, str]
|
|
111
|
+
Path to the output file.
|
|
112
|
+
model_name : str
|
|
113
|
+
Model name.
|
|
114
|
+
general_description : str
|
|
115
|
+
General description of the model.
|
|
116
|
+
data_description : str
|
|
117
|
+
Description of the data the model was trained on.
|
|
118
|
+
authors : list[dict]
|
|
119
|
+
Authors of the model.
|
|
120
|
+
input_array : np.ndarray
|
|
121
|
+
Input array, should not have been normalized.
|
|
122
|
+
output_array : np.ndarray
|
|
123
|
+
Output array, should have been denormalized.
|
|
124
|
+
covers : list of pathlib.Path or str, default=None
|
|
125
|
+
Paths to the cover images.
|
|
126
|
+
channel_names : Optional[list[str]], optional
|
|
127
|
+
Channel names, by default None.
|
|
128
|
+
model_version : str, default="0.1.0"
|
|
129
|
+
Model version.
|
|
130
|
+
"""
|
|
131
|
+
path_to_archive = Path(path_to_archive)
|
|
132
|
+
|
|
133
|
+
if path_to_archive.suffix != ".zip":
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"Path to archive must point to a zip file, got {path_to_archive}."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if not path_to_archive.parent.exists():
|
|
139
|
+
path_to_archive.parent.mkdir(parents=True, exist_ok=True)
|
|
140
|
+
|
|
141
|
+
# versions
|
|
142
|
+
careamics_version = get_careamics_version()
|
|
143
|
+
|
|
144
|
+
# save files in temporary folder
|
|
145
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
146
|
+
temp_path = Path(tmpdirname)
|
|
147
|
+
|
|
148
|
+
# create environment file
|
|
149
|
+
# TODO move in bioimage module
|
|
150
|
+
env_path = temp_path / "environment.yml"
|
|
151
|
+
env_path.write_text(create_env_text(PYTORCH_VERSION, TORCHVISION_VERSION))
|
|
152
|
+
|
|
153
|
+
# export input and ouputs
|
|
154
|
+
inputs = temp_path / "inputs.npy"
|
|
155
|
+
np.save(inputs, input_array)
|
|
156
|
+
outputs = temp_path / "outputs.npy"
|
|
157
|
+
np.save(outputs, output_array)
|
|
158
|
+
|
|
159
|
+
# export configuration
|
|
160
|
+
config_path = save_configuration(config, temp_path / "careamics.yaml")
|
|
161
|
+
|
|
162
|
+
# export model state dictionary
|
|
163
|
+
weight_path = _export_state_dict(model, temp_path / "weights.pth")
|
|
164
|
+
|
|
165
|
+
# export cover if necesary
|
|
166
|
+
if covers is None:
|
|
167
|
+
covers = [create_cover(temp_path, input_array, output_array)]
|
|
168
|
+
|
|
169
|
+
# create model description
|
|
170
|
+
model_description = create_model_description(
|
|
171
|
+
config=config,
|
|
172
|
+
name=model_name,
|
|
173
|
+
general_description=general_description,
|
|
174
|
+
data_description=data_description,
|
|
175
|
+
authors=authors,
|
|
176
|
+
inputs=inputs,
|
|
177
|
+
outputs=outputs,
|
|
178
|
+
weights_path=weight_path,
|
|
179
|
+
torch_version=PYTORCH_VERSION,
|
|
180
|
+
careamics_version=careamics_version,
|
|
181
|
+
config_path=config_path,
|
|
182
|
+
env_path=env_path,
|
|
183
|
+
covers=covers,
|
|
184
|
+
channel_names=channel_names,
|
|
185
|
+
model_version=model_version,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# test model description
|
|
189
|
+
test_kwargs = {}
|
|
190
|
+
if hasattr(model_description, "config") and isinstance(
|
|
191
|
+
model_description.config, dict
|
|
192
|
+
):
|
|
193
|
+
bioimageio_config = model_description.config.get("bioimageio", {})
|
|
194
|
+
test_kwargs = bioimageio_config.get("test_kwargs", {}).get(
|
|
195
|
+
"pytorch_state_dict", {}
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
summary: ValidationSummary = test_model(model_description, **test_kwargs)
|
|
199
|
+
if summary.status == "failed":
|
|
200
|
+
raise ValueError(f"Model description test failed: {summary}")
|
|
201
|
+
|
|
202
|
+
# save bmz model
|
|
203
|
+
save_bioimageio_package(model_description, output_path=path_to_archive)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def load_from_bmz(
|
|
207
|
+
path: Union[Path, str, HttpUrl],
|
|
208
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
209
|
+
"""Load a model from a BioImage Model Zoo archive.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
path : Path, str or HttpUrl
|
|
214
|
+
Path to the BioImage Model Zoo archive. A Http URL must point to a downloadable
|
|
215
|
+
location.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
FCNModel or VAEModel
|
|
220
|
+
The loaded CAREamics model.
|
|
221
|
+
Configuration
|
|
222
|
+
The loaded CAREamics configuration.
|
|
223
|
+
|
|
224
|
+
Raises
|
|
225
|
+
------
|
|
226
|
+
ValueError
|
|
227
|
+
If the path is not a zip file.
|
|
228
|
+
"""
|
|
229
|
+
# load description, this creates an unzipped folder next to the archive
|
|
230
|
+
model_desc = load_model_description(path)
|
|
231
|
+
|
|
232
|
+
# extract paths
|
|
233
|
+
weights_path, config_path = extract_model_path(model_desc)
|
|
234
|
+
|
|
235
|
+
# load configuration
|
|
236
|
+
config = load_configuration(config_path)
|
|
237
|
+
|
|
238
|
+
# create careamics lightning module
|
|
239
|
+
if config.algorithm_config.model.architecture == SupportedArchitecture.UNET:
|
|
240
|
+
model = FCNModule(algorithm_config=config.algorithm_config)
|
|
241
|
+
elif config.algorithm_config.model.architecture == SupportedArchitecture.LVAE:
|
|
242
|
+
model = VAEModule(algorithm_config=config.algorithm_config)
|
|
243
|
+
else:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"Unsupported architecture {config.algorithm_config.model.architecture}"
|
|
246
|
+
) # TODO ugly ?
|
|
247
|
+
|
|
248
|
+
# load model state dictionary
|
|
249
|
+
_load_state_dict(model, weights_path)
|
|
250
|
+
|
|
251
|
+
return model, config
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Utility functions to load pretrained models."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration
|
|
9
|
+
from careamics.lightning.lightning_module import FCNModule, VAEModule
|
|
10
|
+
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
|
+
from careamics.utils import check_path_exists
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_pretrained(
|
|
15
|
+
path: Union[Path, str],
|
|
16
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
17
|
+
"""
|
|
18
|
+
Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
|
|
19
|
+
|
|
20
|
+
Expected formats are .ckpt or .zip files.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
path : Union[Path, str]
|
|
25
|
+
Path to the pretrained model.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
tuple[CAREamicsKiln, Configuration]
|
|
30
|
+
tuple of CAREamics model and its configuration.
|
|
31
|
+
|
|
32
|
+
Raises
|
|
33
|
+
------
|
|
34
|
+
ValueError
|
|
35
|
+
If the model format is not supported.
|
|
36
|
+
"""
|
|
37
|
+
path = check_path_exists(path)
|
|
38
|
+
|
|
39
|
+
if path.suffix == ".ckpt":
|
|
40
|
+
return _load_checkpoint(path)
|
|
41
|
+
elif path.suffix == ".zip":
|
|
42
|
+
return load_from_bmz(path)
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _load_checkpoint(
|
|
50
|
+
path: Union[Path, str],
|
|
51
|
+
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
52
|
+
"""
|
|
53
|
+
Load a model from a checkpoint and return both model and configuration.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
path : Union[Path, str]
|
|
58
|
+
Path to the checkpoint.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
tuple[CAREamicsKiln, Configuration]
|
|
63
|
+
tuple of CAREamics model and its configuration.
|
|
64
|
+
|
|
65
|
+
Raises
|
|
66
|
+
------
|
|
67
|
+
ValueError
|
|
68
|
+
If the checkpoint file does not contain hyper parameters (configuration).
|
|
69
|
+
"""
|
|
70
|
+
# load checkpoint
|
|
71
|
+
# here we might run into issues between devices
|
|
72
|
+
# see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
|
|
73
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
74
|
+
checkpoint: dict = torch.load(path, map_location=device)
|
|
75
|
+
|
|
76
|
+
# attempt to load configuration
|
|
77
|
+
try:
|
|
78
|
+
cfg_dict = checkpoint["hyper_parameters"]
|
|
79
|
+
except KeyError as e:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Invalid checkpoint file. No `hyper_parameters` found in the "
|
|
82
|
+
f"checkpoint: {checkpoint.keys()}"
|
|
83
|
+
) from e
|
|
84
|
+
|
|
85
|
+
if cfg_dict["algorithm_config"]["model"]["architecture"] == "UNet":
|
|
86
|
+
model = FCNModule.load_from_checkpoint(path)
|
|
87
|
+
elif cfg_dict["algorithm_config"]["model"]["architecture"] == "LVAE":
|
|
88
|
+
model = VAEModule.load_from_checkpoint(path)
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"Invalid model architecture: "
|
|
92
|
+
f"{cfg_dict['algorithm_config']['model']['architecture']}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return model, Configuration(**cfg_dict)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Activations for CAREamics models."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from ..config.support import SupportedActivation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_activation(activation: Union[SupportedActivation, str]) -> Callable:
|
|
12
|
+
"""
|
|
13
|
+
Get activation function.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
activation : str
|
|
18
|
+
Activation function name.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
Callable
|
|
23
|
+
Activation function.
|
|
24
|
+
"""
|
|
25
|
+
if activation == SupportedActivation.RELU:
|
|
26
|
+
return nn.ReLU()
|
|
27
|
+
elif activation == SupportedActivation.ELU:
|
|
28
|
+
return nn.ELU()
|
|
29
|
+
elif activation == SupportedActivation.LEAKYRELU:
|
|
30
|
+
return nn.LeakyReLU()
|
|
31
|
+
elif activation == SupportedActivation.TANH:
|
|
32
|
+
return nn.Tanh()
|
|
33
|
+
elif activation == SupportedActivation.SIGMOID:
|
|
34
|
+
return nn.Sigmoid()
|
|
35
|
+
elif activation == SupportedActivation.SOFTMAX:
|
|
36
|
+
return nn.Softmax(dim=1)
|
|
37
|
+
elif activation == SupportedActivation.NONE:
|
|
38
|
+
return nn.Identity()
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(f"Activation {activation} not supported.")
|