careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,171 @@
1
+ """Convenience function to create covers for the BMZ."""
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from PIL import Image
8
+
9
+ color_palette = np.array(
10
+ [
11
+ np.array([255, 195, 0]), # grey
12
+ np.array([189, 226, 240]),
13
+ np.array([96, 60, 76]),
14
+ np.array([193, 225, 193]),
15
+ ]
16
+ )
17
+
18
+
19
+ def _get_norm_slice(array: NDArray) -> NDArray:
20
+ """Get the normalized middle slice of a 4D or 5D array (SC(Z)YX).
21
+
22
+ Parameters
23
+ ----------
24
+ array : NDArray
25
+ Array from which to get the middle slice.
26
+
27
+ Returns
28
+ -------
29
+ NDArray
30
+ Normalized middle slice of the input array.
31
+ """
32
+ if array.ndim not in (4, 5):
33
+ raise ValueError("Array must be 4D or 5D.")
34
+
35
+ channels = array.shape[1] > 1
36
+ z_stack = array.ndim == 5
37
+
38
+ # get slice
39
+ if z_stack:
40
+ array_slice = array[0, :, array.shape[2] // 2, ...]
41
+ else:
42
+ array_slice = array[0, ...]
43
+
44
+ # channels
45
+ if channels:
46
+ array_slice = np.moveaxis(array_slice, 0, -1)
47
+ else:
48
+ array_slice = array_slice[0, ...]
49
+
50
+ # normalize
51
+ array_slice = (
52
+ 255
53
+ * (array_slice - array_slice.min())
54
+ / (array_slice.max() - array_slice.min())
55
+ )
56
+
57
+ return array_slice.astype(np.uint8)
58
+
59
+
60
+ def _four_channel_image(array: NDArray) -> Image:
61
+ """Convert 4-channel array to Image.
62
+
63
+ Parameters
64
+ ----------
65
+ array : NDArray
66
+ Normalized array to convert.
67
+
68
+ Returns
69
+ -------
70
+ Image
71
+ Converted array.
72
+ """
73
+ colors = color_palette[np.newaxis, np.newaxis, :, :]
74
+ four_c_array = np.sum(array[..., :4, np.newaxis] * colors, axis=-2).astype(np.uint8)
75
+
76
+ return Image.fromarray(four_c_array).convert("RGB")
77
+
78
+
79
+ def _convert_to_image(original_shape: tuple[int, ...], array: NDArray) -> Image:
80
+ """Convert to Image.
81
+
82
+ Parameters
83
+ ----------
84
+ original_shape : tuple
85
+ Original shape of the array.
86
+ array : NDArray
87
+ Normalized array to convert.
88
+
89
+ Returns
90
+ -------
91
+ Image
92
+ Converted array.
93
+ """
94
+ n_channels = original_shape[1]
95
+
96
+ if n_channels > 1:
97
+ if n_channels == 3:
98
+ return Image.fromarray(array).convert("RGB")
99
+ elif n_channels == 2:
100
+ # add an empty channel to the numpy array
101
+ array = np.concatenate([np.zeros_like(array[..., 0:1]), array], axis=-1)
102
+
103
+ return Image.fromarray(array).convert("RGB")
104
+ else: # more than 4
105
+ return _four_channel_image(array[..., :4])
106
+ else:
107
+ return Image.fromarray(array).convert("L").convert("RGB")
108
+
109
+
110
+ def create_cover(directory: Path, array_in: NDArray, array_out: NDArray) -> Path:
111
+ """Create a cover image from input and output arrays.
112
+
113
+ Input and output arrays are expected to be SC(Z)YX. For images with a Z
114
+ dimension, the middle slice is taken.
115
+
116
+ Parameters
117
+ ----------
118
+ directory : Path
119
+ Directory in which to save the cover.
120
+ array_in : numpy.ndarray
121
+ Array from which to create the cover image.
122
+ array_out : numpy.ndarray
123
+ Array from which to create the cover image.
124
+
125
+ Returns
126
+ -------
127
+ Path
128
+ Path to the saved cover image.
129
+ """
130
+ # extract slice and normalize arrays
131
+ slice_in = _get_norm_slice(array_in)
132
+ slice_out = _get_norm_slice(array_out)
133
+
134
+ horizontal_split = slice_in.shape[-1] == slice_out.shape[-1]
135
+ if not horizontal_split:
136
+ if slice_in.shape[-2] != slice_out.shape[-2]:
137
+ raise ValueError("Input and output arrays have different shapes.")
138
+
139
+ # convert to Image
140
+ image_in = _convert_to_image(array_in.shape, slice_in)
141
+ image_out = _convert_to_image(array_out.shape, slice_out)
142
+
143
+ # split horizontally or vertically
144
+ if horizontal_split:
145
+ width = image_in.width // 2
146
+
147
+ cover = Image.new("RGB", (image_in.width, image_in.height))
148
+ cover.paste(image_in.crop((0, 0, width, image_in.height)), (0, 0))
149
+ cover.paste(
150
+ image_out.crop(
151
+ (image_in.width - width, 0, image_in.width, image_in.height)
152
+ ),
153
+ (width, 0),
154
+ )
155
+ else:
156
+ height = image_in.height // 2
157
+
158
+ cover = Image.new("RGB", (image_in.width, image_in.height))
159
+ cover.paste(image_in.crop((0, 0, image_in.width, height)), (0, 0))
160
+ cover.paste(
161
+ image_out.crop(
162
+ (0, image_in.height - height, image_in.width, image_in.height)
163
+ ),
164
+ (0, height),
165
+ )
166
+
167
+ # save
168
+ cover_path = directory / "cover.png"
169
+ cover.save(cover_path)
170
+
171
+ return cover_path
@@ -1,9 +1,10 @@
1
1
  """Module use to build BMZ model description."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import List, Optional, Tuple, Union
4
+ from typing import Optional, Union
5
5
 
6
6
  import numpy as np
7
+ from bioimageio.spec._internal.io import resolve_and_extract
7
8
  from bioimageio.spec.model.v0_5 import (
8
9
  ArchitectureFromLibraryDescr,
9
10
  Author,
@@ -27,17 +28,17 @@ from bioimageio.spec.model.v0_5 import (
27
28
  WeightsDescr,
28
29
  )
29
30
 
30
- from careamics.config import Configuration, DataConfig
31
+ from careamics.config import Configuration, GeneralDataConfig
31
32
 
32
33
  from ._readme_factory import readme_factory
33
34
 
34
35
 
35
36
  def _create_axes(
36
37
  array: np.ndarray,
37
- data_config: DataConfig,
38
- channel_names: Optional[List[str]] = None,
38
+ data_config: GeneralDataConfig,
39
+ channel_names: Optional[list[str]] = None,
39
40
  is_input: bool = True,
40
- ) -> List[AxisBase]:
41
+ ) -> list[AxisBase]:
41
42
  """Create axes description.
42
43
 
43
44
  Array shape is expected to be SC(Z)YX.
@@ -48,15 +49,15 @@ def _create_axes(
48
49
  Array.
49
50
  data_config : DataModel
50
51
  CAREamics data configuration.
51
- channel_names : Optional[List[str]], optional
52
+ channel_names : Optional[list[str]], optional
52
53
  Channel names, by default None.
53
54
  is_input : bool, optional
54
55
  Whether the axes are input axes, by default True.
55
56
 
56
57
  Returns
57
58
  -------
58
- List[AxisBase]
59
- List of axes description.
59
+ list[AxisBase]
60
+ list of axes description.
60
61
 
61
62
  Raises
62
63
  ------
@@ -101,11 +102,11 @@ def _create_axes(
101
102
  def _create_inputs_ouputs(
102
103
  input_array: np.ndarray,
103
104
  output_array: np.ndarray,
104
- data_config: DataConfig,
105
+ data_config: GeneralDataConfig,
105
106
  input_path: Union[Path, str],
106
107
  output_path: Union[Path, str],
107
- channel_names: Optional[List[str]] = None,
108
- ) -> Tuple[InputTensorDescr, OutputTensorDescr]:
108
+ channel_names: Optional[list[str]] = None,
109
+ ) -> tuple[InputTensorDescr, OutputTensorDescr]:
109
110
  """Create input and output tensor description.
110
111
 
111
112
  Input and output paths must point to a `.npy` file.
@@ -122,12 +123,12 @@ def _create_inputs_ouputs(
122
123
  Path to input .npy file.
123
124
  output_path : Union[Path, str]
124
125
  Path to output .npy file.
125
- channel_names : Optional[List[str]], optional
126
+ channel_names : Optional[list[str]], optional
126
127
  Channel names, by default None.
127
128
 
128
129
  Returns
129
130
  -------
130
- Tuple[InputTensorDescr, OutputTensorDescr]
131
+ tuple[InputTensorDescr, OutputTensorDescr]
131
132
  Input and output tensor descriptions.
132
133
  """
133
134
  input_axes = _create_axes(input_array, data_config, channel_names)
@@ -186,7 +187,8 @@ def create_model_description(
186
187
  config: Configuration,
187
188
  name: str,
188
189
  general_description: str,
189
- authors: List[Author],
190
+ data_description: str,
191
+ authors: list[Author],
190
192
  inputs: Union[Path, str],
191
193
  outputs: Union[Path, str],
192
194
  weights_path: Union[Path, str],
@@ -194,8 +196,9 @@ def create_model_description(
194
196
  careamics_version: str,
195
197
  config_path: Union[Path, str],
196
198
  env_path: Union[Path, str],
197
- channel_names: Optional[List[str]] = None,
198
- data_description: Optional[str] = None,
199
+ covers: list[Union[Path, str]],
200
+ channel_names: Optional[list[str]] = None,
201
+ model_version: str = "0.1.0",
199
202
  ) -> ModelDescr:
200
203
  """Create model description.
201
204
 
@@ -207,7 +210,9 @@ def create_model_description(
207
210
  Name of the model.
208
211
  general_description : str
209
212
  General description of the model.
210
- authors : List[Author]
213
+ data_description : str
214
+ Description of the data the model was trained on.
215
+ authors : list[Author]
211
216
  Authors of the model.
212
217
  inputs : Union[Path, str]
213
218
  Path to input .npy file.
@@ -223,10 +228,12 @@ def create_model_description(
223
228
  Path to model configuration.
224
229
  env_path : Union[Path, str]
225
230
  Path to environment file.
226
- channel_names : Optional[List[str]], optional
231
+ covers : list of pathlib.Path or str
232
+ Paths to cover images.
233
+ channel_names : Optional[list[str]], optional
227
234
  Channel names, by default None.
228
- data_description : Optional[str], optional
229
- Description of the data, by default None.
235
+ model_version : str, default "0.1.0"
236
+ Model version.
230
237
 
231
238
  Returns
232
239
  -------
@@ -280,16 +287,27 @@ def create_model_description(
280
287
  "https://careamics.github.io/latest/",
281
288
  ],
282
289
  license="BSD-3-Clause",
283
- version="0.1.0",
290
+ config={
291
+ "bioimageio": {
292
+ "test_kwargs": {
293
+ "pytorch_state_dict": {
294
+ "absolute_tolerance": 1e-2,
295
+ "relative_tolerance": 1e-2,
296
+ }
297
+ }
298
+ }
299
+ },
300
+ version=model_version,
284
301
  weights=weights_descr,
285
302
  attachments=[FileDescr(source=config_path)],
286
303
  cite=config.get_algorithm_citations(),
304
+ covers=covers,
287
305
  )
288
306
 
289
307
  return model
290
308
 
291
309
 
292
- def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]:
310
+ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
293
311
  """Return the relative path to the weights and configuration files.
294
312
 
295
313
  Parameters
@@ -299,20 +317,24 @@ def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]:
299
317
 
300
318
  Returns
301
319
  -------
302
- Tuple[Path, Path]
320
+ tuple of (path, path)
303
321
  Weights and configuration paths.
304
322
  """
305
- weights_path = model_desc.weights.pytorch_state_dict.source.path
306
-
307
- if len(model_desc.attachments) == 1:
308
- config_path = model_desc.attachments[0].source.path
323
+ if model_desc.weights.pytorch_state_dict is None:
324
+ raise ValueError("No model weights found in model description.")
325
+ weights_path = resolve_and_extract(
326
+ model_desc.weights.pytorch_state_dict.source
327
+ ).path
328
+
329
+ for file in model_desc.attachments:
330
+ file_path = file.source if isinstance(file.source, Path) else file.source.path
331
+ if file_path is None:
332
+ continue
333
+ file_path = Path(file_path)
334
+ if file_path.name == "careamics.yaml":
335
+ config_path = resolve_and_extract(file.source).path
336
+ break
309
337
  else:
310
- for file in model_desc.attachments:
311
- if file.source.path.suffix == ".yml":
312
- config_path = file.source.path
313
- break
314
-
315
- if config_path is None:
316
- raise ValueError("Configuration file not found.")
338
+ raise ValueError("Configuration file not found.")
317
339
 
318
340
  return weights_path, config_path
@@ -2,12 +2,13 @@
2
2
 
3
3
  import tempfile
4
4
  from pathlib import Path
5
- from typing import List, Optional, Tuple, Union
5
+ from typing import Optional, Union
6
6
 
7
7
  import numpy as np
8
8
  import pkg_resources
9
- from bioimageio.core import load_description, test_model
9
+ from bioimageio.core import load_model_description, test_model
10
10
  from bioimageio.spec import ValidationSummary, save_bioimageio_package
11
+ from pydantic import HttpUrl
11
12
  from torch import __version__ as PYTORCH_VERSION
12
13
  from torch import load, save
13
14
  from torchvision import __version__ as TORCHVISION_VERSION
@@ -20,8 +21,8 @@ from .bioimage import (
20
21
  create_env_text,
21
22
  create_model_description,
22
23
  extract_model_path,
23
- get_unzip_path,
24
24
  )
25
+ from .bioimage.cover_factory import create_cover
25
26
 
26
27
 
27
28
  def _export_state_dict(
@@ -85,11 +86,13 @@ def export_to_bmz(
85
86
  path_to_archive: Union[Path, str],
86
87
  model_name: str,
87
88
  general_description: str,
88
- authors: List[dict],
89
+ data_description: str,
90
+ authors: list[dict],
89
91
  input_array: np.ndarray,
90
92
  output_array: np.ndarray,
91
- channel_names: Optional[List[str]] = None,
92
- data_description: Optional[str] = None,
93
+ covers: Optional[list[Union[Path, str]]] = None,
94
+ channel_names: Optional[list[str]] = None,
95
+ model_version: str = "0.1.0",
93
96
  ) -> None:
94
97
  """Export the model to BioImage Model Zoo format.
95
98
 
@@ -110,30 +113,23 @@ def export_to_bmz(
110
113
  Model name.
111
114
  general_description : str
112
115
  General description of the model.
113
- authors : List[dict]
116
+ data_description : str
117
+ Description of the data the model was trained on.
118
+ authors : list[dict]
114
119
  Authors of the model.
115
120
  input_array : np.ndarray
116
121
  Input array, should not have been normalized.
117
122
  output_array : np.ndarray
118
123
  Output array, should have been denormalized.
119
- channel_names : Optional[List[str]], optional
124
+ covers : list of pathlib.Path or str, default=None
125
+ Paths to the cover images.
126
+ channel_names : Optional[list[str]], optional
120
127
  Channel names, by default None.
121
- data_description : Optional[str], optional
122
- Description of the data, by default None.
123
-
124
- Raises
125
- ------
126
- ValueError
127
- If the model is a Custom model.
128
+ model_version : str, default="0.1.0"
129
+ Model version.
128
130
  """
129
131
  path_to_archive = Path(path_to_archive)
130
132
 
131
- # method is not compatible with Custom models
132
- if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM:
133
- raise ValueError(
134
- "Exporting Custom models to BioImage Model Zoo format is not supported."
135
- )
136
-
137
133
  if path_to_archive.suffix != ".zip":
138
134
  raise ValueError(
139
135
  f"Path to archive must point to a zip file, got {path_to_archive}."
@@ -161,16 +157,21 @@ def export_to_bmz(
161
157
  np.save(outputs, output_array)
162
158
 
163
159
  # export configuration
164
- config_path = save_configuration(config, temp_path)
160
+ config_path = save_configuration(config, temp_path / "careamics.yaml")
165
161
 
166
162
  # export model state dictionary
167
163
  weight_path = _export_state_dict(model, temp_path / "weights.pth")
168
164
 
165
+ # export cover if necesary
166
+ if covers is None:
167
+ covers = [create_cover(temp_path, input_array, output_array)]
168
+
169
169
  # create model description
170
170
  model_description = create_model_description(
171
171
  config=config,
172
172
  name=model_name,
173
173
  general_description=general_description,
174
+ data_description=data_description,
174
175
  authors=authors,
175
176
  inputs=inputs,
176
177
  outputs=outputs,
@@ -179,12 +180,18 @@ def export_to_bmz(
179
180
  careamics_version=careamics_version,
180
181
  config_path=config_path,
181
182
  env_path=env_path,
183
+ covers=covers,
182
184
  channel_names=channel_names,
183
- data_description=data_description,
185
+ model_version=model_version,
184
186
  )
185
187
 
186
188
  # test model description
187
- summary: ValidationSummary = test_model(model_description)
189
+ test_kwargs = (
190
+ model_description.config.get("bioimageio", {})
191
+ .get("test_kwargs", {})
192
+ .get("pytorch_state_dict", {})
193
+ )
194
+ summary: ValidationSummary = test_model(model_description, **test_kwargs)
188
195
  if summary.status == "failed":
189
196
  raise ValueError(f"Model description test failed: {summary}")
190
197
 
@@ -193,41 +200,34 @@ def export_to_bmz(
193
200
 
194
201
 
195
202
  def load_from_bmz(
196
- path: Union[Path, str]
197
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
203
+ path: Union[Path, str, HttpUrl]
204
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
198
205
  """Load a model from a BioImage Model Zoo archive.
199
206
 
200
207
  Parameters
201
208
  ----------
202
- path : Union[Path, str]
203
- Path to the BioImage Model Zoo archive.
209
+ path : Path, str or HttpUrl
210
+ Path to the BioImage Model Zoo archive. A Http URL must point to a downloadable
211
+ location.
204
212
 
205
213
  Returns
206
214
  -------
207
- Tuple[CAREamicsKiln, Configuration]
208
- CAREamics model and configuration.
215
+ FCNModel or VAEModel
216
+ The loaded CAREamics model.
217
+ Configuration
218
+ The loaded CAREamics configuration.
209
219
 
210
220
  Raises
211
221
  ------
212
222
  ValueError
213
223
  If the path is not a zip file.
214
224
  """
215
- path = Path(path)
216
-
217
- if path.suffix != ".zip":
218
- raise ValueError(f"Path must be a bioimage.io zip file, got {path}.")
219
-
220
225
  # load description, this creates an unzipped folder next to the archive
221
- model_desc = load_description(path)
226
+ model_desc = load_model_description(path)
222
227
 
223
- # extract relative paths
228
+ # extract paths
224
229
  weights_path, config_path = extract_model_path(model_desc)
225
230
 
226
- # create folder path and absolute paths
227
- unzip_path = get_unzip_path(path)
228
- weights_path = unzip_path / weights_path
229
- config_path = unzip_path / config_path
230
-
231
231
  # load configuration
232
232
  config = load_configuration(config_path)
233
233
 
@@ -1,11 +1,11 @@
1
1
  """Utility functions to load pretrained models."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Tuple, Union
4
+ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from careamics.config import Configuration
8
+ from careamics.config import Configuration, configuration_factory
9
9
  from careamics.lightning.lightning_module import FCNModule, VAEModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
@@ -13,7 +13,7 @@ from careamics.utils import check_path_exists
13
13
 
14
14
  def load_pretrained(
15
15
  path: Union[Path, str]
16
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
16
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
17
17
  """
18
18
  Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
19
19
 
@@ -26,8 +26,8 @@ def load_pretrained(
26
26
 
27
27
  Returns
28
28
  -------
29
- Tuple[CAREamicsKiln, Configuration]
30
- Tuple of CAREamics model and its configuration.
29
+ tuple[CAREamicsKiln, Configuration]
30
+ tuple of CAREamics model and its configuration.
31
31
 
32
32
  Raises
33
33
  ------
@@ -48,7 +48,7 @@ def load_pretrained(
48
48
 
49
49
  def _load_checkpoint(
50
50
  path: Union[Path, str]
51
- ) -> Tuple[Union[FCNModule, VAEModule], Configuration]:
51
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
52
52
  """
53
53
  Load a model from a checkpoint and return both model and configuration.
54
54
 
@@ -59,8 +59,8 @@ def _load_checkpoint(
59
59
 
60
60
  Returns
61
61
  -------
62
- Tuple[CAREamicsKiln, Configuration]
63
- Tuple of CAREamics model and its configuration.
62
+ tuple[CAREamicsKiln, Configuration]
63
+ tuple of CAREamics model and its configuration.
64
64
 
65
65
  Raises
66
66
  ------
@@ -92,4 +92,4 @@ def _load_checkpoint(
92
92
  f"{cfg_dict['algorithm_config']['model']['architecture']}"
93
93
  )
94
94
 
95
- return model, Configuration(**cfg_dict)
95
+ return model, configuration_factory(cfg_dict)