careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Bioimage.io utils."""
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_unzip_path(zip_path: Union[Path, str]) -> Path:
|
|
7
|
+
"""Generate unzipped folder path from the bioimage.io model path.
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
zip_path : Path
|
|
12
|
+
Path to the bioimage.io model.
|
|
13
|
+
|
|
14
|
+
Returns
|
|
15
|
+
-------
|
|
16
|
+
Path
|
|
17
|
+
Path to the unzipped folder.
|
|
18
|
+
"""
|
|
19
|
+
zip_path = Path(zip_path)
|
|
20
|
+
|
|
21
|
+
return zip_path.parent / (str(zip_path.name) + ".unzip")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_env_text(pytorch_version: str) -> str:
|
|
25
|
+
"""Create environment text for the bioimage model.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
pytorch_version : str
|
|
30
|
+
Pytorch version.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
str
|
|
35
|
+
Environment text.
|
|
36
|
+
"""
|
|
37
|
+
env = (
|
|
38
|
+
f"name: careamics\n"
|
|
39
|
+
f"dependencies:\n"
|
|
40
|
+
f" - python=3.8\n"
|
|
41
|
+
f" - pytorch={pytorch_version}\n"
|
|
42
|
+
f" - torchvision={pytorch_version}\n"
|
|
43
|
+
f" - pip\n"
|
|
44
|
+
f" - pip:\n"
|
|
45
|
+
f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n"
|
|
46
|
+
)
|
|
47
|
+
# TODO from pip with package version
|
|
48
|
+
return env
|
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""Module use to build BMZ model description."""
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from bioimageio.spec.model.v0_5 import (
|
|
7
|
+
ArchitectureFromLibraryDescr,
|
|
8
|
+
Author,
|
|
9
|
+
AxisBase,
|
|
10
|
+
AxisId,
|
|
11
|
+
BatchAxis,
|
|
12
|
+
ChannelAxis,
|
|
13
|
+
EnvironmentFileDescr,
|
|
14
|
+
FileDescr,
|
|
15
|
+
FixedZeroMeanUnitVarianceDescr,
|
|
16
|
+
FixedZeroMeanUnitVarianceKwargs,
|
|
17
|
+
Identifier,
|
|
18
|
+
InputTensorDescr,
|
|
19
|
+
ModelDescr,
|
|
20
|
+
OutputTensorDescr,
|
|
21
|
+
PytorchStateDictWeightsDescr,
|
|
22
|
+
SpaceInputAxis,
|
|
23
|
+
SpaceOutputAxis,
|
|
24
|
+
TensorId,
|
|
25
|
+
Version,
|
|
26
|
+
WeightsDescr,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from careamics.config import Configuration, DataConfig
|
|
30
|
+
|
|
31
|
+
from ._readme_factory import readme_factory
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _create_axes(
|
|
35
|
+
array: np.ndarray,
|
|
36
|
+
data_config: DataConfig,
|
|
37
|
+
channel_names: Optional[List[str]] = None,
|
|
38
|
+
is_input: bool = True,
|
|
39
|
+
) -> List[AxisBase]:
|
|
40
|
+
"""Create axes description.
|
|
41
|
+
|
|
42
|
+
Array shape is expected to be SC(Z)YX.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
array : np.ndarray
|
|
47
|
+
Array.
|
|
48
|
+
data_config : DataModel
|
|
49
|
+
CAREamics data configuration.
|
|
50
|
+
channel_names : Optional[List[str]], optional
|
|
51
|
+
Channel names, by default None.
|
|
52
|
+
is_input : bool, optional
|
|
53
|
+
Whether the axes are input axes, by default True.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
List[AxisBase]
|
|
58
|
+
List of axes description.
|
|
59
|
+
|
|
60
|
+
Raises
|
|
61
|
+
------
|
|
62
|
+
ValueError
|
|
63
|
+
If channel names are not provided when channel axis is present.
|
|
64
|
+
"""
|
|
65
|
+
# axes have to be SC(Z)YX
|
|
66
|
+
spatial_axes = data_config.axes.replace("S", "").replace("C", "")
|
|
67
|
+
|
|
68
|
+
# batch is always present
|
|
69
|
+
axes_model = [BatchAxis()]
|
|
70
|
+
|
|
71
|
+
if "C" in data_config.axes:
|
|
72
|
+
if channel_names is not None:
|
|
73
|
+
axes_model.append(
|
|
74
|
+
ChannelAxis(channel_names=[Identifier(name) for name in channel_names])
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Channel names must be provided if channel axis is present, axes: "
|
|
79
|
+
f"{data_config.axes}."
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
# singleton channel
|
|
83
|
+
axes_model.append(ChannelAxis(channel_names=[Identifier("channel")]))
|
|
84
|
+
|
|
85
|
+
# spatial axes
|
|
86
|
+
for ind, axes in enumerate(spatial_axes):
|
|
87
|
+
if axes in ["X", "Y", "Z"]:
|
|
88
|
+
if is_input:
|
|
89
|
+
axes_model.append(
|
|
90
|
+
SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
axes_model.append(
|
|
94
|
+
SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return axes_model
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _create_inputs_ouputs(
|
|
101
|
+
input_array: np.ndarray,
|
|
102
|
+
output_array: np.ndarray,
|
|
103
|
+
data_config: DataConfig,
|
|
104
|
+
input_path: Union[Path, str],
|
|
105
|
+
output_path: Union[Path, str],
|
|
106
|
+
channel_names: Optional[List[str]] = None,
|
|
107
|
+
) -> Tuple[InputTensorDescr, OutputTensorDescr]:
|
|
108
|
+
"""Create input and output tensor description.
|
|
109
|
+
|
|
110
|
+
Input and output paths must point to a `.npy` file.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
input_array : np.ndarray
|
|
115
|
+
Input array.
|
|
116
|
+
output_array : np.ndarray
|
|
117
|
+
Output array.
|
|
118
|
+
data_config : DataModel
|
|
119
|
+
CAREamics data configuration.
|
|
120
|
+
input_path : Union[Path, str]
|
|
121
|
+
Path to input .npy file.
|
|
122
|
+
output_path : Union[Path, str]
|
|
123
|
+
Path to output .npy file.
|
|
124
|
+
channel_names : Optional[List[str]], optional
|
|
125
|
+
Channel names, by default None.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
Tuple[InputTensorDescr, OutputTensorDescr]
|
|
130
|
+
Input and output tensor descriptions.
|
|
131
|
+
"""
|
|
132
|
+
input_axes = _create_axes(input_array, data_config, channel_names)
|
|
133
|
+
output_axes = _create_axes(output_array, data_config, channel_names, False)
|
|
134
|
+
|
|
135
|
+
# mean and std
|
|
136
|
+
assert data_config.mean is not None, "Mean cannot be None."
|
|
137
|
+
assert data_config.std is not None, "Std cannot be None."
|
|
138
|
+
mean = data_config.mean
|
|
139
|
+
std = data_config.std
|
|
140
|
+
|
|
141
|
+
# and the mean and std required to invert the normalization
|
|
142
|
+
# CAREamics denormalization: x = y * (std + eps) + mean
|
|
143
|
+
# BMZ normalization : x = (y - mean') / (std' + eps)
|
|
144
|
+
# to apply the BMZ normalization as a denormalization step, we need:
|
|
145
|
+
eps = 1e-6
|
|
146
|
+
inv_mean = -mean / (std + eps)
|
|
147
|
+
inv_std = 1 / (std + eps) - eps
|
|
148
|
+
|
|
149
|
+
# create input/output descriptions
|
|
150
|
+
input_descr = InputTensorDescr(
|
|
151
|
+
id=TensorId("input"),
|
|
152
|
+
axes=input_axes,
|
|
153
|
+
test_tensor=FileDescr(source=input_path),
|
|
154
|
+
preprocessing=[
|
|
155
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
156
|
+
kwargs=FixedZeroMeanUnitVarianceKwargs(mean=mean, std=std)
|
|
157
|
+
)
|
|
158
|
+
],
|
|
159
|
+
)
|
|
160
|
+
output_descr = OutputTensorDescr(
|
|
161
|
+
id=TensorId("prediction"),
|
|
162
|
+
axes=output_axes,
|
|
163
|
+
test_tensor=FileDescr(source=output_path),
|
|
164
|
+
postprocessing=[
|
|
165
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
166
|
+
kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization
|
|
167
|
+
mean=inv_mean, std=inv_std
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
],
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return input_descr, output_descr
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def create_model_description(
|
|
177
|
+
config: Configuration,
|
|
178
|
+
name: str,
|
|
179
|
+
general_description: str,
|
|
180
|
+
authors: List[Author],
|
|
181
|
+
inputs: Union[Path, str],
|
|
182
|
+
outputs: Union[Path, str],
|
|
183
|
+
weights_path: Union[Path, str],
|
|
184
|
+
torch_version: str,
|
|
185
|
+
careamics_version: str,
|
|
186
|
+
config_path: Union[Path, str],
|
|
187
|
+
env_path: Union[Path, str],
|
|
188
|
+
channel_names: Optional[List[str]] = None,
|
|
189
|
+
data_description: Optional[str] = None,
|
|
190
|
+
) -> ModelDescr:
|
|
191
|
+
"""Create model description.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
config : Configuration
|
|
196
|
+
CAREamics configuration.
|
|
197
|
+
name : str
|
|
198
|
+
Name fo the model.
|
|
199
|
+
general_description : str
|
|
200
|
+
General description of the model.
|
|
201
|
+
authors : List[Author]
|
|
202
|
+
Authors of the model.
|
|
203
|
+
inputs : Union[Path, str]
|
|
204
|
+
Path to input .npy file.
|
|
205
|
+
outputs : Union[Path, str]
|
|
206
|
+
Path to output .npy file.
|
|
207
|
+
weights_path : Union[Path, str]
|
|
208
|
+
Path to model weights.
|
|
209
|
+
torch_version : str
|
|
210
|
+
Pytorch version.
|
|
211
|
+
careamics_version : str
|
|
212
|
+
CAREamics version.
|
|
213
|
+
config_path : Union[Path, str]
|
|
214
|
+
Path to model configuration.
|
|
215
|
+
env_path : Union[Path, str]
|
|
216
|
+
Path to environment file.
|
|
217
|
+
channel_names : Optional[List[str]], optional
|
|
218
|
+
Channel names, by default None.
|
|
219
|
+
data_description : Optional[str], optional
|
|
220
|
+
Description of the data, by default None.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
ModelDescr
|
|
225
|
+
Model description.
|
|
226
|
+
"""
|
|
227
|
+
# documentation
|
|
228
|
+
doc = readme_factory(
|
|
229
|
+
config,
|
|
230
|
+
careamics_version=careamics_version,
|
|
231
|
+
data_description=data_description,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# inputs, outputs
|
|
235
|
+
input_descr, output_descr = _create_inputs_ouputs(
|
|
236
|
+
input_array=np.load(inputs),
|
|
237
|
+
output_array=np.load(outputs),
|
|
238
|
+
data_config=config.data_config,
|
|
239
|
+
input_path=inputs,
|
|
240
|
+
output_path=outputs,
|
|
241
|
+
channel_names=channel_names,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# weights description
|
|
245
|
+
architecture_descr = ArchitectureFromLibraryDescr(
|
|
246
|
+
import_from="careamics.models",
|
|
247
|
+
callable=f"{config.algorithm_config.model.architecture}",
|
|
248
|
+
kwargs=config.algorithm_config.model.model_dump(),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
weights_descr = WeightsDescr(
|
|
252
|
+
pytorch_state_dict=PytorchStateDictWeightsDescr(
|
|
253
|
+
source=weights_path,
|
|
254
|
+
architecture=architecture_descr,
|
|
255
|
+
pytorch_version=Version(torch_version),
|
|
256
|
+
dependencies=EnvironmentFileDescr(source=env_path),
|
|
257
|
+
),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# overall model description
|
|
261
|
+
model = ModelDescr(
|
|
262
|
+
name=name,
|
|
263
|
+
authors=authors,
|
|
264
|
+
description=general_description,
|
|
265
|
+
documentation=doc,
|
|
266
|
+
inputs=[input_descr],
|
|
267
|
+
outputs=[output_descr],
|
|
268
|
+
tags=config.get_algorithm_keywords(),
|
|
269
|
+
links=[
|
|
270
|
+
"https://github.com/CAREamics/careamics",
|
|
271
|
+
"https://careamics.github.io/latest/",
|
|
272
|
+
],
|
|
273
|
+
license="BSD-3-Clause",
|
|
274
|
+
version="0.1.0",
|
|
275
|
+
weights=weights_descr,
|
|
276
|
+
attachments=[FileDescr(source=config_path)],
|
|
277
|
+
cite=config.get_algorithm_citations(),
|
|
278
|
+
config={ # conversion from float32 to float64 creates small differences...
|
|
279
|
+
"bioimageio": {
|
|
280
|
+
"test_kwargs": {
|
|
281
|
+
"pytorch_state_dict": {
|
|
282
|
+
"decimals": 2, # ...so we relax the constraints on the decimals
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
},
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return model
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]:
|
|
293
|
+
"""Return the relative path to the weights and configuration files.
|
|
294
|
+
|
|
295
|
+
Parameters
|
|
296
|
+
----------
|
|
297
|
+
model_desc : ModelDescr
|
|
298
|
+
Model description.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
Tuple[Path, Path]
|
|
303
|
+
Weights and configuration paths.
|
|
304
|
+
"""
|
|
305
|
+
weights_path = model_desc.weights.pytorch_state_dict.source.path
|
|
306
|
+
|
|
307
|
+
if len(model_desc.attachments) == 1:
|
|
308
|
+
config_path = model_desc.attachments[0].source.path
|
|
309
|
+
else:
|
|
310
|
+
for file in model_desc.attachments:
|
|
311
|
+
if file.source.path.suffix == ".yml":
|
|
312
|
+
config_path = file.source.path
|
|
313
|
+
break
|
|
314
|
+
|
|
315
|
+
if config_path is None:
|
|
316
|
+
raise ValueError("Configuration file not found.")
|
|
317
|
+
|
|
318
|
+
return weights_path, config_path
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Function to export to the BioImage Model Zoo format."""
|
|
2
|
+
import tempfile
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pkg_resources
|
|
8
|
+
from bioimageio.core import load_description, test_model
|
|
9
|
+
from bioimageio.spec import ValidationSummary, save_bioimageio_package
|
|
10
|
+
from torch import __version__, load, save
|
|
11
|
+
|
|
12
|
+
from careamics.config import Configuration, load_configuration, save_configuration
|
|
13
|
+
from careamics.config.support import SupportedArchitecture
|
|
14
|
+
from careamics.lightning_module import CAREamicsModule
|
|
15
|
+
|
|
16
|
+
from .bioimage import (
|
|
17
|
+
create_env_text,
|
|
18
|
+
create_model_description,
|
|
19
|
+
extract_model_path,
|
|
20
|
+
get_unzip_path,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _export_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> Path:
|
|
25
|
+
"""
|
|
26
|
+
Export the model state dictionary to a file.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
model : CAREamicsKiln
|
|
31
|
+
CAREamics model to export.
|
|
32
|
+
path : Union[Path, str]
|
|
33
|
+
Path to the file where to save the model state dictionary.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
Path
|
|
38
|
+
Path to the saved model state dictionary.
|
|
39
|
+
"""
|
|
40
|
+
path = Path(path)
|
|
41
|
+
|
|
42
|
+
# make sure it has the correct suffix
|
|
43
|
+
if path.suffix not in ".pth":
|
|
44
|
+
path = path.with_suffix(".pth")
|
|
45
|
+
|
|
46
|
+
# save model state dictionary
|
|
47
|
+
# we save through the torch model itself to avoid the initial "model." in the
|
|
48
|
+
# layers naming, which is incompatible with the way the BMZ load torch state dicts
|
|
49
|
+
save(model.model.state_dict(), path)
|
|
50
|
+
|
|
51
|
+
return path
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _load_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Load a model from a state dictionary.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
model : CAREamicsKiln
|
|
61
|
+
CAREamics model to be updated with the weights.
|
|
62
|
+
path : Union[Path, str]
|
|
63
|
+
Path to the model state dictionary.
|
|
64
|
+
"""
|
|
65
|
+
path = Path(path)
|
|
66
|
+
|
|
67
|
+
# load model state dictionary
|
|
68
|
+
# same as in _export_state_dict, we load through the torch model to be compatible
|
|
69
|
+
# witht bioimageio.core expectations for a torch state dict
|
|
70
|
+
state_dict = load(path)
|
|
71
|
+
model.model.load_state_dict(state_dict)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# TODO break down in subfunctions
|
|
75
|
+
def export_to_bmz(
|
|
76
|
+
model: CAREamicsModule,
|
|
77
|
+
config: Configuration,
|
|
78
|
+
path: Union[Path, str],
|
|
79
|
+
name: str,
|
|
80
|
+
general_description: str,
|
|
81
|
+
authors: List[dict],
|
|
82
|
+
input_array: np.ndarray,
|
|
83
|
+
output_array: np.ndarray,
|
|
84
|
+
channel_names: Optional[List[str]] = None,
|
|
85
|
+
data_description: Optional[str] = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Export the model to BioImage Model Zoo format.
|
|
88
|
+
|
|
89
|
+
Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
model : CAREamicsKiln
|
|
94
|
+
CAREamics model to export.
|
|
95
|
+
config : Configuration
|
|
96
|
+
Model configuration.
|
|
97
|
+
path : Union[Path, str]
|
|
98
|
+
Path to the output file.
|
|
99
|
+
name : str
|
|
100
|
+
Model name.
|
|
101
|
+
general_description : str
|
|
102
|
+
General description of the model.
|
|
103
|
+
authors : List[dict]
|
|
104
|
+
Authors of the model.
|
|
105
|
+
input_array : np.ndarray
|
|
106
|
+
Input array.
|
|
107
|
+
output_array : np.ndarray
|
|
108
|
+
Output array.
|
|
109
|
+
channel_names : Optional[List[str]], optional
|
|
110
|
+
Channel names, by default None.
|
|
111
|
+
data_description : Optional[str], optional
|
|
112
|
+
Description of the data, by default None.
|
|
113
|
+
|
|
114
|
+
Raises
|
|
115
|
+
------
|
|
116
|
+
ValueError
|
|
117
|
+
If the model is a Custom model.
|
|
118
|
+
"""
|
|
119
|
+
path = Path(path)
|
|
120
|
+
|
|
121
|
+
# method is not compatible with Custom models
|
|
122
|
+
if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"Exporting Custom models to BioImage Model Zoo format is not supported."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# make sure that input and output arrays have the same shape
|
|
128
|
+
assert input_array.shape == output_array.shape, (
|
|
129
|
+
f"Input ({input_array.shape}) and output ({output_array.shape}) arrays "
|
|
130
|
+
f"have different shapes"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# make sure it has the correct suffix
|
|
134
|
+
if path.suffix not in ".zip":
|
|
135
|
+
path = path.with_suffix(".zip")
|
|
136
|
+
|
|
137
|
+
# versions
|
|
138
|
+
pytorch_version = __version__
|
|
139
|
+
careamics_version = pkg_resources.get_distribution("careamics").version
|
|
140
|
+
|
|
141
|
+
# save files in temporary folder
|
|
142
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
143
|
+
temp_path = Path(tmpdirname)
|
|
144
|
+
|
|
145
|
+
# create environment file
|
|
146
|
+
# TODO move in bioimage module
|
|
147
|
+
env_path = temp_path / "environment.yml"
|
|
148
|
+
env_path.write_text(create_env_text(pytorch_version))
|
|
149
|
+
|
|
150
|
+
# export input and ouputs
|
|
151
|
+
inputs = temp_path / "inputs.npy"
|
|
152
|
+
np.save(inputs, input_array)
|
|
153
|
+
outputs = temp_path / "outputs.npy"
|
|
154
|
+
np.save(outputs, output_array)
|
|
155
|
+
|
|
156
|
+
# export configuration
|
|
157
|
+
config_path = save_configuration(config, temp_path)
|
|
158
|
+
|
|
159
|
+
# export model state dictionary
|
|
160
|
+
weight_path = _export_state_dict(model, temp_path / "weights.pth")
|
|
161
|
+
|
|
162
|
+
# create model description
|
|
163
|
+
model_description = create_model_description(
|
|
164
|
+
config=config,
|
|
165
|
+
name=name,
|
|
166
|
+
general_description=general_description,
|
|
167
|
+
authors=authors,
|
|
168
|
+
inputs=inputs,
|
|
169
|
+
outputs=outputs,
|
|
170
|
+
weights_path=weight_path,
|
|
171
|
+
torch_version=pytorch_version,
|
|
172
|
+
careamics_version=careamics_version,
|
|
173
|
+
config_path=config_path,
|
|
174
|
+
env_path=env_path,
|
|
175
|
+
channel_names=channel_names,
|
|
176
|
+
data_description=data_description,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# test model description
|
|
180
|
+
summary: ValidationSummary = test_model(model_description)
|
|
181
|
+
if summary.status == "failed":
|
|
182
|
+
raise ValueError(f"Model description test failed: {summary}")
|
|
183
|
+
|
|
184
|
+
# save bmz model
|
|
185
|
+
save_bioimageio_package(model_description, output_path=path)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
|
|
189
|
+
"""Load a model from a BioImage Model Zoo archive.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
path : Union[Path, str]
|
|
194
|
+
Path to the BioImage Model Zoo archive.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
Tuple[CAREamicsKiln, Configuration]
|
|
199
|
+
CAREamics model and configuration.
|
|
200
|
+
|
|
201
|
+
Raises
|
|
202
|
+
------
|
|
203
|
+
ValueError
|
|
204
|
+
If the path is not a zip file.
|
|
205
|
+
"""
|
|
206
|
+
path = Path(path)
|
|
207
|
+
|
|
208
|
+
if path.suffix != ".zip":
|
|
209
|
+
raise ValueError(f"Path must be a bioimage.io zip file, got {path}.")
|
|
210
|
+
|
|
211
|
+
# load description, this creates an unzipped folder next to the archive
|
|
212
|
+
model_desc = load_description(path)
|
|
213
|
+
|
|
214
|
+
# extract relative paths
|
|
215
|
+
weights_path, config_path = extract_model_path(model_desc)
|
|
216
|
+
|
|
217
|
+
# create folder path and absolute paths
|
|
218
|
+
unzip_path = get_unzip_path(path)
|
|
219
|
+
weights_path = unzip_path / weights_path
|
|
220
|
+
config_path = unzip_path / config_path
|
|
221
|
+
|
|
222
|
+
# load configuration
|
|
223
|
+
config = load_configuration(config_path)
|
|
224
|
+
|
|
225
|
+
# create careamics lightning module
|
|
226
|
+
model = CAREamicsModule(algorithm_config=config.algorithm_config)
|
|
227
|
+
|
|
228
|
+
# load model state dictionary
|
|
229
|
+
_load_state_dict(model, weights_path)
|
|
230
|
+
|
|
231
|
+
return model, config
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Utility functions to load pretrained models."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Tuple, Union
|
|
5
|
+
|
|
6
|
+
from torch import load
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration
|
|
9
|
+
from careamics.lightning_module import CAREamicsModule
|
|
10
|
+
from careamics.model_io.bmz_io import load_from_bmz
|
|
11
|
+
from careamics.utils import check_path_exists
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
|
|
15
|
+
"""
|
|
16
|
+
Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
|
|
17
|
+
|
|
18
|
+
Expected formats are .ckpt or .zip files.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
path : Union[Path, str]
|
|
23
|
+
Path to the pretrained model.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Tuple[CAREamicsKiln, Configuration]
|
|
28
|
+
Tuple of CAREamics model and its configuration.
|
|
29
|
+
|
|
30
|
+
Raises
|
|
31
|
+
------
|
|
32
|
+
ValueError
|
|
33
|
+
If the model format is not supported.
|
|
34
|
+
"""
|
|
35
|
+
path = check_path_exists(path)
|
|
36
|
+
|
|
37
|
+
if path.suffix == ".ckpt":
|
|
38
|
+
return _load_checkpoint(path)
|
|
39
|
+
elif path.suffix == ".zip":
|
|
40
|
+
return load_from_bmz(path)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
|
|
48
|
+
"""
|
|
49
|
+
Load a model from a checkpoint and return both model and configuration.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
path : Union[Path, str]
|
|
54
|
+
Path to the checkpoint.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
Tuple[CAREamicsKiln, Configuration]
|
|
59
|
+
Tuple of CAREamics model and its configuration.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If the checkpoint file does not contain hyper parameters (configuration).
|
|
65
|
+
"""
|
|
66
|
+
# load checkpoint
|
|
67
|
+
checkpoint: dict = load(path)
|
|
68
|
+
|
|
69
|
+
# attempt to load configuration
|
|
70
|
+
try:
|
|
71
|
+
cfg_dict = checkpoint["hyper_parameters"]
|
|
72
|
+
except KeyError as e:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Invalid checkpoint file. No `hyper_parameters` found in the "
|
|
75
|
+
f"checkpoint: {checkpoint.keys()}"
|
|
76
|
+
) from e
|
|
77
|
+
|
|
78
|
+
model = CAREamicsModule.load_from_checkpoint(path)
|
|
79
|
+
|
|
80
|
+
return model, Configuration(**cfg_dict)
|