careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,48 @@
1
+ """Bioimage.io utils."""
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+
6
+ def get_unzip_path(zip_path: Union[Path, str]) -> Path:
7
+ """Generate unzipped folder path from the bioimage.io model path.
8
+
9
+ Parameters
10
+ ----------
11
+ zip_path : Path
12
+ Path to the bioimage.io model.
13
+
14
+ Returns
15
+ -------
16
+ Path
17
+ Path to the unzipped folder.
18
+ """
19
+ zip_path = Path(zip_path)
20
+
21
+ return zip_path.parent / (str(zip_path.name) + ".unzip")
22
+
23
+
24
+ def create_env_text(pytorch_version: str) -> str:
25
+ """Create environment text for the bioimage model.
26
+
27
+ Parameters
28
+ ----------
29
+ pytorch_version : str
30
+ Pytorch version.
31
+
32
+ Returns
33
+ -------
34
+ str
35
+ Environment text.
36
+ """
37
+ env = (
38
+ f"name: careamics\n"
39
+ f"dependencies:\n"
40
+ f" - python=3.8\n"
41
+ f" - pytorch={pytorch_version}\n"
42
+ f" - torchvision={pytorch_version}\n"
43
+ f" - pip\n"
44
+ f" - pip:\n"
45
+ f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n"
46
+ )
47
+ # TODO from pip with package version
48
+ return env
@@ -0,0 +1,318 @@
1
+ """Module use to build BMZ model description."""
2
+ from pathlib import Path
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from bioimageio.spec.model.v0_5 import (
7
+ ArchitectureFromLibraryDescr,
8
+ Author,
9
+ AxisBase,
10
+ AxisId,
11
+ BatchAxis,
12
+ ChannelAxis,
13
+ EnvironmentFileDescr,
14
+ FileDescr,
15
+ FixedZeroMeanUnitVarianceDescr,
16
+ FixedZeroMeanUnitVarianceKwargs,
17
+ Identifier,
18
+ InputTensorDescr,
19
+ ModelDescr,
20
+ OutputTensorDescr,
21
+ PytorchStateDictWeightsDescr,
22
+ SpaceInputAxis,
23
+ SpaceOutputAxis,
24
+ TensorId,
25
+ Version,
26
+ WeightsDescr,
27
+ )
28
+
29
+ from careamics.config import Configuration, DataModel
30
+
31
+ from ._readme_factory import readme_factory
32
+
33
+
34
+ def _create_axes(
35
+ array: np.ndarray,
36
+ data_config: DataModel,
37
+ channel_names: Optional[List[str]] = None,
38
+ is_input: bool = True,
39
+ ) -> List[AxisBase]:
40
+ """Create axes description.
41
+
42
+ Array shape is expected to be SC(Z)YX.
43
+
44
+ Parameters
45
+ ----------
46
+ array : np.ndarray
47
+ Array.
48
+ data_config : DataModel
49
+ CAREamics data configuration.
50
+ channel_names : Optional[List[str]], optional
51
+ Channel names, by default None.
52
+ is_input : bool, optional
53
+ Whether the axes are input axes, by default True.
54
+
55
+ Returns
56
+ -------
57
+ List[AxisBase]
58
+ List of axes description.
59
+
60
+ Raises
61
+ ------
62
+ ValueError
63
+ If channel names are not provided when channel axis is present.
64
+ """
65
+ # axes have to be SC(Z)YX
66
+ spatial_axes = data_config.axes.replace("S", "").replace("C", "")
67
+
68
+ # batch is always present
69
+ axes_model = [BatchAxis()]
70
+
71
+ if "C" in data_config.axes:
72
+ if channel_names is not None:
73
+ axes_model.append(
74
+ ChannelAxis(channel_names=[Identifier(name) for name in channel_names])
75
+ )
76
+ else:
77
+ raise ValueError(
78
+ f"Channel names must be provided if channel axis is present, axes: "
79
+ f"{data_config.axes}."
80
+ )
81
+ else:
82
+ # singleton channel
83
+ axes_model.append(ChannelAxis(channel_names=[Identifier("channel")]))
84
+
85
+ # spatial axes
86
+ for ind, axes in enumerate(spatial_axes):
87
+ if axes in ["X", "Y", "Z"]:
88
+ if is_input:
89
+ axes_model.append(
90
+ SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
91
+ )
92
+ else:
93
+ axes_model.append(
94
+ SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
95
+ )
96
+
97
+ return axes_model
98
+
99
+
100
+ def _create_inputs_ouputs(
101
+ input_array: np.ndarray,
102
+ output_array: np.ndarray,
103
+ data_config: DataModel,
104
+ input_path: Union[Path, str],
105
+ output_path: Union[Path, str],
106
+ channel_names: Optional[List[str]] = None,
107
+ ) -> Tuple[InputTensorDescr, OutputTensorDescr]:
108
+ """Create input and output tensor description.
109
+
110
+ Input and output paths must point to a `.npy` file.
111
+
112
+ Parameters
113
+ ----------
114
+ input_array : np.ndarray
115
+ Input array.
116
+ output_array : np.ndarray
117
+ Output array.
118
+ data_config : DataModel
119
+ CAREamics data configuration.
120
+ input_path : Union[Path, str]
121
+ Path to input .npy file.
122
+ output_path : Union[Path, str]
123
+ Path to output .npy file.
124
+ channel_names : Optional[List[str]], optional
125
+ Channel names, by default None.
126
+
127
+ Returns
128
+ -------
129
+ Tuple[InputTensorDescr, OutputTensorDescr]
130
+ Input and output tensor descriptions.
131
+ """
132
+ input_axes = _create_axes(input_array, data_config, channel_names)
133
+ output_axes = _create_axes(output_array, data_config, channel_names, False)
134
+
135
+ # mean and std
136
+ assert data_config.mean is not None, "Mean cannot be None."
137
+ assert data_config.std is not None, "Std cannot be None."
138
+ mean = data_config.mean
139
+ std = data_config.std
140
+
141
+ # and the mean and std required to invert the normalization
142
+ # CAREamics denormalization: x = y * (std + eps) + mean
143
+ # BMZ normalization : x = (y - mean') / (std' + eps)
144
+ # to apply the BMZ normalization as a denormalization step, we need:
145
+ eps = 1e-6
146
+ inv_mean = -mean / (std + eps)
147
+ inv_std = 1 / (std + eps) - eps
148
+
149
+ # create input/output descriptions
150
+ input_descr = InputTensorDescr(
151
+ id=TensorId("input"),
152
+ axes=input_axes,
153
+ test_tensor=FileDescr(source=input_path),
154
+ preprocessing=[
155
+ FixedZeroMeanUnitVarianceDescr(
156
+ kwargs=FixedZeroMeanUnitVarianceKwargs(mean=mean, std=std)
157
+ )
158
+ ],
159
+ )
160
+ output_descr = OutputTensorDescr(
161
+ id=TensorId("prediction"),
162
+ axes=output_axes,
163
+ test_tensor=FileDescr(source=output_path),
164
+ postprocessing=[
165
+ FixedZeroMeanUnitVarianceDescr(
166
+ kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization
167
+ mean=inv_mean, std=inv_std
168
+ )
169
+ )
170
+ ],
171
+ )
172
+
173
+ return input_descr, output_descr
174
+
175
+
176
+ def create_model_description(
177
+ config: Configuration,
178
+ name: str,
179
+ general_description: str,
180
+ authors: List[Author],
181
+ inputs: Union[Path, str],
182
+ outputs: Union[Path, str],
183
+ weights_path: Union[Path, str],
184
+ torch_version: str,
185
+ careamics_version: str,
186
+ config_path: Union[Path, str],
187
+ env_path: Union[Path, str],
188
+ channel_names: Optional[List[str]] = None,
189
+ data_description: Optional[str] = None,
190
+ ) -> ModelDescr:
191
+ """Create model description.
192
+
193
+ Parameters
194
+ ----------
195
+ config : Configuration
196
+ CAREamics configuration.
197
+ name : str
198
+ Name fo the model.
199
+ general_description : str
200
+ General description of the model.
201
+ authors : List[Author]
202
+ Authors of the model.
203
+ inputs : Union[Path, str]
204
+ Path to input .npy file.
205
+ outputs : Union[Path, str]
206
+ Path to output .npy file.
207
+ weights_path : Union[Path, str]
208
+ Path to model weights.
209
+ torch_version : str
210
+ Pytorch version.
211
+ careamics_version : str
212
+ CAREamics version.
213
+ config_path : Union[Path, str]
214
+ Path to model configuration.
215
+ env_path : Union[Path, str]
216
+ Path to environment file.
217
+ channel_names : Optional[List[str]], optional
218
+ Channel names, by default None.
219
+ data_description : Optional[str], optional
220
+ Description of the data, by default None.
221
+
222
+ Returns
223
+ -------
224
+ ModelDescr
225
+ Model description.
226
+ """
227
+ # documentation
228
+ doc = readme_factory(
229
+ config,
230
+ careamics_version=careamics_version,
231
+ data_description=data_description,
232
+ )
233
+
234
+ # inputs, outputs
235
+ input_descr, output_descr = _create_inputs_ouputs(
236
+ input_array=np.load(inputs),
237
+ output_array=np.load(outputs),
238
+ data_config=config.data_config,
239
+ input_path=inputs,
240
+ output_path=outputs,
241
+ channel_names=channel_names,
242
+ )
243
+
244
+ # weights description
245
+ architecture_descr = ArchitectureFromLibraryDescr(
246
+ import_from="careamics.models",
247
+ callable=f"{config.algorithm_config.model.architecture}",
248
+ kwargs=config.algorithm_config.model.model_dump(),
249
+ )
250
+
251
+ weights_descr = WeightsDescr(
252
+ pytorch_state_dict=PytorchStateDictWeightsDescr(
253
+ source=weights_path,
254
+ architecture=architecture_descr,
255
+ pytorch_version=Version(torch_version),
256
+ dependencies=EnvironmentFileDescr(source=env_path),
257
+ ),
258
+ )
259
+
260
+ # overall model description
261
+ model = ModelDescr(
262
+ name=name,
263
+ authors=authors,
264
+ description=general_description,
265
+ documentation=doc,
266
+ inputs=[input_descr],
267
+ outputs=[output_descr],
268
+ tags=config.get_algorithm_keywords(),
269
+ links=[
270
+ "https://github.com/CAREamics/careamics",
271
+ "https://careamics.github.io/latest/",
272
+ ],
273
+ license="BSD-3-Clause",
274
+ version="0.1.0",
275
+ weights=weights_descr,
276
+ attachments=[FileDescr(source=config_path)],
277
+ cite=config.get_algorithm_citations(),
278
+ config={ # conversion from float32 to float64 creates small differences...
279
+ "bioimageio": {
280
+ "test_kwargs": {
281
+ "pytorch_state_dict": {
282
+ "decimals": 2, # ...so we relax the constraints on the decimals
283
+ }
284
+ }
285
+ }
286
+ },
287
+ )
288
+
289
+ return model
290
+
291
+
292
+ def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]:
293
+ """Return the relative path to the weights and configuration files.
294
+
295
+ Parameters
296
+ ----------
297
+ model_desc : ModelDescr
298
+ Model description.
299
+
300
+ Returns
301
+ -------
302
+ Tuple[Path, Path]
303
+ Weights and configuration paths.
304
+ """
305
+ weights_path = model_desc.weights.pytorch_state_dict.source.path
306
+
307
+ if len(model_desc.attachments) == 1:
308
+ config_path = model_desc.attachments[0].source.path
309
+ else:
310
+ for file in model_desc.attachments:
311
+ if file.source.path.suffix == ".yml":
312
+ config_path = file.source.path
313
+ break
314
+
315
+ if config_path is None:
316
+ raise ValueError("Configuration file not found.")
317
+
318
+ return weights_path, config_path
@@ -0,0 +1,231 @@
1
+ """Function to export to the BioImage Model Zoo format."""
2
+ import tempfile
3
+ from pathlib import Path
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import pkg_resources
8
+ from bioimageio.core import load_description, test_model
9
+ from bioimageio.spec import ValidationSummary, save_bioimageio_package
10
+ from torch import __version__, load, save
11
+
12
+ from careamics.config import Configuration, load_configuration, save_configuration
13
+ from careamics.config.support import SupportedArchitecture
14
+ from careamics.lightning_module import CAREamicsKiln
15
+
16
+ from .bioimage import (
17
+ create_env_text,
18
+ create_model_description,
19
+ extract_model_path,
20
+ get_unzip_path,
21
+ )
22
+
23
+
24
+ def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path:
25
+ """
26
+ Export the model state dictionary to a file.
27
+
28
+ Parameters
29
+ ----------
30
+ model : CAREamicsKiln
31
+ CAREamics model to export.
32
+ path : Union[Path, str]
33
+ Path to the file where to save the model state dictionary.
34
+
35
+ Returns
36
+ -------
37
+ Path
38
+ Path to the saved model state dictionary.
39
+ """
40
+ path = Path(path)
41
+
42
+ # make sure it has the correct suffix
43
+ if path.suffix not in ".pth":
44
+ path = path.with_suffix(".pth")
45
+
46
+ # save model state dictionary
47
+ # we save through the torch model itself to avoid the initial "model." in the
48
+ # layers naming, which is incompatible with the way the BMZ load torch state dicts
49
+ save(model.model.state_dict(), path)
50
+
51
+ return path
52
+
53
+
54
+ def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None:
55
+ """
56
+ Load a model from a state dictionary.
57
+
58
+ Parameters
59
+ ----------
60
+ model : CAREamicsKiln
61
+ CAREamics model to be updated with the weights.
62
+ path : Union[Path, str]
63
+ Path to the model state dictionary.
64
+ """
65
+ path = Path(path)
66
+
67
+ # load model state dictionary
68
+ # same as in _export_state_dict, we load through the torch model to be compatible
69
+ # witht bioimageio.core expectations for a torch state dict
70
+ state_dict = load(path)
71
+ model.model.load_state_dict(state_dict)
72
+
73
+
74
+ # TODO break down in subfunctions
75
+ def export_to_bmz(
76
+ model: CAREamicsKiln,
77
+ config: Configuration,
78
+ path: Union[Path, str],
79
+ name: str,
80
+ general_description: str,
81
+ authors: List[dict],
82
+ input_array: np.ndarray,
83
+ output_array: np.ndarray,
84
+ channel_names: Optional[List[str]] = None,
85
+ data_description: Optional[str] = None,
86
+ ) -> None:
87
+ """Export the model to BioImage Model Zoo format.
88
+
89
+ Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C.
90
+
91
+ Parameters
92
+ ----------
93
+ model : CAREamicsKiln
94
+ CAREamics model to export.
95
+ config : Configuration
96
+ Model configuration.
97
+ path : Union[Path, str]
98
+ Path to the output file.
99
+ name : str
100
+ Model name.
101
+ general_description : str
102
+ General description of the model.
103
+ authors : List[dict]
104
+ Authors of the model.
105
+ input_array : np.ndarray
106
+ Input array.
107
+ output_array : np.ndarray
108
+ Output array.
109
+ channel_names : Optional[List[str]], optional
110
+ Channel names, by default None.
111
+ data_description : Optional[str], optional
112
+ Description of the data, by default None.
113
+
114
+ Raises
115
+ ------
116
+ ValueError
117
+ If the model is a Custom model.
118
+ """
119
+ path = Path(path)
120
+
121
+ # method is not compatible with Custom models
122
+ if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
123
+ raise ValueError(
124
+ "Exporting Custom models to BioImage Model Zoo format is not supported."
125
+ )
126
+
127
+ # make sure that input and output arrays have the same shape
128
+ assert input_array.shape == output_array.shape, (
129
+ f"Input ({input_array.shape}) and output ({output_array.shape}) arrays "
130
+ f"have different shapes"
131
+ )
132
+
133
+ # make sure it has the correct suffix
134
+ if path.suffix not in ".zip":
135
+ path = path.with_suffix(".zip")
136
+
137
+ # versions
138
+ pytorch_version = __version__
139
+ careamics_version = pkg_resources.get_distribution("careamics").version
140
+
141
+ # save files in temporary folder
142
+ with tempfile.TemporaryDirectory() as tmpdirname:
143
+ temp_path = Path(tmpdirname)
144
+
145
+ # create environment file
146
+ # TODO move in bioimage module
147
+ env_path = temp_path / "environment.yml"
148
+ env_path.write_text(create_env_text(pytorch_version))
149
+
150
+ # export input and ouputs
151
+ inputs = temp_path / "inputs.npy"
152
+ np.save(inputs, input_array)
153
+ outputs = temp_path / "outputs.npy"
154
+ np.save(outputs, output_array)
155
+
156
+ # export configuration
157
+ config_path = save_configuration(config, temp_path)
158
+
159
+ # export model state dictionary
160
+ weight_path = _export_state_dict(model, temp_path / "weights.pth")
161
+
162
+ # create model description
163
+ model_description = create_model_description(
164
+ config=config,
165
+ name=name,
166
+ general_description=general_description,
167
+ authors=authors,
168
+ inputs=inputs,
169
+ outputs=outputs,
170
+ weights_path=weight_path,
171
+ torch_version=pytorch_version,
172
+ careamics_version=careamics_version,
173
+ config_path=config_path,
174
+ env_path=env_path,
175
+ channel_names=channel_names,
176
+ data_description=data_description,
177
+ )
178
+
179
+ # test model description
180
+ summary: ValidationSummary = test_model(model_description)
181
+ if summary.status == "failed":
182
+ raise ValueError(f"Model description test failed: {summary}")
183
+
184
+ # save bmz model
185
+ save_bioimageio_package(model_description, output_path=path)
186
+
187
+
188
+ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
189
+ """Load a model from a BioImage Model Zoo archive.
190
+
191
+ Parameters
192
+ ----------
193
+ path : Union[Path, str]
194
+ Path to the BioImage Model Zoo archive.
195
+
196
+ Returns
197
+ -------
198
+ Tuple[CAREamicsKiln, Configuration]
199
+ CAREamics model and configuration.
200
+
201
+ Raises
202
+ ------
203
+ ValueError
204
+ If the path is not a zip file.
205
+ """
206
+ path = Path(path)
207
+
208
+ if path.suffix != ".zip":
209
+ raise ValueError(f"Path must be a bioimage.io zip file, got {path}.")
210
+
211
+ # load description, this creates an unzipped folder next to the archive
212
+ model_desc = load_description(path)
213
+
214
+ # extract relative paths
215
+ weights_path, config_path = extract_model_path(model_desc)
216
+
217
+ # create folder path and absolute paths
218
+ unzip_path = get_unzip_path(path)
219
+ weights_path = unzip_path / weights_path
220
+ config_path = unzip_path / config_path
221
+
222
+ # load configuration
223
+ config = load_configuration(config_path)
224
+
225
+ # create careamics lightning module
226
+ model = CAREamicsKiln(algorithm_config=config.algorithm_config)
227
+
228
+ # load model state dictionary
229
+ _load_state_dict(model, weights_path)
230
+
231
+ return model, config
@@ -0,0 +1,80 @@
1
+ """Utility functions to load pretrained models."""
2
+
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ from torch import load
7
+
8
+ from careamics.config import Configuration
9
+ from careamics.lightning_module import CAREamicsKiln
10
+ from careamics.model_io.bmz_io import load_from_bmz
11
+ from careamics.utils import check_path_exists
12
+
13
+
14
+ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
15
+ """
16
+ Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
17
+
18
+ Expected formats are .ckpt or .zip files.
19
+
20
+ Parameters
21
+ ----------
22
+ path : Union[Path, str]
23
+ Path to the pretrained model.
24
+
25
+ Returns
26
+ -------
27
+ Tuple[CAREamicsKiln, Configuration]
28
+ Tuple of CAREamics model and its configuration.
29
+
30
+ Raises
31
+ ------
32
+ ValueError
33
+ If the model format is not supported.
34
+ """
35
+ path = check_path_exists(path)
36
+
37
+ if path.suffix == ".ckpt":
38
+ return _load_checkpoint(path)
39
+ elif path.suffix == ".zip":
40
+ return load_from_bmz(path)
41
+ else:
42
+ raise ValueError(
43
+ f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}."
44
+ )
45
+
46
+
47
+ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
48
+ """
49
+ Load a model from a checkpoint and return both model and configuration.
50
+
51
+ Parameters
52
+ ----------
53
+ path : Union[Path, str]
54
+ Path to the checkpoint.
55
+
56
+ Returns
57
+ -------
58
+ Tuple[CAREamicsKiln, Configuration]
59
+ Tuple of CAREamics model and its configuration.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If the checkpoint file does not contain hyper parameters (configuration).
65
+ """
66
+ # load checkpoint
67
+ checkpoint: dict = load(path)
68
+
69
+ # attempt to load configuration
70
+ try:
71
+ cfg_dict = checkpoint["hyper_parameters"]
72
+ except KeyError as e:
73
+ raise ValueError(
74
+ f"Invalid checkpoint file. No `hyper_parameters` found in the "
75
+ f"checkpoint: {checkpoint.keys()}"
76
+ ) from e
77
+
78
+ model = CAREamicsKiln.load_from_checkpoint(path)
79
+
80
+ return model, Configuration(**cfg_dict)
@@ -1,4 +1,7 @@
1
1
  """Models package."""
2
2
 
3
- from .model_factory import create_model as create_model
3
+ __all__ = ["model_factory", "UNet"]
4
+
5
+
6
+ from .model_factory import model_factory
4
7
  from .unet import UNet as UNet