careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,251 @@
1
+ """Function to export to the BioImage Model Zoo format."""
2
+
3
+ import tempfile
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ from bioimageio.core import load_model_description, test_model
9
+ from bioimageio.spec import ValidationSummary, save_bioimageio_package
10
+ from pydantic import HttpUrl
11
+ from torch import __version__ as PYTORCH_VERSION
12
+ from torch import load, save
13
+ from torchvision import __version__ as TORCHVISION_VERSION
14
+
15
+ from careamics.config import Configuration, load_configuration, save_configuration
16
+ from careamics.config.support import SupportedArchitecture
17
+ from careamics.lightning.lightning_module import FCNModule, VAEModule
18
+ from careamics.utils.version import get_careamics_version
19
+
20
+ from .bioimage import (
21
+ create_env_text,
22
+ create_model_description,
23
+ extract_model_path,
24
+ )
25
+ from .bioimage.cover_factory import create_cover
26
+
27
+
28
+ def _export_state_dict(
29
+ model: Union[FCNModule, VAEModule], path: Union[Path, str]
30
+ ) -> Path:
31
+ """
32
+ Export the model state dictionary to a file.
33
+
34
+ Parameters
35
+ ----------
36
+ model : CAREamicsKiln
37
+ CAREamics model to export.
38
+ path : Union[Path, str]
39
+ Path to the file where to save the model state dictionary.
40
+
41
+ Returns
42
+ -------
43
+ Path
44
+ Path to the saved model state dictionary.
45
+ """
46
+ path = Path(path)
47
+
48
+ # make sure it has the correct suffix
49
+ if path.suffix not in ".pth":
50
+ path = path.with_suffix(".pth")
51
+
52
+ # save model state dictionary
53
+ # we save through the torch model itself to avoid the initial "model." in the
54
+ # layers naming, which is incompatible with the way the BMZ load torch state dicts
55
+ save(model.model.state_dict(), path)
56
+
57
+ return path
58
+
59
+
60
+ def _load_state_dict(
61
+ model: Union[FCNModule, VAEModule], path: Union[Path, str]
62
+ ) -> None:
63
+ """
64
+ Load a model from a state dictionary.
65
+
66
+ Parameters
67
+ ----------
68
+ model : CAREamicsKiln
69
+ CAREamics model to be updated with the weights.
70
+ path : Union[Path, str]
71
+ Path to the model state dictionary.
72
+ """
73
+ path = Path(path)
74
+
75
+ # load model state dictionary
76
+ # same as in _export_state_dict, we load through the torch model to be compatible
77
+ # witht bioimageio.core expectations for a torch state dict
78
+ state_dict = load(path)
79
+ model.model.load_state_dict(state_dict)
80
+
81
+
82
+ # TODO break down in subfunctions
83
+ def export_to_bmz(
84
+ model: Union[FCNModule, VAEModule],
85
+ config: Configuration,
86
+ path_to_archive: Union[Path, str],
87
+ model_name: str,
88
+ general_description: str,
89
+ data_description: str,
90
+ authors: list[dict],
91
+ input_array: np.ndarray,
92
+ output_array: np.ndarray,
93
+ covers: list[Union[Path, str]] | None = None,
94
+ channel_names: list[str] | None = None,
95
+ model_version: str = "0.1.0",
96
+ ) -> None:
97
+ """Export the model to BioImage Model Zoo format.
98
+
99
+ Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C.
100
+
101
+ `model_name` should consist of letters, numbers, dashes, underscores and parentheses
102
+ only.
103
+
104
+ Parameters
105
+ ----------
106
+ model : CAREamicsModule
107
+ CAREamics model to export.
108
+ config : Configuration
109
+ Model configuration.
110
+ path_to_archive : Union[Path, str]
111
+ Path to the output file.
112
+ model_name : str
113
+ Model name.
114
+ general_description : str
115
+ General description of the model.
116
+ data_description : str
117
+ Description of the data the model was trained on.
118
+ authors : list[dict]
119
+ Authors of the model.
120
+ input_array : np.ndarray
121
+ Input array, should not have been normalized.
122
+ output_array : np.ndarray
123
+ Output array, should have been denormalized.
124
+ covers : list of pathlib.Path or str, default=None
125
+ Paths to the cover images.
126
+ channel_names : Optional[list[str]], optional
127
+ Channel names, by default None.
128
+ model_version : str, default="0.1.0"
129
+ Model version.
130
+ """
131
+ path_to_archive = Path(path_to_archive)
132
+
133
+ if path_to_archive.suffix != ".zip":
134
+ raise ValueError(
135
+ f"Path to archive must point to a zip file, got {path_to_archive}."
136
+ )
137
+
138
+ if not path_to_archive.parent.exists():
139
+ path_to_archive.parent.mkdir(parents=True, exist_ok=True)
140
+
141
+ # versions
142
+ careamics_version = get_careamics_version()
143
+
144
+ # save files in temporary folder
145
+ with tempfile.TemporaryDirectory() as tmpdirname:
146
+ temp_path = Path(tmpdirname)
147
+
148
+ # create environment file
149
+ # TODO move in bioimage module
150
+ env_path = temp_path / "environment.yml"
151
+ env_path.write_text(create_env_text(PYTORCH_VERSION, TORCHVISION_VERSION))
152
+
153
+ # export input and ouputs
154
+ inputs = temp_path / "inputs.npy"
155
+ np.save(inputs, input_array)
156
+ outputs = temp_path / "outputs.npy"
157
+ np.save(outputs, output_array)
158
+
159
+ # export configuration
160
+ config_path = save_configuration(config, temp_path / "careamics.yaml")
161
+
162
+ # export model state dictionary
163
+ weight_path = _export_state_dict(model, temp_path / "weights.pth")
164
+
165
+ # export cover if necesary
166
+ if covers is None:
167
+ covers = [create_cover(temp_path, input_array, output_array)]
168
+
169
+ # create model description
170
+ model_description = create_model_description(
171
+ config=config,
172
+ name=model_name,
173
+ general_description=general_description,
174
+ data_description=data_description,
175
+ authors=authors,
176
+ inputs=inputs,
177
+ outputs=outputs,
178
+ weights_path=weight_path,
179
+ torch_version=PYTORCH_VERSION,
180
+ careamics_version=careamics_version,
181
+ config_path=config_path,
182
+ env_path=env_path,
183
+ covers=covers,
184
+ channel_names=channel_names,
185
+ model_version=model_version,
186
+ )
187
+
188
+ # test model description
189
+ test_kwargs = {}
190
+ if hasattr(model_description, "config") and isinstance(
191
+ model_description.config, dict
192
+ ):
193
+ bioimageio_config = model_description.config.get("bioimageio", {})
194
+ test_kwargs = bioimageio_config.get("test_kwargs", {}).get(
195
+ "pytorch_state_dict", {}
196
+ )
197
+
198
+ summary: ValidationSummary = test_model(model_description, **test_kwargs)
199
+ if summary.status == "failed":
200
+ raise ValueError(f"Model description test failed: {summary}")
201
+
202
+ # save bmz model
203
+ save_bioimageio_package(model_description, output_path=path_to_archive)
204
+
205
+
206
+ def load_from_bmz(
207
+ path: Union[Path, str, HttpUrl],
208
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
209
+ """Load a model from a BioImage Model Zoo archive.
210
+
211
+ Parameters
212
+ ----------
213
+ path : Path, str or HttpUrl
214
+ Path to the BioImage Model Zoo archive. A Http URL must point to a downloadable
215
+ location.
216
+
217
+ Returns
218
+ -------
219
+ FCNModel or VAEModel
220
+ The loaded CAREamics model.
221
+ Configuration
222
+ The loaded CAREamics configuration.
223
+
224
+ Raises
225
+ ------
226
+ ValueError
227
+ If the path is not a zip file.
228
+ """
229
+ # load description, this creates an unzipped folder next to the archive
230
+ model_desc = load_model_description(path)
231
+
232
+ # extract paths
233
+ weights_path, config_path = extract_model_path(model_desc)
234
+
235
+ # load configuration
236
+ config = load_configuration(config_path)
237
+
238
+ # create careamics lightning module
239
+ if config.algorithm_config.model.architecture == SupportedArchitecture.UNET:
240
+ model = FCNModule(algorithm_config=config.algorithm_config)
241
+ elif config.algorithm_config.model.architecture == SupportedArchitecture.LVAE:
242
+ model = VAEModule(algorithm_config=config.algorithm_config)
243
+ else:
244
+ raise ValueError(
245
+ f"Unsupported architecture {config.algorithm_config.model.architecture}"
246
+ ) # TODO ugly ?
247
+
248
+ # load model state dictionary
249
+ _load_state_dict(model, weights_path)
250
+
251
+ return model, config
@@ -0,0 +1,95 @@
1
+ """Utility functions to load pretrained models."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from careamics.config import Configuration
9
+ from careamics.lightning.lightning_module import FCNModule, VAEModule
10
+ from careamics.model_io.bmz_io import load_from_bmz
11
+ from careamics.utils import check_path_exists
12
+
13
+
14
+ def load_pretrained(
15
+ path: Union[Path, str],
16
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
17
+ """
18
+ Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
19
+
20
+ Expected formats are .ckpt or .zip files.
21
+
22
+ Parameters
23
+ ----------
24
+ path : Union[Path, str]
25
+ Path to the pretrained model.
26
+
27
+ Returns
28
+ -------
29
+ tuple[CAREamicsKiln, Configuration]
30
+ tuple of CAREamics model and its configuration.
31
+
32
+ Raises
33
+ ------
34
+ ValueError
35
+ If the model format is not supported.
36
+ """
37
+ path = check_path_exists(path)
38
+
39
+ if path.suffix == ".ckpt":
40
+ return _load_checkpoint(path)
41
+ elif path.suffix == ".zip":
42
+ return load_from_bmz(path)
43
+ else:
44
+ raise ValueError(
45
+ f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}."
46
+ )
47
+
48
+
49
+ def _load_checkpoint(
50
+ path: Union[Path, str],
51
+ ) -> tuple[Union[FCNModule, VAEModule], Configuration]:
52
+ """
53
+ Load a model from a checkpoint and return both model and configuration.
54
+
55
+ Parameters
56
+ ----------
57
+ path : Union[Path, str]
58
+ Path to the checkpoint.
59
+
60
+ Returns
61
+ -------
62
+ tuple[CAREamicsKiln, Configuration]
63
+ tuple of CAREamics model and its configuration.
64
+
65
+ Raises
66
+ ------
67
+ ValueError
68
+ If the checkpoint file does not contain hyper parameters (configuration).
69
+ """
70
+ # load checkpoint
71
+ # here we might run into issues between devices
72
+ # see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
73
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
74
+ checkpoint: dict = torch.load(path, map_location=device)
75
+
76
+ # attempt to load configuration
77
+ try:
78
+ cfg_dict = checkpoint["hyper_parameters"]
79
+ except KeyError as e:
80
+ raise ValueError(
81
+ f"Invalid checkpoint file. No `hyper_parameters` found in the "
82
+ f"checkpoint: {checkpoint.keys()}"
83
+ ) from e
84
+
85
+ if cfg_dict["algorithm_config"]["model"]["architecture"] == "UNet":
86
+ model = FCNModule.load_from_checkpoint(path)
87
+ elif cfg_dict["algorithm_config"]["model"]["architecture"] == "LVAE":
88
+ model = VAEModule.load_from_checkpoint(path)
89
+ else:
90
+ raise ValueError(
91
+ "Invalid model architecture: "
92
+ f"{cfg_dict['algorithm_config']['model']['architecture']}"
93
+ )
94
+
95
+ return model, Configuration(**cfg_dict)
@@ -0,0 +1,5 @@
1
+ """Models package."""
2
+
3
+ __all__ = ["model_factory"]
4
+
5
+ from .model_factory import model_factory
@@ -0,0 +1,40 @@
1
+ """Activations for CAREamics models."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Union
5
+
6
+ import torch.nn as nn
7
+
8
+ from ..config.support import SupportedActivation
9
+
10
+
11
+ def get_activation(activation: Union[SupportedActivation, str]) -> Callable:
12
+ """
13
+ Get activation function.
14
+
15
+ Parameters
16
+ ----------
17
+ activation : str
18
+ Activation function name.
19
+
20
+ Returns
21
+ -------
22
+ Callable
23
+ Activation function.
24
+ """
25
+ if activation == SupportedActivation.RELU:
26
+ return nn.ReLU()
27
+ elif activation == SupportedActivation.ELU:
28
+ return nn.ELU()
29
+ elif activation == SupportedActivation.LEAKYRELU:
30
+ return nn.LeakyReLU()
31
+ elif activation == SupportedActivation.TANH:
32
+ return nn.Tanh()
33
+ elif activation == SupportedActivation.SIGMOID:
34
+ return nn.Sigmoid()
35
+ elif activation == SupportedActivation.SOFTMAX:
36
+ return nn.Softmax(dim=1)
37
+ elif activation == SupportedActivation.NONE:
38
+ return nn.Identity()
39
+ else:
40
+ raise ValueError(f"Activation {activation} not supported.")