careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,121 @@
1
+ """Functions used to create a README.md file for BMZ export."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import yaml
7
+
8
+ from careamics.config import Configuration
9
+ from careamics.utils import cwd, get_careamics_home
10
+
11
+
12
+ def _yaml_block(yaml_str: str) -> str:
13
+ """Return a markdown code block with a yaml string.
14
+
15
+ Parameters
16
+ ----------
17
+ yaml_str : str
18
+ YAML string.
19
+
20
+ Returns
21
+ -------
22
+ str
23
+ Markdown code block with the YAML string.
24
+ """
25
+ return f"```yaml\n{yaml_str}\n```"
26
+
27
+
28
+ def readme_factory(
29
+ config: Configuration,
30
+ careamics_version: str,
31
+ data_description: Optional[str] = None,
32
+ ) -> Path:
33
+ """Create a README file for the model.
34
+
35
+ `data_description` can be used to add more information about the content of the
36
+ data the model was trained on.
37
+
38
+ Parameters
39
+ ----------
40
+ config : Configuration
41
+ CAREamics configuration.
42
+ careamics_version : str
43
+ CAREamics version.
44
+ data_description : Optional[str], optional
45
+ Description of the data, by default None.
46
+
47
+ Returns
48
+ -------
49
+ Path
50
+ Path to the README file.
51
+ """
52
+ algorithm = config.algorithm_config
53
+ training = config.training_config
54
+ data = config.data_config
55
+
56
+ # create file
57
+ # TODO use tempfile as in the bmz_io module
58
+ with cwd(get_careamics_home()):
59
+ readme = Path("README.md")
60
+ readme.touch()
61
+
62
+ # algorithm pretty name
63
+ algorithm_flavour = config.get_algorithm_flavour()
64
+ algorithm_pretty_name = algorithm_flavour + " - CAREamics"
65
+
66
+ description = [f"# {algorithm_pretty_name}\n\n"]
67
+
68
+ # algorithm description
69
+ description.append("Algorithm description:\n\n")
70
+ description.append(config.get_algorithm_description())
71
+ description.append("\n\n")
72
+
73
+ # algorithm details
74
+ description.append(
75
+ f"{algorithm_flavour} was trained using CAREamics (version "
76
+ f"{careamics_version}) with the following algorithm "
77
+ f"parameters:\n\n"
78
+ )
79
+ description.append(
80
+ _yaml_block(yaml.dump(algorithm.model_dump(exclude_none=True)))
81
+ )
82
+ description.append("\n\n")
83
+
84
+ # data description
85
+ description.append("## Data description\n\n")
86
+ if data_description is not None:
87
+ description.append(data_description)
88
+ description.append("\n\n")
89
+
90
+ description.append("The data was processed using the following parameters:\n\n")
91
+
92
+ description.append(_yaml_block(yaml.dump(data.model_dump(exclude_none=True))))
93
+ description.append("\n\n")
94
+
95
+ # training description
96
+ description.append("## Training description\n\n")
97
+
98
+ description.append("The model was trained using the following parameters:\n\n")
99
+
100
+ description.append(
101
+ _yaml_block(yaml.dump(training.model_dump(exclude_none=True)))
102
+ )
103
+ description.append("\n\n")
104
+
105
+ # references
106
+ reference = config.get_algorithm_references()
107
+ if reference != "":
108
+ description.append("## References\n\n")
109
+ description.append(reference)
110
+ description.append("\n\n")
111
+
112
+ # links
113
+ description.append(
114
+ "## Links\n\n"
115
+ "- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
116
+ "- [CAREamics documentation](https://careamics.github.io/latest/)\n"
117
+ )
118
+
119
+ readme.write_text("".join(description))
120
+
121
+ return readme.absolute()
@@ -0,0 +1,52 @@
1
+ """Bioimage.io utils."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+
7
+ def get_unzip_path(zip_path: Union[Path, str]) -> Path:
8
+ """Generate unzipped folder path from the bioimage.io model path.
9
+
10
+ Parameters
11
+ ----------
12
+ zip_path : Path
13
+ Path to the bioimage.io model.
14
+
15
+ Returns
16
+ -------
17
+ Path
18
+ Path to the unzipped folder.
19
+ """
20
+ zip_path = Path(zip_path)
21
+
22
+ return zip_path.parent / (str(zip_path.name) + ".unzip")
23
+
24
+
25
+ def create_env_text(pytorch_version: str) -> str:
26
+ """Create environment yaml content for the bioimage model.
27
+
28
+ This installs an environment with the specified pytorch version and the latest
29
+ changes to careamics.
30
+
31
+ Parameters
32
+ ----------
33
+ pytorch_version : str
34
+ Pytorch version.
35
+
36
+ Returns
37
+ -------
38
+ str
39
+ Environment text.
40
+ """
41
+ env = (
42
+ f"name: careamics\n"
43
+ f"dependencies:\n"
44
+ f" - python=3.10\n"
45
+ f" - pytorch={pytorch_version}\n"
46
+ f" - torchvision={pytorch_version}\n"
47
+ f" - pip\n"
48
+ f" - pip:\n"
49
+ f" - git+https://github.com/CAREamics/careamics.git\n"
50
+ )
51
+
52
+ return env
@@ -0,0 +1,327 @@
1
+ """Module use to build BMZ model description."""
2
+
3
+ from pathlib import Path
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ from bioimageio.spec.model.v0_5 import (
8
+ ArchitectureFromLibraryDescr,
9
+ Author,
10
+ AxisBase,
11
+ AxisId,
12
+ BatchAxis,
13
+ ChannelAxis,
14
+ EnvironmentFileDescr,
15
+ FileDescr,
16
+ FixedZeroMeanUnitVarianceAlongAxisKwargs,
17
+ FixedZeroMeanUnitVarianceDescr,
18
+ Identifier,
19
+ InputTensorDescr,
20
+ ModelDescr,
21
+ OutputTensorDescr,
22
+ PytorchStateDictWeightsDescr,
23
+ SpaceInputAxis,
24
+ SpaceOutputAxis,
25
+ TensorId,
26
+ Version,
27
+ WeightsDescr,
28
+ )
29
+
30
+ from careamics.config import Configuration, DataConfig
31
+
32
+ from ._readme_factory import readme_factory
33
+
34
+
35
+ def _create_axes(
36
+ array: np.ndarray,
37
+ data_config: DataConfig,
38
+ channel_names: Optional[List[str]] = None,
39
+ is_input: bool = True,
40
+ ) -> List[AxisBase]:
41
+ """Create axes description.
42
+
43
+ Array shape is expected to be SC(Z)YX.
44
+
45
+ Parameters
46
+ ----------
47
+ array : np.ndarray
48
+ Array.
49
+ data_config : DataModel
50
+ CAREamics data configuration.
51
+ channel_names : Optional[List[str]], optional
52
+ Channel names, by default None.
53
+ is_input : bool, optional
54
+ Whether the axes are input axes, by default True.
55
+
56
+ Returns
57
+ -------
58
+ List[AxisBase]
59
+ List of axes description.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If channel names are not provided when channel axis is present.
65
+ """
66
+ # axes have to be SC(Z)YX
67
+ spatial_axes = data_config.axes.replace("S", "").replace("C", "")
68
+
69
+ # batch is always present
70
+ axes_model = [BatchAxis()]
71
+
72
+ if "C" in data_config.axes:
73
+ if channel_names is not None:
74
+ axes_model.append(
75
+ ChannelAxis(channel_names=[Identifier(name) for name in channel_names])
76
+ )
77
+ else:
78
+ raise ValueError(
79
+ f"Channel names must be provided if channel axis is present, axes: "
80
+ f"{data_config.axes}."
81
+ )
82
+ else:
83
+ # singleton channel
84
+ axes_model.append(ChannelAxis(channel_names=[Identifier("channel")]))
85
+
86
+ # spatial axes
87
+ for ind, axes in enumerate(spatial_axes):
88
+ if axes in ["X", "Y", "Z"]:
89
+ if is_input:
90
+ axes_model.append(
91
+ SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
92
+ )
93
+ else:
94
+ axes_model.append(
95
+ SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
96
+ )
97
+
98
+ return axes_model
99
+
100
+
101
+ def _create_inputs_ouputs(
102
+ input_array: np.ndarray,
103
+ output_array: np.ndarray,
104
+ data_config: DataConfig,
105
+ input_path: Union[Path, str],
106
+ output_path: Union[Path, str],
107
+ channel_names: Optional[List[str]] = None,
108
+ ) -> Tuple[InputTensorDescr, OutputTensorDescr]:
109
+ """Create input and output tensor description.
110
+
111
+ Input and output paths must point to a `.npy` file.
112
+
113
+ Parameters
114
+ ----------
115
+ input_array : np.ndarray
116
+ Input array.
117
+ output_array : np.ndarray
118
+ Output array.
119
+ data_config : DataModel
120
+ CAREamics data configuration.
121
+ input_path : Union[Path, str]
122
+ Path to input .npy file.
123
+ output_path : Union[Path, str]
124
+ Path to output .npy file.
125
+ channel_names : Optional[List[str]], optional
126
+ Channel names, by default None.
127
+
128
+ Returns
129
+ -------
130
+ Tuple[InputTensorDescr, OutputTensorDescr]
131
+ Input and output tensor descriptions.
132
+ """
133
+ input_axes = _create_axes(input_array, data_config, channel_names)
134
+ output_axes = _create_axes(output_array, data_config, channel_names, False)
135
+
136
+ # mean and std
137
+ assert data_config.image_means is not None, "Mean cannot be None."
138
+ assert data_config.image_means is not None, "Std cannot be None."
139
+ means = data_config.image_means
140
+ stds = data_config.image_stds
141
+
142
+ # and the mean and std required to invert the normalization
143
+ # CAREamics denormalization: x = y * (std + eps) + mean
144
+ # BMZ normalization : x = (y - mean') / (std' + eps)
145
+ # to apply the BMZ normalization as a denormalization step, we need:
146
+ eps = 1e-6
147
+ inv_means = []
148
+ inv_stds = []
149
+ if means and stds:
150
+ for mean, std in zip(means, stds):
151
+ inv_means.append(-mean / (std + eps))
152
+ inv_stds.append(1 / (std + eps) - eps)
153
+
154
+ # create input/output descriptions
155
+ input_descr = InputTensorDescr(
156
+ id=TensorId("input"),
157
+ axes=input_axes,
158
+ test_tensor=FileDescr(source=input_path),
159
+ preprocessing=[
160
+ FixedZeroMeanUnitVarianceDescr(
161
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
162
+ mean=means, std=stds, axis="channel"
163
+ )
164
+ )
165
+ ],
166
+ )
167
+ output_descr = OutputTensorDescr(
168
+ id=TensorId("prediction"),
169
+ axes=output_axes,
170
+ test_tensor=FileDescr(source=output_path),
171
+ postprocessing=[
172
+ FixedZeroMeanUnitVarianceDescr(
173
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
174
+ mean=inv_means, std=inv_stds, axis="channel"
175
+ )
176
+ )
177
+ ],
178
+ )
179
+
180
+ return input_descr, output_descr
181
+ else:
182
+ raise ValueError("Mean and std cannot be None.")
183
+
184
+
185
+ def create_model_description(
186
+ config: Configuration,
187
+ name: str,
188
+ general_description: str,
189
+ authors: List[Author],
190
+ inputs: Union[Path, str],
191
+ outputs: Union[Path, str],
192
+ weights_path: Union[Path, str],
193
+ torch_version: str,
194
+ careamics_version: str,
195
+ config_path: Union[Path, str],
196
+ env_path: Union[Path, str],
197
+ channel_names: Optional[List[str]] = None,
198
+ data_description: Optional[str] = None,
199
+ ) -> ModelDescr:
200
+ """Create model description.
201
+
202
+ Parameters
203
+ ----------
204
+ config : Configuration
205
+ CAREamics configuration.
206
+ name : str
207
+ Name of the model.
208
+ general_description : str
209
+ General description of the model.
210
+ authors : List[Author]
211
+ Authors of the model.
212
+ inputs : Union[Path, str]
213
+ Path to input .npy file.
214
+ outputs : Union[Path, str]
215
+ Path to output .npy file.
216
+ weights_path : Union[Path, str]
217
+ Path to model weights.
218
+ torch_version : str
219
+ Pytorch version.
220
+ careamics_version : str
221
+ CAREamics version.
222
+ config_path : Union[Path, str]
223
+ Path to model configuration.
224
+ env_path : Union[Path, str]
225
+ Path to environment file.
226
+ channel_names : Optional[List[str]], optional
227
+ Channel names, by default None.
228
+ data_description : Optional[str], optional
229
+ Description of the data, by default None.
230
+
231
+ Returns
232
+ -------
233
+ ModelDescr
234
+ Model description.
235
+ """
236
+ # documentation
237
+ doc = readme_factory(
238
+ config,
239
+ careamics_version=careamics_version,
240
+ data_description=data_description,
241
+ )
242
+
243
+ # inputs, outputs
244
+ input_descr, output_descr = _create_inputs_ouputs(
245
+ input_array=np.load(inputs),
246
+ output_array=np.load(outputs),
247
+ data_config=config.data_config,
248
+ input_path=inputs,
249
+ output_path=outputs,
250
+ channel_names=channel_names,
251
+ )
252
+
253
+ # weights description
254
+ architecture_descr = ArchitectureFromLibraryDescr(
255
+ import_from="careamics.models.unet",
256
+ callable=f"{config.algorithm_config.model.architecture}",
257
+ kwargs=config.algorithm_config.model.model_dump(),
258
+ )
259
+
260
+ weights_descr = WeightsDescr(
261
+ pytorch_state_dict=PytorchStateDictWeightsDescr(
262
+ source=weights_path,
263
+ architecture=architecture_descr,
264
+ pytorch_version=Version(torch_version),
265
+ dependencies=EnvironmentFileDescr(source=env_path),
266
+ ),
267
+ )
268
+
269
+ # overall model description
270
+ model = ModelDescr(
271
+ name=name,
272
+ authors=authors,
273
+ description=general_description,
274
+ documentation=doc,
275
+ inputs=[input_descr],
276
+ outputs=[output_descr],
277
+ tags=config.get_algorithm_keywords(),
278
+ links=[
279
+ "https://github.com/CAREamics/careamics",
280
+ "https://careamics.github.io/latest/",
281
+ ],
282
+ license="BSD-3-Clause",
283
+ version="0.1.0",
284
+ weights=weights_descr,
285
+ attachments=[FileDescr(source=config_path)],
286
+ cite=config.get_algorithm_citations(),
287
+ config={ # conversion from float32 to float64 creates small differences...
288
+ "bioimageio": {
289
+ "test_kwargs": {
290
+ "pytorch_state_dict": {
291
+ "decimals": 0, # ...so we relax the constraints on the decimals
292
+ }
293
+ }
294
+ }
295
+ },
296
+ )
297
+
298
+ return model
299
+
300
+
301
+ def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]:
302
+ """Return the relative path to the weights and configuration files.
303
+
304
+ Parameters
305
+ ----------
306
+ model_desc : ModelDescr
307
+ Model description.
308
+
309
+ Returns
310
+ -------
311
+ Tuple[Path, Path]
312
+ Weights and configuration paths.
313
+ """
314
+ weights_path = model_desc.weights.pytorch_state_dict.source.path
315
+
316
+ if len(model_desc.attachments) == 1:
317
+ config_path = model_desc.attachments[0].source.path
318
+ else:
319
+ for file in model_desc.attachments:
320
+ if file.source.path.suffix == ".yml":
321
+ config_path = file.source.path
322
+ break
323
+
324
+ if config_path is None:
325
+ raise ValueError("Configuration file not found.")
326
+
327
+ return weights_path, config_path