careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Functions used to create a README.md file for BMZ export."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration
|
|
9
|
+
from careamics.utils import cwd, get_careamics_home
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _yaml_block(yaml_str: str) -> str:
|
|
13
|
+
"""Return a markdown code block with a yaml string.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
yaml_str : str
|
|
18
|
+
YAML string.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
str
|
|
23
|
+
Markdown code block with the YAML string.
|
|
24
|
+
"""
|
|
25
|
+
return f"```yaml\n{yaml_str}\n```"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def readme_factory(
|
|
29
|
+
config: Configuration,
|
|
30
|
+
careamics_version: str,
|
|
31
|
+
data_description: Optional[str] = None,
|
|
32
|
+
) -> Path:
|
|
33
|
+
"""Create a README file for the model.
|
|
34
|
+
|
|
35
|
+
`data_description` can be used to add more information about the content of the
|
|
36
|
+
data the model was trained on.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
config : Configuration
|
|
41
|
+
CAREamics configuration.
|
|
42
|
+
careamics_version : str
|
|
43
|
+
CAREamics version.
|
|
44
|
+
data_description : Optional[str], optional
|
|
45
|
+
Description of the data, by default None.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
Path
|
|
50
|
+
Path to the README file.
|
|
51
|
+
"""
|
|
52
|
+
algorithm = config.algorithm_config
|
|
53
|
+
training = config.training_config
|
|
54
|
+
data = config.data_config
|
|
55
|
+
|
|
56
|
+
# create file
|
|
57
|
+
# TODO use tempfile as in the bmz_io module
|
|
58
|
+
with cwd(get_careamics_home()):
|
|
59
|
+
readme = Path("README.md")
|
|
60
|
+
readme.touch()
|
|
61
|
+
|
|
62
|
+
# algorithm pretty name
|
|
63
|
+
algorithm_flavour = config.get_algorithm_flavour()
|
|
64
|
+
algorithm_pretty_name = algorithm_flavour + " - CAREamics"
|
|
65
|
+
|
|
66
|
+
description = [f"# {algorithm_pretty_name}\n\n"]
|
|
67
|
+
|
|
68
|
+
# algorithm description
|
|
69
|
+
description.append("Algorithm description:\n\n")
|
|
70
|
+
description.append(config.get_algorithm_description())
|
|
71
|
+
description.append("\n\n")
|
|
72
|
+
|
|
73
|
+
# algorithm details
|
|
74
|
+
description.append(
|
|
75
|
+
f"{algorithm_flavour} was trained using CAREamics (version "
|
|
76
|
+
f"{careamics_version}) with the following algorithm "
|
|
77
|
+
f"parameters:\n\n"
|
|
78
|
+
)
|
|
79
|
+
description.append(
|
|
80
|
+
_yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True)))
|
|
81
|
+
)
|
|
82
|
+
description.append("\n\n")
|
|
83
|
+
|
|
84
|
+
# data description
|
|
85
|
+
description.append("## Data description\n\n")
|
|
86
|
+
if data_description is not None:
|
|
87
|
+
description.append(data_description)
|
|
88
|
+
description.append("\n\n")
|
|
89
|
+
|
|
90
|
+
description.append("The data was processed using the following parameters:\n\n")
|
|
91
|
+
|
|
92
|
+
description.append(_yaml_block(yaml.dump(data.model_dump(exclude_none=True))))
|
|
93
|
+
description.append("\n\n")
|
|
94
|
+
|
|
95
|
+
# training description
|
|
96
|
+
description.append("## Training description\n\n")
|
|
97
|
+
|
|
98
|
+
description.append("The model was trained using the following parameters:\n\n")
|
|
99
|
+
|
|
100
|
+
description.append(
|
|
101
|
+
_yaml_block(yaml.dump(training.model_dump(exclude_none=True)))
|
|
102
|
+
)
|
|
103
|
+
description.append("\n\n")
|
|
104
|
+
|
|
105
|
+
# references
|
|
106
|
+
reference = config.get_algorithm_references()
|
|
107
|
+
if reference != "":
|
|
108
|
+
description.append("## References\n\n")
|
|
109
|
+
description.append(reference)
|
|
110
|
+
description.append("\n\n")
|
|
111
|
+
|
|
112
|
+
# links
|
|
113
|
+
description.append(
|
|
114
|
+
"## Links\n\n"
|
|
115
|
+
"- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
|
|
116
|
+
"- [CAREamics documentation](https://careamics.github.io/latest/)\n"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
readme.write_text("".join(description))
|
|
120
|
+
|
|
121
|
+
return readme.absolute()
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Bioimage.io utils."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_unzip_path(zip_path: Union[Path, str]) -> Path:
|
|
8
|
+
"""Generate unzipped folder path from the bioimage.io model path.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
zip_path : Path
|
|
13
|
+
Path to the bioimage.io model.
|
|
14
|
+
|
|
15
|
+
Returns
|
|
16
|
+
-------
|
|
17
|
+
Path
|
|
18
|
+
Path to the unzipped folder.
|
|
19
|
+
"""
|
|
20
|
+
zip_path = Path(zip_path)
|
|
21
|
+
|
|
22
|
+
return zip_path.parent / (str(zip_path.name) + ".unzip")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_env_text(pytorch_version: str) -> str:
|
|
26
|
+
"""Create environment yaml content for the bioimage model.
|
|
27
|
+
|
|
28
|
+
This installs an environment with the specified pytorch version and the latest
|
|
29
|
+
changes to careamics.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
pytorch_version : str
|
|
34
|
+
Pytorch version.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
str
|
|
39
|
+
Environment text.
|
|
40
|
+
"""
|
|
41
|
+
env = (
|
|
42
|
+
f"name: careamics\n"
|
|
43
|
+
f"dependencies:\n"
|
|
44
|
+
f" - python=3.10\n"
|
|
45
|
+
f" - pytorch={pytorch_version}\n"
|
|
46
|
+
f" - torchvision={pytorch_version}\n"
|
|
47
|
+
f" - pip\n"
|
|
48
|
+
f" - pip:\n"
|
|
49
|
+
f" - git+https://github.com/CAREamics/careamics.git\n"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
return env
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
"""Module use to build BMZ model description."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from bioimageio.spec.model.v0_5 import (
|
|
8
|
+
ArchitectureFromLibraryDescr,
|
|
9
|
+
Author,
|
|
10
|
+
AxisBase,
|
|
11
|
+
AxisId,
|
|
12
|
+
BatchAxis,
|
|
13
|
+
ChannelAxis,
|
|
14
|
+
EnvironmentFileDescr,
|
|
15
|
+
FileDescr,
|
|
16
|
+
FixedZeroMeanUnitVarianceAlongAxisKwargs,
|
|
17
|
+
FixedZeroMeanUnitVarianceDescr,
|
|
18
|
+
Identifier,
|
|
19
|
+
InputTensorDescr,
|
|
20
|
+
ModelDescr,
|
|
21
|
+
OutputTensorDescr,
|
|
22
|
+
PytorchStateDictWeightsDescr,
|
|
23
|
+
SpaceInputAxis,
|
|
24
|
+
SpaceOutputAxis,
|
|
25
|
+
TensorId,
|
|
26
|
+
Version,
|
|
27
|
+
WeightsDescr,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from careamics.config import Configuration, DataConfig
|
|
31
|
+
|
|
32
|
+
from ._readme_factory import readme_factory
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _create_axes(
|
|
36
|
+
array: np.ndarray,
|
|
37
|
+
data_config: DataConfig,
|
|
38
|
+
channel_names: Optional[List[str]] = None,
|
|
39
|
+
is_input: bool = True,
|
|
40
|
+
) -> List[AxisBase]:
|
|
41
|
+
"""Create axes description.
|
|
42
|
+
|
|
43
|
+
Array shape is expected to be SC(Z)YX.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
array : np.ndarray
|
|
48
|
+
Array.
|
|
49
|
+
data_config : DataModel
|
|
50
|
+
CAREamics data configuration.
|
|
51
|
+
channel_names : Optional[List[str]], optional
|
|
52
|
+
Channel names, by default None.
|
|
53
|
+
is_input : bool, optional
|
|
54
|
+
Whether the axes are input axes, by default True.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
List[AxisBase]
|
|
59
|
+
List of axes description.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If channel names are not provided when channel axis is present.
|
|
65
|
+
"""
|
|
66
|
+
# axes have to be SC(Z)YX
|
|
67
|
+
spatial_axes = data_config.axes.replace("S", "").replace("C", "")
|
|
68
|
+
|
|
69
|
+
# batch is always present
|
|
70
|
+
axes_model = [BatchAxis()]
|
|
71
|
+
|
|
72
|
+
if "C" in data_config.axes:
|
|
73
|
+
if channel_names is not None:
|
|
74
|
+
axes_model.append(
|
|
75
|
+
ChannelAxis(channel_names=[Identifier(name) for name in channel_names])
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Channel names must be provided if channel axis is present, axes: "
|
|
80
|
+
f"{data_config.axes}."
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
# singleton channel
|
|
84
|
+
axes_model.append(ChannelAxis(channel_names=[Identifier("channel")]))
|
|
85
|
+
|
|
86
|
+
# spatial axes
|
|
87
|
+
for ind, axes in enumerate(spatial_axes):
|
|
88
|
+
if axes in ["X", "Y", "Z"]:
|
|
89
|
+
if is_input:
|
|
90
|
+
axes_model.append(
|
|
91
|
+
SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
axes_model.append(
|
|
95
|
+
SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return axes_model
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _create_inputs_ouputs(
|
|
102
|
+
input_array: np.ndarray,
|
|
103
|
+
output_array: np.ndarray,
|
|
104
|
+
data_config: DataConfig,
|
|
105
|
+
input_path: Union[Path, str],
|
|
106
|
+
output_path: Union[Path, str],
|
|
107
|
+
channel_names: Optional[List[str]] = None,
|
|
108
|
+
) -> Tuple[InputTensorDescr, OutputTensorDescr]:
|
|
109
|
+
"""Create input and output tensor description.
|
|
110
|
+
|
|
111
|
+
Input and output paths must point to a `.npy` file.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
input_array : np.ndarray
|
|
116
|
+
Input array.
|
|
117
|
+
output_array : np.ndarray
|
|
118
|
+
Output array.
|
|
119
|
+
data_config : DataModel
|
|
120
|
+
CAREamics data configuration.
|
|
121
|
+
input_path : Union[Path, str]
|
|
122
|
+
Path to input .npy file.
|
|
123
|
+
output_path : Union[Path, str]
|
|
124
|
+
Path to output .npy file.
|
|
125
|
+
channel_names : Optional[List[str]], optional
|
|
126
|
+
Channel names, by default None.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
Tuple[InputTensorDescr, OutputTensorDescr]
|
|
131
|
+
Input and output tensor descriptions.
|
|
132
|
+
"""
|
|
133
|
+
input_axes = _create_axes(input_array, data_config, channel_names)
|
|
134
|
+
output_axes = _create_axes(output_array, data_config, channel_names, False)
|
|
135
|
+
|
|
136
|
+
# mean and std
|
|
137
|
+
assert data_config.image_means is not None, "Mean cannot be None."
|
|
138
|
+
assert data_config.image_means is not None, "Std cannot be None."
|
|
139
|
+
means = data_config.image_means
|
|
140
|
+
stds = data_config.image_stds
|
|
141
|
+
|
|
142
|
+
# and the mean and std required to invert the normalization
|
|
143
|
+
# CAREamics denormalization: x = y * (std + eps) + mean
|
|
144
|
+
# BMZ normalization : x = (y - mean') / (std' + eps)
|
|
145
|
+
# to apply the BMZ normalization as a denormalization step, we need:
|
|
146
|
+
eps = 1e-6
|
|
147
|
+
inv_means = []
|
|
148
|
+
inv_stds = []
|
|
149
|
+
if means and stds:
|
|
150
|
+
for mean, std in zip(means, stds):
|
|
151
|
+
inv_means.append(-mean / (std + eps))
|
|
152
|
+
inv_stds.append(1 / (std + eps) - eps)
|
|
153
|
+
|
|
154
|
+
# create input/output descriptions
|
|
155
|
+
input_descr = InputTensorDescr(
|
|
156
|
+
id=TensorId("input"),
|
|
157
|
+
axes=input_axes,
|
|
158
|
+
test_tensor=FileDescr(source=input_path),
|
|
159
|
+
preprocessing=[
|
|
160
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
161
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
|
|
162
|
+
mean=means, std=stds, axis="channel"
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
output_descr = OutputTensorDescr(
|
|
168
|
+
id=TensorId("prediction"),
|
|
169
|
+
axes=output_axes,
|
|
170
|
+
test_tensor=FileDescr(source=output_path),
|
|
171
|
+
postprocessing=[
|
|
172
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
173
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
|
|
174
|
+
mean=inv_means, std=inv_stds, axis="channel"
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
],
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return input_descr, output_descr
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError("Mean and std cannot be None.")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def create_model_description(
|
|
186
|
+
config: Configuration,
|
|
187
|
+
name: str,
|
|
188
|
+
general_description: str,
|
|
189
|
+
authors: List[Author],
|
|
190
|
+
inputs: Union[Path, str],
|
|
191
|
+
outputs: Union[Path, str],
|
|
192
|
+
weights_path: Union[Path, str],
|
|
193
|
+
torch_version: str,
|
|
194
|
+
careamics_version: str,
|
|
195
|
+
config_path: Union[Path, str],
|
|
196
|
+
env_path: Union[Path, str],
|
|
197
|
+
channel_names: Optional[List[str]] = None,
|
|
198
|
+
data_description: Optional[str] = None,
|
|
199
|
+
) -> ModelDescr:
|
|
200
|
+
"""Create model description.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
config : Configuration
|
|
205
|
+
CAREamics configuration.
|
|
206
|
+
name : str
|
|
207
|
+
Name of the model.
|
|
208
|
+
general_description : str
|
|
209
|
+
General description of the model.
|
|
210
|
+
authors : List[Author]
|
|
211
|
+
Authors of the model.
|
|
212
|
+
inputs : Union[Path, str]
|
|
213
|
+
Path to input .npy file.
|
|
214
|
+
outputs : Union[Path, str]
|
|
215
|
+
Path to output .npy file.
|
|
216
|
+
weights_path : Union[Path, str]
|
|
217
|
+
Path to model weights.
|
|
218
|
+
torch_version : str
|
|
219
|
+
Pytorch version.
|
|
220
|
+
careamics_version : str
|
|
221
|
+
CAREamics version.
|
|
222
|
+
config_path : Union[Path, str]
|
|
223
|
+
Path to model configuration.
|
|
224
|
+
env_path : Union[Path, str]
|
|
225
|
+
Path to environment file.
|
|
226
|
+
channel_names : Optional[List[str]], optional
|
|
227
|
+
Channel names, by default None.
|
|
228
|
+
data_description : Optional[str], optional
|
|
229
|
+
Description of the data, by default None.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
ModelDescr
|
|
234
|
+
Model description.
|
|
235
|
+
"""
|
|
236
|
+
# documentation
|
|
237
|
+
doc = readme_factory(
|
|
238
|
+
config,
|
|
239
|
+
careamics_version=careamics_version,
|
|
240
|
+
data_description=data_description,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# inputs, outputs
|
|
244
|
+
input_descr, output_descr = _create_inputs_ouputs(
|
|
245
|
+
input_array=np.load(inputs),
|
|
246
|
+
output_array=np.load(outputs),
|
|
247
|
+
data_config=config.data_config,
|
|
248
|
+
input_path=inputs,
|
|
249
|
+
output_path=outputs,
|
|
250
|
+
channel_names=channel_names,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# weights description
|
|
254
|
+
architecture_descr = ArchitectureFromLibraryDescr(
|
|
255
|
+
import_from="careamics.models.unet",
|
|
256
|
+
callable=f"{config.algorithm_config.model.architecture}",
|
|
257
|
+
kwargs=config.algorithm_config.model.model_dump(),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
weights_descr = WeightsDescr(
|
|
261
|
+
pytorch_state_dict=PytorchStateDictWeightsDescr(
|
|
262
|
+
source=weights_path,
|
|
263
|
+
architecture=architecture_descr,
|
|
264
|
+
pytorch_version=Version(torch_version),
|
|
265
|
+
dependencies=EnvironmentFileDescr(source=env_path),
|
|
266
|
+
),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# overall model description
|
|
270
|
+
model = ModelDescr(
|
|
271
|
+
name=name,
|
|
272
|
+
authors=authors,
|
|
273
|
+
description=general_description,
|
|
274
|
+
documentation=doc,
|
|
275
|
+
inputs=[input_descr],
|
|
276
|
+
outputs=[output_descr],
|
|
277
|
+
tags=config.get_algorithm_keywords(),
|
|
278
|
+
links=[
|
|
279
|
+
"https://github.com/CAREamics/careamics",
|
|
280
|
+
"https://careamics.github.io/latest/",
|
|
281
|
+
],
|
|
282
|
+
license="BSD-3-Clause",
|
|
283
|
+
version="0.1.0",
|
|
284
|
+
weights=weights_descr,
|
|
285
|
+
attachments=[FileDescr(source=config_path)],
|
|
286
|
+
cite=config.get_algorithm_citations(),
|
|
287
|
+
config={ # conversion from float32 to float64 creates small differences...
|
|
288
|
+
"bioimageio": {
|
|
289
|
+
"test_kwargs": {
|
|
290
|
+
"pytorch_state_dict": {
|
|
291
|
+
"decimals": 0, # ...so we relax the constraints on the decimals
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
},
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
return model
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]:
|
|
302
|
+
"""Return the relative path to the weights and configuration files.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
model_desc : ModelDescr
|
|
307
|
+
Model description.
|
|
308
|
+
|
|
309
|
+
Returns
|
|
310
|
+
-------
|
|
311
|
+
Tuple[Path, Path]
|
|
312
|
+
Weights and configuration paths.
|
|
313
|
+
"""
|
|
314
|
+
weights_path = model_desc.weights.pytorch_state_dict.source.path
|
|
315
|
+
|
|
316
|
+
if len(model_desc.attachments) == 1:
|
|
317
|
+
config_path = model_desc.attachments[0].source.path
|
|
318
|
+
else:
|
|
319
|
+
for file in model_desc.attachments:
|
|
320
|
+
if file.source.path.suffix == ".yml":
|
|
321
|
+
config_path = file.source.path
|
|
322
|
+
break
|
|
323
|
+
|
|
324
|
+
if config_path is None:
|
|
325
|
+
raise ValueError("Configuration file not found.")
|
|
326
|
+
|
|
327
|
+
return weights_path, config_path
|