careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,113 @@
1
+ """Functions used to create a README.md file for BMZ export."""
2
+
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+
7
+ from careamics.config import Configuration
8
+ from careamics.utils import cwd, get_careamics_home
9
+
10
+
11
+ def _yaml_block(yaml_str: str) -> str:
12
+ """Return a markdown code block with a yaml string.
13
+
14
+ Parameters
15
+ ----------
16
+ yaml_str : str
17
+ YAML string.
18
+
19
+ Returns
20
+ -------
21
+ str
22
+ Markdown code block with the YAML string.
23
+ """
24
+ return f"```yaml\n{yaml_str}\n```"
25
+
26
+
27
+ def readme_factory(
28
+ config: Configuration,
29
+ careamics_version: str,
30
+ data_description: str,
31
+ ) -> Path:
32
+ """Create a README file for the model.
33
+
34
+ `data_description` can be used to add more information about the content of the
35
+ data the model was trained on.
36
+
37
+ Parameters
38
+ ----------
39
+ config : Configuration
40
+ CAREamics configuration.
41
+ careamics_version : str
42
+ CAREamics version.
43
+ data_description : str
44
+ Description of the data.
45
+
46
+ Returns
47
+ -------
48
+ Path
49
+ Path to the README file.
50
+ """
51
+ # create file
52
+ # TODO use tempfile as in the bmz_io module
53
+ with cwd(get_careamics_home()):
54
+ readme = Path("README.md")
55
+ readme.touch()
56
+
57
+ # algorithm pretty name
58
+ algorithm_flavour = config.get_algorithm_friendly_name()
59
+ algorithm_pretty_name = algorithm_flavour + " - CAREamics"
60
+
61
+ description = [f"# {algorithm_pretty_name}\n\n"]
62
+
63
+ # data description
64
+ description.append("## Data description\n\n")
65
+ description.append(data_description)
66
+ description.append("\n\n")
67
+
68
+ # algorithm description
69
+ description.append("## Algorithm description:\n\n")
70
+ description.append(config.get_algorithm_description())
71
+ description.append("\n\n")
72
+
73
+ # configuration description
74
+ description.append("## Configuration\n\n")
75
+
76
+ description.append(
77
+ f"{algorithm_flavour} was trained using CAREamics (version "
78
+ f"{careamics_version}) using the following configuration:\n\n"
79
+ )
80
+
81
+ description.append(_yaml_block(yaml.dump(config.model_dump(exclude_none=True))))
82
+ description.append("\n\n")
83
+
84
+ # validation
85
+ description.append("# Validation\n\n")
86
+
87
+ description.append(
88
+ "In order to validate the model, we encourage users to acquire a "
89
+ "test dataset with ground-truth data. Comparing the ground-truth data "
90
+ "with the prediction allows unbiased evaluation of the model performances. "
91
+ "This can be done for instance by using metrics such as PSNR, SSIM, or"
92
+ "MicroSSIM. In the absence of ground-truth, inspecting the residual image "
93
+ "(difference between input and predicted image) can be helpful to identify "
94
+ "whether real signal is removed from the input image.\n\n"
95
+ )
96
+
97
+ # references
98
+ reference = config.get_algorithm_references()
99
+ if reference != "":
100
+ description.append("## References\n\n")
101
+ description.append(reference)
102
+ description.append("\n\n")
103
+
104
+ # links
105
+ description.append(
106
+ "# Links\n\n"
107
+ "- [CAREamics repository](https://github.com/CAREamics/careamics)\n"
108
+ "- [CAREamics documentation](https://careamics.github.io/)\n"
109
+ )
110
+
111
+ readme.write_text("".join(description))
112
+
113
+ return readme.absolute()
@@ -0,0 +1,56 @@
1
+ """Bioimage.io utils."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ from careamics.utils.version import get_careamics_version
7
+
8
+
9
+ def get_unzip_path(zip_path: Union[Path, str]) -> Path:
10
+ """Generate unzipped folder path from the bioimage.io model path.
11
+
12
+ Parameters
13
+ ----------
14
+ zip_path : Path
15
+ Path to the bioimage.io model.
16
+
17
+ Returns
18
+ -------
19
+ Path
20
+ Path to the unzipped folder.
21
+ """
22
+ zip_path = Path(zip_path)
23
+
24
+ return zip_path.parent / (str(zip_path.name) + ".unzip")
25
+
26
+
27
+ def create_env_text(pytorch_version: str, torchvision_version: str) -> str:
28
+ """Create environment yaml content for the bioimage model.
29
+
30
+ This installs an environment with the specified pytorch version and the latest
31
+ changes to careamics.
32
+
33
+ Parameters
34
+ ----------
35
+ pytorch_version : str
36
+ Pytorch version.
37
+ torchvision_version : str
38
+ Torchvision version.
39
+
40
+ Returns
41
+ -------
42
+ str
43
+ Environment text.
44
+ """
45
+ env = (
46
+ f"name: careamics\n"
47
+ f"dependencies:\n"
48
+ f" - python=3.12\n"
49
+ f" - pip\n"
50
+ f" - pip:\n"
51
+ f" - torch=={pytorch_version}\n"
52
+ f" - torchvision=={torchvision_version}\n"
53
+ f" - careamics=={get_careamics_version()}"
54
+ )
55
+
56
+ return env
@@ -0,0 +1,171 @@
1
+ """Convenience function to create covers for the BMZ."""
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from PIL import Image
8
+
9
+ color_palette = np.array(
10
+ [
11
+ np.array([255, 195, 0]), # grey
12
+ np.array([189, 226, 240]),
13
+ np.array([96, 60, 76]),
14
+ np.array([193, 225, 193]),
15
+ ]
16
+ )
17
+
18
+
19
+ def _get_norm_slice(array: NDArray) -> NDArray:
20
+ """Get the normalized middle slice of a 4D or 5D array (SC(Z)YX).
21
+
22
+ Parameters
23
+ ----------
24
+ array : NDArray
25
+ Array from which to get the middle slice.
26
+
27
+ Returns
28
+ -------
29
+ NDArray
30
+ Normalized middle slice of the input array.
31
+ """
32
+ if array.ndim not in (4, 5):
33
+ raise ValueError("Array must be 4D or 5D.")
34
+
35
+ channels = array.shape[1] > 1
36
+ z_stack = array.ndim == 5
37
+
38
+ # get slice
39
+ if z_stack:
40
+ array_slice = array[0, :, array.shape[2] // 2, ...]
41
+ else:
42
+ array_slice = array[0, ...]
43
+
44
+ # channels
45
+ if channels:
46
+ array_slice = np.moveaxis(array_slice, 0, -1)
47
+ else:
48
+ array_slice = array_slice[0, ...]
49
+
50
+ # normalize
51
+ array_slice = (
52
+ 255
53
+ * (array_slice - array_slice.min())
54
+ / (array_slice.max() - array_slice.min())
55
+ )
56
+
57
+ return array_slice.astype(np.uint8)
58
+
59
+
60
+ def _four_channel_image(array: NDArray) -> Image:
61
+ """Convert 4-channel array to Image.
62
+
63
+ Parameters
64
+ ----------
65
+ array : NDArray
66
+ Normalized array to convert.
67
+
68
+ Returns
69
+ -------
70
+ Image
71
+ Converted array.
72
+ """
73
+ colors = color_palette[np.newaxis, np.newaxis, :, :]
74
+ four_c_array = np.sum(array[..., :4, np.newaxis] * colors, axis=-2).astype(np.uint8)
75
+
76
+ return Image.fromarray(four_c_array).convert("RGB")
77
+
78
+
79
+ def _convert_to_image(original_shape: tuple[int, ...], array: NDArray) -> Image:
80
+ """Convert to Image.
81
+
82
+ Parameters
83
+ ----------
84
+ original_shape : tuple
85
+ Original shape of the array.
86
+ array : NDArray
87
+ Normalized array to convert.
88
+
89
+ Returns
90
+ -------
91
+ Image
92
+ Converted array.
93
+ """
94
+ n_channels = original_shape[1]
95
+
96
+ if n_channels > 1:
97
+ if n_channels == 3:
98
+ return Image.fromarray(array).convert("RGB")
99
+ elif n_channels == 2:
100
+ # add an empty channel to the numpy array
101
+ array = np.concatenate([np.zeros_like(array[..., 0:1]), array], axis=-1)
102
+
103
+ return Image.fromarray(array).convert("RGB")
104
+ else: # more than 4
105
+ return _four_channel_image(array[..., :4])
106
+ else:
107
+ return Image.fromarray(array).convert("L").convert("RGB")
108
+
109
+
110
+ def create_cover(directory: Path, array_in: NDArray, array_out: NDArray) -> Path:
111
+ """Create a cover image from input and output arrays.
112
+
113
+ Input and output arrays are expected to be SC(Z)YX. For images with a Z
114
+ dimension, the middle slice is taken.
115
+
116
+ Parameters
117
+ ----------
118
+ directory : Path
119
+ Directory in which to save the cover.
120
+ array_in : numpy.ndarray
121
+ Array from which to create the cover image.
122
+ array_out : numpy.ndarray
123
+ Array from which to create the cover image.
124
+
125
+ Returns
126
+ -------
127
+ Path
128
+ Path to the saved cover image.
129
+ """
130
+ # extract slice and normalize arrays
131
+ slice_in = _get_norm_slice(array_in)
132
+ slice_out = _get_norm_slice(array_out)
133
+
134
+ horizontal_split = slice_in.shape[-1] == slice_out.shape[-1]
135
+ if not horizontal_split:
136
+ if slice_in.shape[-2] != slice_out.shape[-2]:
137
+ raise ValueError("Input and output arrays have different shapes.")
138
+
139
+ # convert to Image
140
+ image_in = _convert_to_image(array_in.shape, slice_in)
141
+ image_out = _convert_to_image(array_out.shape, slice_out)
142
+
143
+ # split horizontally or vertically
144
+ if horizontal_split:
145
+ width = image_in.width // 2
146
+
147
+ cover = Image.new("RGB", (image_in.width, image_in.height))
148
+ cover.paste(image_in.crop((0, 0, width, image_in.height)), (0, 0))
149
+ cover.paste(
150
+ image_out.crop(
151
+ (image_in.width - width, 0, image_in.width, image_in.height)
152
+ ),
153
+ (width, 0),
154
+ )
155
+ else:
156
+ height = image_in.height // 2
157
+
158
+ cover = Image.new("RGB", (image_in.width, image_in.height))
159
+ cover.paste(image_in.crop((0, 0, image_in.width, height)), (0, 0))
160
+ cover.paste(
161
+ image_out.crop(
162
+ (0, image_in.height - height, image_in.width, image_in.height)
163
+ ),
164
+ (0, height),
165
+ )
166
+
167
+ # save
168
+ cover_path = directory / "cover.png"
169
+ cover.save(cover_path)
170
+
171
+ return cover_path
@@ -0,0 +1,341 @@
1
+ """Module use to build BMZ model description."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ from bioimageio.spec._internal.io import extract
8
+ from bioimageio.spec.model.v0_5 import (
9
+ ArchitectureFromLibraryDescr,
10
+ Author,
11
+ AxisBase,
12
+ AxisId,
13
+ BatchAxis,
14
+ ChannelAxis,
15
+ FileDescr,
16
+ FixedZeroMeanUnitVarianceAlongAxisKwargs,
17
+ FixedZeroMeanUnitVarianceDescr,
18
+ Identifier,
19
+ InputTensorDescr,
20
+ ModelDescr,
21
+ OutputTensorDescr,
22
+ PytorchStateDictWeightsDescr,
23
+ SpaceInputAxis,
24
+ SpaceOutputAxis,
25
+ TensorId,
26
+ Version,
27
+ WeightsDescr,
28
+ )
29
+
30
+ from careamics.config import Configuration, DataConfig
31
+
32
+ from ._readme_factory import readme_factory
33
+
34
+
35
+ def _create_axes(
36
+ array: np.ndarray,
37
+ data_config: DataConfig,
38
+ channel_names: list[str] | None = None,
39
+ is_input: bool = True,
40
+ ) -> list[AxisBase]:
41
+ """Create axes description.
42
+
43
+ Array shape is expected to be SC(Z)YX.
44
+
45
+ Parameters
46
+ ----------
47
+ array : np.ndarray
48
+ Array.
49
+ data_config : DataModel
50
+ CAREamics data configuration.
51
+ channel_names : Optional[list[str]], optional
52
+ Channel names, by default None.
53
+ is_input : bool, optional
54
+ Whether the axes are input axes, by default True.
55
+
56
+ Returns
57
+ -------
58
+ list[AxisBase]
59
+ list of axes description.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If channel names are not provided when channel axis is present.
65
+ """
66
+ # axes have to be SC(Z)YX
67
+ spatial_axes = data_config.axes.replace("S", "").replace("C", "")
68
+
69
+ # batch is always present
70
+ axes_model = [BatchAxis()]
71
+
72
+ if "C" in data_config.axes:
73
+ if channel_names is not None:
74
+ axes_model.append(
75
+ ChannelAxis(channel_names=[Identifier(name) for name in channel_names])
76
+ )
77
+ else:
78
+ raise ValueError(
79
+ f"Channel names must be provided if channel axis is present, axes: "
80
+ f"{data_config.axes}."
81
+ )
82
+ else:
83
+ # singleton channel
84
+ axes_model.append(ChannelAxis(channel_names=[Identifier("channel")]))
85
+
86
+ # spatial axes
87
+ for ind, axes in enumerate(spatial_axes):
88
+ if axes in ["X", "Y", "Z"]:
89
+ if is_input:
90
+ axes_model.append(
91
+ SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
92
+ )
93
+ else:
94
+ axes_model.append(
95
+ SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind])
96
+ )
97
+
98
+ return axes_model
99
+
100
+
101
+ def _create_inputs_ouputs(
102
+ input_array: np.ndarray,
103
+ output_array: np.ndarray,
104
+ data_config: DataConfig,
105
+ input_path: Union[Path, str],
106
+ output_path: Union[Path, str],
107
+ channel_names: list[str] | None = None,
108
+ ) -> tuple[InputTensorDescr, OutputTensorDescr]:
109
+ """Create input and output tensor description.
110
+
111
+ Input and output paths must point to a `.npy` file.
112
+
113
+ Parameters
114
+ ----------
115
+ input_array : np.ndarray
116
+ Input array.
117
+ output_array : np.ndarray
118
+ Output array.
119
+ data_config : DataModel
120
+ CAREamics data configuration.
121
+ input_path : Union[Path, str]
122
+ Path to input .npy file.
123
+ output_path : Union[Path, str]
124
+ Path to output .npy file.
125
+ channel_names : Optional[list[str]], optional
126
+ Channel names, by default None.
127
+
128
+ Returns
129
+ -------
130
+ tuple[InputTensorDescr, OutputTensorDescr]
131
+ Input and output tensor descriptions.
132
+ """
133
+ input_axes = _create_axes(input_array, data_config, channel_names)
134
+ output_axes = _create_axes(output_array, data_config, channel_names, False)
135
+
136
+ # mean and std
137
+ assert data_config.image_means is not None, "Mean cannot be None."
138
+ assert data_config.image_means is not None, "Std cannot be None."
139
+ means = data_config.image_means
140
+ stds = data_config.image_stds
141
+
142
+ # and the mean and std required to invert the normalization
143
+ # CAREamics denormalization: x = y * (std + eps) + mean
144
+ # BMZ normalization : x = (y - mean') / (std' + eps)
145
+ # to apply the BMZ normalization as a denormalization step, we need:
146
+ eps = 1e-6
147
+ inv_means = []
148
+ inv_stds = []
149
+ if means and stds:
150
+ for mean, std in zip(means, stds, strict=False):
151
+ inv_means.append(-mean / (std + eps))
152
+ inv_stds.append(1 / (std + eps) - eps)
153
+
154
+ # create input/output descriptions
155
+ input_descr = InputTensorDescr(
156
+ id=TensorId("input"),
157
+ axes=input_axes,
158
+ test_tensor=FileDescr(source=input_path),
159
+ preprocessing=[
160
+ FixedZeroMeanUnitVarianceDescr(
161
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
162
+ mean=means, std=stds, axis="channel"
163
+ )
164
+ )
165
+ ],
166
+ )
167
+ output_descr = OutputTensorDescr(
168
+ id=TensorId("prediction"),
169
+ axes=output_axes,
170
+ test_tensor=FileDescr(source=output_path),
171
+ postprocessing=[
172
+ FixedZeroMeanUnitVarianceDescr(
173
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
174
+ mean=inv_means, std=inv_stds, axis="channel"
175
+ )
176
+ )
177
+ ],
178
+ )
179
+
180
+ return input_descr, output_descr
181
+ else:
182
+ raise ValueError("Mean and std cannot be None.")
183
+
184
+
185
+ def create_model_description(
186
+ config: Configuration,
187
+ name: str,
188
+ general_description: str,
189
+ data_description: str,
190
+ authors: list[Author],
191
+ inputs: Union[Path, str],
192
+ outputs: Union[Path, str],
193
+ weights_path: Union[Path, str],
194
+ torch_version: str,
195
+ careamics_version: str,
196
+ config_path: Union[Path, str],
197
+ env_path: Union[Path, str],
198
+ covers: list[Union[Path, str]],
199
+ channel_names: list[str] | None = None,
200
+ model_version: str = "0.1.0",
201
+ ) -> ModelDescr:
202
+ """Create model description.
203
+
204
+ Parameters
205
+ ----------
206
+ config : Configuration
207
+ CAREamics configuration.
208
+ name : str
209
+ Name of the model.
210
+ general_description : str
211
+ General description of the model.
212
+ data_description : str
213
+ Description of the data the model was trained on.
214
+ authors : list[Author]
215
+ Authors of the model.
216
+ inputs : Union[Path, str]
217
+ Path to input .npy file.
218
+ outputs : Union[Path, str]
219
+ Path to output .npy file.
220
+ weights_path : Union[Path, str]
221
+ Path to model weights.
222
+ torch_version : str
223
+ Pytorch version.
224
+ careamics_version : str
225
+ CAREamics version.
226
+ config_path : Union[Path, str]
227
+ Path to model configuration.
228
+ env_path : Union[Path, str]
229
+ Path to environment file.
230
+ covers : list of pathlib.Path or str
231
+ Paths to cover images.
232
+ channel_names : Optional[list[str]], optional
233
+ Channel names, by default None.
234
+ model_version : str, default "0.1.0"
235
+ Model version.
236
+
237
+ Returns
238
+ -------
239
+ ModelDescr
240
+ Model description.
241
+ """
242
+ # documentation
243
+ doc = readme_factory(
244
+ config,
245
+ careamics_version=careamics_version,
246
+ data_description=data_description,
247
+ )
248
+
249
+ # inputs, outputs
250
+ input_descr, output_descr = _create_inputs_ouputs(
251
+ input_array=np.load(inputs),
252
+ output_array=np.load(outputs),
253
+ data_config=config.data_config,
254
+ input_path=inputs,
255
+ output_path=outputs,
256
+ channel_names=channel_names,
257
+ )
258
+
259
+ # weights description
260
+ architecture_descr = ArchitectureFromLibraryDescr(
261
+ import_from="careamics.models.unet",
262
+ callable=f"{config.algorithm_config.model.architecture}",
263
+ kwargs=config.algorithm_config.model.model_dump(),
264
+ )
265
+
266
+ weights_descr = WeightsDescr(
267
+ pytorch_state_dict=PytorchStateDictWeightsDescr(
268
+ source=weights_path,
269
+ architecture=architecture_descr,
270
+ pytorch_version=Version(torch_version),
271
+ dependencies=FileDescr(source=Path(env_path)),
272
+ ),
273
+ )
274
+
275
+ # overall model description
276
+ model = ModelDescr(
277
+ name=name,
278
+ authors=authors,
279
+ description=general_description,
280
+ documentation=doc,
281
+ inputs=[input_descr],
282
+ outputs=[output_descr],
283
+ tags=config.get_algorithm_keywords(),
284
+ links=[
285
+ "https://github.com/CAREamics/careamics",
286
+ "https://careamics.github.io/latest/",
287
+ ],
288
+ license="BSD-3-Clause",
289
+ config={
290
+ "bioimageio": {
291
+ "test_kwargs": {
292
+ "pytorch_state_dict": {
293
+ "absolute_tolerance": 1e-2,
294
+ "relative_tolerance": 1e-2,
295
+ }
296
+ }
297
+ }
298
+ },
299
+ version=model_version,
300
+ weights=weights_descr,
301
+ attachments=[FileDescr(source=config_path)],
302
+ cite=config.get_algorithm_citations(),
303
+ covers=covers,
304
+ )
305
+
306
+ return model
307
+
308
+
309
+ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
310
+ """Return the relative path to the weights and configuration files.
311
+
312
+ Parameters
313
+ ----------
314
+ model_desc : ModelDescr
315
+ Model description.
316
+
317
+ Returns
318
+ -------
319
+ tuple of (path, path)
320
+ Weights and configuration paths.
321
+ """
322
+ if model_desc.weights.pytorch_state_dict is None:
323
+ raise ValueError("No model weights found in model description.")
324
+
325
+ # extract the zip model and return the directory
326
+ model_dir = extract(model_desc.root)
327
+
328
+ weights_path = model_dir.joinpath(model_desc.weights.pytorch_state_dict.source.path)
329
+
330
+ for file in model_desc.attachments:
331
+ file_path = file.source if isinstance(file.source, Path) else file.source.path
332
+ if file_path is None:
333
+ continue
334
+ file_path = Path(file_path)
335
+ if file_path.name == "careamics.yaml":
336
+ config_path = model_dir.joinpath(file.source.path)
337
+ break
338
+ else:
339
+ raise ValueError("Configuration file not found.")
340
+
341
+ return weights_path, config_path