careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Functions used to create a README.md file for BMZ export."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
|
|
7
|
+
from careamics.config import Configuration
|
|
8
|
+
from careamics.utils import cwd, get_careamics_home
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _yaml_block(yaml_str: str) -> str:
|
|
12
|
+
"""Return a markdown code block with a yaml string.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
yaml_str : str
|
|
17
|
+
YAML string.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
str
|
|
22
|
+
Markdown code block with the YAML string.
|
|
23
|
+
"""
|
|
24
|
+
return f"```yaml\n{yaml_str}\n```"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def readme_factory(
|
|
28
|
+
config: Configuration,
|
|
29
|
+
careamics_version: str,
|
|
30
|
+
data_description: str,
|
|
31
|
+
) -> Path:
|
|
32
|
+
"""Create a README file for the model.
|
|
33
|
+
|
|
34
|
+
`data_description` can be used to add more information about the content of the
|
|
35
|
+
data the model was trained on.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
config : Configuration
|
|
40
|
+
CAREamics configuration.
|
|
41
|
+
careamics_version : str
|
|
42
|
+
CAREamics version.
|
|
43
|
+
data_description : str
|
|
44
|
+
Description of the data.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
Path
|
|
49
|
+
Path to the README file.
|
|
50
|
+
"""
|
|
51
|
+
# create file
|
|
52
|
+
# TODO use tempfile as in the bmz_io module
|
|
53
|
+
with cwd(get_careamics_home()):
|
|
54
|
+
readme = Path("README.md")
|
|
55
|
+
readme.touch()
|
|
56
|
+
|
|
57
|
+
# algorithm pretty name
|
|
58
|
+
algorithm_flavour = config.get_algorithm_friendly_name()
|
|
59
|
+
algorithm_pretty_name = algorithm_flavour + " - CAREamics"
|
|
60
|
+
|
|
61
|
+
description = [f"# {algorithm_pretty_name}\n\n"]
|
|
62
|
+
|
|
63
|
+
# data description
|
|
64
|
+
description.append("## Data description\n\n")
|
|
65
|
+
description.append(data_description)
|
|
66
|
+
description.append("\n\n")
|
|
67
|
+
|
|
68
|
+
# algorithm description
|
|
69
|
+
description.append("## Algorithm description:\n\n")
|
|
70
|
+
description.append(config.get_algorithm_description())
|
|
71
|
+
description.append("\n\n")
|
|
72
|
+
|
|
73
|
+
# configuration description
|
|
74
|
+
description.append("## Configuration\n\n")
|
|
75
|
+
|
|
76
|
+
description.append(
|
|
77
|
+
f"{algorithm_flavour} was trained using CAREamics (version "
|
|
78
|
+
f"{careamics_version}) using the following configuration:\n\n"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
description.append(_yaml_block(yaml.dump(config.model_dump(exclude_none=True))))
|
|
82
|
+
description.append("\n\n")
|
|
83
|
+
|
|
84
|
+
# validation
|
|
85
|
+
description.append("# Validation\n\n")
|
|
86
|
+
|
|
87
|
+
description.append(
|
|
88
|
+
"In order to validate the model, we encourage users to acquire a "
|
|
89
|
+
"test dataset with ground-truth data. Comparing the ground-truth data "
|
|
90
|
+
"with the prediction allows unbiased evaluation of the model performances. "
|
|
91
|
+
"This can be done for instance by using metrics such as PSNR, SSIM, or"
|
|
92
|
+
"MicroSSIM. In the absence of ground-truth, inspecting the residual image "
|
|
93
|
+
"(difference between input and predicted image) can be helpful to identify "
|
|
94
|
+
"whether real signal is removed from the input image.\n\n"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# references
|
|
98
|
+
reference = config.get_algorithm_references()
|
|
99
|
+
if reference != "":
|
|
100
|
+
description.append("## References\n\n")
|
|
101
|
+
description.append(reference)
|
|
102
|
+
description.append("\n\n")
|
|
103
|
+
|
|
104
|
+
# links
|
|
105
|
+
description.append(
|
|
106
|
+
"# Links\n\n"
|
|
107
|
+
"- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
|
|
108
|
+
"- [CAREamics documentation](https://careamics.github.io/)\n"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
readme.write_text("".join(description))
|
|
112
|
+
|
|
113
|
+
return readme.absolute()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Bioimage.io utils."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from careamics.utils.version import get_careamics_version
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_unzip_path(zip_path: Union[Path, str]) -> Path:
|
|
10
|
+
"""Generate unzipped folder path from the bioimage.io model path.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
zip_path : Path
|
|
15
|
+
Path to the bioimage.io model.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
Path
|
|
20
|
+
Path to the unzipped folder.
|
|
21
|
+
"""
|
|
22
|
+
zip_path = Path(zip_path)
|
|
23
|
+
|
|
24
|
+
return zip_path.parent / (str(zip_path.name) + ".unzip")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def create_env_text(pytorch_version: str, torchvision_version: str) -> str:
|
|
28
|
+
"""Create environment yaml content for the bioimage model.
|
|
29
|
+
|
|
30
|
+
This installs an environment with the specified pytorch version and the latest
|
|
31
|
+
changes to careamics.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
pytorch_version : str
|
|
36
|
+
Pytorch version.
|
|
37
|
+
torchvision_version : str
|
|
38
|
+
Torchvision version.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
str
|
|
43
|
+
Environment text.
|
|
44
|
+
"""
|
|
45
|
+
env = (
|
|
46
|
+
f"name: careamics\n"
|
|
47
|
+
f"dependencies:\n"
|
|
48
|
+
f" - python=3.12\n"
|
|
49
|
+
f" - pip\n"
|
|
50
|
+
f" - pip:\n"
|
|
51
|
+
f" - torch=={pytorch_version}\n"
|
|
52
|
+
f" - torchvision=={torchvision_version}\n"
|
|
53
|
+
f" - careamics=={get_careamics_version()}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return env
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Convenience function to create covers for the BMZ."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
color_palette = np.array(
|
|
10
|
+
[
|
|
11
|
+
np.array([255, 195, 0]), # grey
|
|
12
|
+
np.array([189, 226, 240]),
|
|
13
|
+
np.array([96, 60, 76]),
|
|
14
|
+
np.array([193, 225, 193]),
|
|
15
|
+
]
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_norm_slice(array: NDArray) -> NDArray:
|
|
20
|
+
"""Get the normalized middle slice of a 4D or 5D array (SC(Z)YX).
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
array : NDArray
|
|
25
|
+
Array from which to get the middle slice.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
NDArray
|
|
30
|
+
Normalized middle slice of the input array.
|
|
31
|
+
"""
|
|
32
|
+
if array.ndim not in (4, 5):
|
|
33
|
+
raise ValueError("Array must be 4D or 5D.")
|
|
34
|
+
|
|
35
|
+
channels = array.shape[1] > 1
|
|
36
|
+
z_stack = array.ndim == 5
|
|
37
|
+
|
|
38
|
+
# get slice
|
|
39
|
+
if z_stack:
|
|
40
|
+
array_slice = array[0, :, array.shape[2] // 2, ...]
|
|
41
|
+
else:
|
|
42
|
+
array_slice = array[0, ...]
|
|
43
|
+
|
|
44
|
+
# channels
|
|
45
|
+
if channels:
|
|
46
|
+
array_slice = np.moveaxis(array_slice, 0, -1)
|
|
47
|
+
else:
|
|
48
|
+
array_slice = array_slice[0, ...]
|
|
49
|
+
|
|
50
|
+
# normalize
|
|
51
|
+
array_slice = (
|
|
52
|
+
255
|
|
53
|
+
* (array_slice - array_slice.min())
|
|
54
|
+
/ (array_slice.max() - array_slice.min())
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return array_slice.astype(np.uint8)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _four_channel_image(array: NDArray) -> Image:
|
|
61
|
+
"""Convert 4-channel array to Image.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
array : NDArray
|
|
66
|
+
Normalized array to convert.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
Image
|
|
71
|
+
Converted array.
|
|
72
|
+
"""
|
|
73
|
+
colors = color_palette[np.newaxis, np.newaxis, :, :]
|
|
74
|
+
four_c_array = np.sum(array[..., :4, np.newaxis] * colors, axis=-2).astype(np.uint8)
|
|
75
|
+
|
|
76
|
+
return Image.fromarray(four_c_array).convert("RGB")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _convert_to_image(original_shape: tuple[int, ...], array: NDArray) -> Image:
|
|
80
|
+
"""Convert to Image.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
original_shape : tuple
|
|
85
|
+
Original shape of the array.
|
|
86
|
+
array : NDArray
|
|
87
|
+
Normalized array to convert.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
Image
|
|
92
|
+
Converted array.
|
|
93
|
+
"""
|
|
94
|
+
n_channels = original_shape[1]
|
|
95
|
+
|
|
96
|
+
if n_channels > 1:
|
|
97
|
+
if n_channels == 3:
|
|
98
|
+
return Image.fromarray(array).convert("RGB")
|
|
99
|
+
elif n_channels == 2:
|
|
100
|
+
# add an empty channel to the numpy array
|
|
101
|
+
array = np.concatenate([np.zeros_like(array[..., 0:1]), array], axis=-1)
|
|
102
|
+
|
|
103
|
+
return Image.fromarray(array).convert("RGB")
|
|
104
|
+
else: # more than 4
|
|
105
|
+
return _four_channel_image(array[..., :4])
|
|
106
|
+
else:
|
|
107
|
+
return Image.fromarray(array).convert("L").convert("RGB")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def create_cover(directory: Path, array_in: NDArray, array_out: NDArray) -> Path:
|
|
111
|
+
"""Create a cover image from input and output arrays.
|
|
112
|
+
|
|
113
|
+
Input and output arrays are expected to be SC(Z)YX. For images with a Z
|
|
114
|
+
dimension, the middle slice is taken.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
directory : Path
|
|
119
|
+
Directory in which to save the cover.
|
|
120
|
+
array_in : numpy.ndarray
|
|
121
|
+
Array from which to create the cover image.
|
|
122
|
+
array_out : numpy.ndarray
|
|
123
|
+
Array from which to create the cover image.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
Path
|
|
128
|
+
Path to the saved cover image.
|
|
129
|
+
"""
|
|
130
|
+
# extract slice and normalize arrays
|
|
131
|
+
slice_in = _get_norm_slice(array_in)
|
|
132
|
+
slice_out = _get_norm_slice(array_out)
|
|
133
|
+
|
|
134
|
+
horizontal_split = slice_in.shape[-1] == slice_out.shape[-1]
|
|
135
|
+
if not horizontal_split:
|
|
136
|
+
if slice_in.shape[-2] != slice_out.shape[-2]:
|
|
137
|
+
raise ValueError("Input and output arrays have different shapes.")
|
|
138
|
+
|
|
139
|
+
# convert to Image
|
|
140
|
+
image_in = _convert_to_image(array_in.shape, slice_in)
|
|
141
|
+
image_out = _convert_to_image(array_out.shape, slice_out)
|
|
142
|
+
|
|
143
|
+
# split horizontally or vertically
|
|
144
|
+
if horizontal_split:
|
|
145
|
+
width = image_in.width // 2
|
|
146
|
+
|
|
147
|
+
cover = Image.new("RGB", (image_in.width, image_in.height))
|
|
148
|
+
cover.paste(image_in.crop((0, 0, width, image_in.height)), (0, 0))
|
|
149
|
+
cover.paste(
|
|
150
|
+
image_out.crop(
|
|
151
|
+
(image_in.width - width, 0, image_in.width, image_in.height)
|
|
152
|
+
),
|
|
153
|
+
(width, 0),
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
height = image_in.height // 2
|
|
157
|
+
|
|
158
|
+
cover = Image.new("RGB", (image_in.width, image_in.height))
|
|
159
|
+
cover.paste(image_in.crop((0, 0, image_in.width, height)), (0, 0))
|
|
160
|
+
cover.paste(
|
|
161
|
+
image_out.crop(
|
|
162
|
+
(0, image_in.height - height, image_in.width, image_in.height)
|
|
163
|
+
),
|
|
164
|
+
(0, height),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# save
|
|
168
|
+
cover_path = directory / "cover.png"
|
|
169
|
+
cover.save(cover_path)
|
|
170
|
+
|
|
171
|
+
return cover_path
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""Module use to build BMZ model description."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from bioimageio.spec._internal.io import extract
|
|
8
|
+
from bioimageio.spec.model.v0_5 import (
|
|
9
|
+
ArchitectureFromLibraryDescr,
|
|
10
|
+
Author,
|
|
11
|
+
AxisBase,
|
|
12
|
+
AxisId,
|
|
13
|
+
BatchAxis,
|
|
14
|
+
ChannelAxis,
|
|
15
|
+
FileDescr,
|
|
16
|
+
FixedZeroMeanUnitVarianceAlongAxisKwargs,
|
|
17
|
+
FixedZeroMeanUnitVarianceDescr,
|
|
18
|
+
Identifier,
|
|
19
|
+
InputTensorDescr,
|
|
20
|
+
ModelDescr,
|
|
21
|
+
OutputTensorDescr,
|
|
22
|
+
PytorchStateDictWeightsDescr,
|
|
23
|
+
SpaceInputAxis,
|
|
24
|
+
SpaceOutputAxis,
|
|
25
|
+
TensorId,
|
|
26
|
+
Version,
|
|
27
|
+
WeightsDescr,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from careamics.config import Configuration, DataConfig
|
|
31
|
+
|
|
32
|
+
from ._readme_factory import readme_factory
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _create_axes(
|
|
36
|
+
array: np.ndarray,
|
|
37
|
+
data_config: DataConfig,
|
|
38
|
+
channel_names: list[str] | None = None,
|
|
39
|
+
is_input: bool = True,
|
|
40
|
+
) -> list[AxisBase]:
|
|
41
|
+
"""Create axes description.
|
|
42
|
+
|
|
43
|
+
Array shape is expected to be SC(Z)YX.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
array : np.ndarray
|
|
48
|
+
Array.
|
|
49
|
+
data_config : DataModel
|
|
50
|
+
CAREamics data configuration.
|
|
51
|
+
channel_names : Optional[list[str]], optional
|
|
52
|
+
Channel names, by default None.
|
|
53
|
+
is_input : bool, optional
|
|
54
|
+
Whether the axes are input axes, by default True.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
list[AxisBase]
|
|
59
|
+
list of axes description.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If channel names are not provided when channel axis is present.
|
|
65
|
+
"""
|
|
66
|
+
# axes have to be SC(Z)YX
|
|
67
|
+
spatial_axes = data_config.axes.replace("S", "").replace("C", "")
|
|
68
|
+
|
|
69
|
+
# batch is always present
|
|
70
|
+
axes_model = [BatchAxis()]
|
|
71
|
+
|
|
72
|
+
if "C" in data_config.axes:
|
|
73
|
+
if channel_names is not None:
|
|
74
|
+
axes_model.append(
|
|
75
|
+
ChannelAxis(channel_names=[Identifier(name) for name in channel_names])
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Channel names must be provided if channel axis is present, axes: "
|
|
80
|
+
f"{data_config.axes}."
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
# singleton channel
|
|
84
|
+
axes_model.append(ChannelAxis(channel_names=[Identifier("channel")]))
|
|
85
|
+
|
|
86
|
+
# spatial axes
|
|
87
|
+
for ind, axes in enumerate(spatial_axes):
|
|
88
|
+
if axes in ["X", "Y", "Z"]:
|
|
89
|
+
if is_input:
|
|
90
|
+
axes_model.append(
|
|
91
|
+
SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
axes_model.append(
|
|
95
|
+
SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return axes_model
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _create_inputs_ouputs(
|
|
102
|
+
input_array: np.ndarray,
|
|
103
|
+
output_array: np.ndarray,
|
|
104
|
+
data_config: DataConfig,
|
|
105
|
+
input_path: Union[Path, str],
|
|
106
|
+
output_path: Union[Path, str],
|
|
107
|
+
channel_names: list[str] | None = None,
|
|
108
|
+
) -> tuple[InputTensorDescr, OutputTensorDescr]:
|
|
109
|
+
"""Create input and output tensor description.
|
|
110
|
+
|
|
111
|
+
Input and output paths must point to a `.npy` file.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
input_array : np.ndarray
|
|
116
|
+
Input array.
|
|
117
|
+
output_array : np.ndarray
|
|
118
|
+
Output array.
|
|
119
|
+
data_config : DataModel
|
|
120
|
+
CAREamics data configuration.
|
|
121
|
+
input_path : Union[Path, str]
|
|
122
|
+
Path to input .npy file.
|
|
123
|
+
output_path : Union[Path, str]
|
|
124
|
+
Path to output .npy file.
|
|
125
|
+
channel_names : Optional[list[str]], optional
|
|
126
|
+
Channel names, by default None.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
tuple[InputTensorDescr, OutputTensorDescr]
|
|
131
|
+
Input and output tensor descriptions.
|
|
132
|
+
"""
|
|
133
|
+
input_axes = _create_axes(input_array, data_config, channel_names)
|
|
134
|
+
output_axes = _create_axes(output_array, data_config, channel_names, False)
|
|
135
|
+
|
|
136
|
+
# mean and std
|
|
137
|
+
assert data_config.image_means is not None, "Mean cannot be None."
|
|
138
|
+
assert data_config.image_means is not None, "Std cannot be None."
|
|
139
|
+
means = data_config.image_means
|
|
140
|
+
stds = data_config.image_stds
|
|
141
|
+
|
|
142
|
+
# and the mean and std required to invert the normalization
|
|
143
|
+
# CAREamics denormalization: x = y * (std + eps) + mean
|
|
144
|
+
# BMZ normalization : x = (y - mean') / (std' + eps)
|
|
145
|
+
# to apply the BMZ normalization as a denormalization step, we need:
|
|
146
|
+
eps = 1e-6
|
|
147
|
+
inv_means = []
|
|
148
|
+
inv_stds = []
|
|
149
|
+
if means and stds:
|
|
150
|
+
for mean, std in zip(means, stds, strict=False):
|
|
151
|
+
inv_means.append(-mean / (std + eps))
|
|
152
|
+
inv_stds.append(1 / (std + eps) - eps)
|
|
153
|
+
|
|
154
|
+
# create input/output descriptions
|
|
155
|
+
input_descr = InputTensorDescr(
|
|
156
|
+
id=TensorId("input"),
|
|
157
|
+
axes=input_axes,
|
|
158
|
+
test_tensor=FileDescr(source=input_path),
|
|
159
|
+
preprocessing=[
|
|
160
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
161
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
|
|
162
|
+
mean=means, std=stds, axis="channel"
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
output_descr = OutputTensorDescr(
|
|
168
|
+
id=TensorId("prediction"),
|
|
169
|
+
axes=output_axes,
|
|
170
|
+
test_tensor=FileDescr(source=output_path),
|
|
171
|
+
postprocessing=[
|
|
172
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
173
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
|
|
174
|
+
mean=inv_means, std=inv_stds, axis="channel"
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
],
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return input_descr, output_descr
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError("Mean and std cannot be None.")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def create_model_description(
|
|
186
|
+
config: Configuration,
|
|
187
|
+
name: str,
|
|
188
|
+
general_description: str,
|
|
189
|
+
data_description: str,
|
|
190
|
+
authors: list[Author],
|
|
191
|
+
inputs: Union[Path, str],
|
|
192
|
+
outputs: Union[Path, str],
|
|
193
|
+
weights_path: Union[Path, str],
|
|
194
|
+
torch_version: str,
|
|
195
|
+
careamics_version: str,
|
|
196
|
+
config_path: Union[Path, str],
|
|
197
|
+
env_path: Union[Path, str],
|
|
198
|
+
covers: list[Union[Path, str]],
|
|
199
|
+
channel_names: list[str] | None = None,
|
|
200
|
+
model_version: str = "0.1.0",
|
|
201
|
+
) -> ModelDescr:
|
|
202
|
+
"""Create model description.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
config : Configuration
|
|
207
|
+
CAREamics configuration.
|
|
208
|
+
name : str
|
|
209
|
+
Name of the model.
|
|
210
|
+
general_description : str
|
|
211
|
+
General description of the model.
|
|
212
|
+
data_description : str
|
|
213
|
+
Description of the data the model was trained on.
|
|
214
|
+
authors : list[Author]
|
|
215
|
+
Authors of the model.
|
|
216
|
+
inputs : Union[Path, str]
|
|
217
|
+
Path to input .npy file.
|
|
218
|
+
outputs : Union[Path, str]
|
|
219
|
+
Path to output .npy file.
|
|
220
|
+
weights_path : Union[Path, str]
|
|
221
|
+
Path to model weights.
|
|
222
|
+
torch_version : str
|
|
223
|
+
Pytorch version.
|
|
224
|
+
careamics_version : str
|
|
225
|
+
CAREamics version.
|
|
226
|
+
config_path : Union[Path, str]
|
|
227
|
+
Path to model configuration.
|
|
228
|
+
env_path : Union[Path, str]
|
|
229
|
+
Path to environment file.
|
|
230
|
+
covers : list of pathlib.Path or str
|
|
231
|
+
Paths to cover images.
|
|
232
|
+
channel_names : Optional[list[str]], optional
|
|
233
|
+
Channel names, by default None.
|
|
234
|
+
model_version : str, default "0.1.0"
|
|
235
|
+
Model version.
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
ModelDescr
|
|
240
|
+
Model description.
|
|
241
|
+
"""
|
|
242
|
+
# documentation
|
|
243
|
+
doc = readme_factory(
|
|
244
|
+
config,
|
|
245
|
+
careamics_version=careamics_version,
|
|
246
|
+
data_description=data_description,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# inputs, outputs
|
|
250
|
+
input_descr, output_descr = _create_inputs_ouputs(
|
|
251
|
+
input_array=np.load(inputs),
|
|
252
|
+
output_array=np.load(outputs),
|
|
253
|
+
data_config=config.data_config,
|
|
254
|
+
input_path=inputs,
|
|
255
|
+
output_path=outputs,
|
|
256
|
+
channel_names=channel_names,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# weights description
|
|
260
|
+
architecture_descr = ArchitectureFromLibraryDescr(
|
|
261
|
+
import_from="careamics.models.unet",
|
|
262
|
+
callable=f"{config.algorithm_config.model.architecture}",
|
|
263
|
+
kwargs=config.algorithm_config.model.model_dump(),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
weights_descr = WeightsDescr(
|
|
267
|
+
pytorch_state_dict=PytorchStateDictWeightsDescr(
|
|
268
|
+
source=weights_path,
|
|
269
|
+
architecture=architecture_descr,
|
|
270
|
+
pytorch_version=Version(torch_version),
|
|
271
|
+
dependencies=FileDescr(source=Path(env_path)),
|
|
272
|
+
),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# overall model description
|
|
276
|
+
model = ModelDescr(
|
|
277
|
+
name=name,
|
|
278
|
+
authors=authors,
|
|
279
|
+
description=general_description,
|
|
280
|
+
documentation=doc,
|
|
281
|
+
inputs=[input_descr],
|
|
282
|
+
outputs=[output_descr],
|
|
283
|
+
tags=config.get_algorithm_keywords(),
|
|
284
|
+
links=[
|
|
285
|
+
"https://github.com/CAREamics/careamics",
|
|
286
|
+
"https://careamics.github.io/latest/",
|
|
287
|
+
],
|
|
288
|
+
license="BSD-3-Clause",
|
|
289
|
+
config={
|
|
290
|
+
"bioimageio": {
|
|
291
|
+
"test_kwargs": {
|
|
292
|
+
"pytorch_state_dict": {
|
|
293
|
+
"absolute_tolerance": 1e-2,
|
|
294
|
+
"relative_tolerance": 1e-2,
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
},
|
|
299
|
+
version=model_version,
|
|
300
|
+
weights=weights_descr,
|
|
301
|
+
attachments=[FileDescr(source=config_path)],
|
|
302
|
+
cite=config.get_algorithm_citations(),
|
|
303
|
+
covers=covers,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
return model
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
|
|
310
|
+
"""Return the relative path to the weights and configuration files.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
model_desc : ModelDescr
|
|
315
|
+
Model description.
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
tuple of (path, path)
|
|
320
|
+
Weights and configuration paths.
|
|
321
|
+
"""
|
|
322
|
+
if model_desc.weights.pytorch_state_dict is None:
|
|
323
|
+
raise ValueError("No model weights found in model description.")
|
|
324
|
+
|
|
325
|
+
# extract the zip model and return the directory
|
|
326
|
+
model_dir = extract(model_desc.root)
|
|
327
|
+
|
|
328
|
+
weights_path = model_dir.joinpath(model_desc.weights.pytorch_state_dict.source.path)
|
|
329
|
+
|
|
330
|
+
for file in model_desc.attachments:
|
|
331
|
+
file_path = file.source if isinstance(file.source, Path) else file.source.path
|
|
332
|
+
if file_path is None:
|
|
333
|
+
continue
|
|
334
|
+
file_path = Path(file_path)
|
|
335
|
+
if file_path.name == "careamics.yaml":
|
|
336
|
+
config_path = model_dir.joinpath(file.source.path)
|
|
337
|
+
break
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError("Configuration file not found.")
|
|
340
|
+
|
|
341
|
+
return weights_path, config_path
|