careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,736 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Creating a custom `ImageStack`\n",
8
+ "\n",
9
+ "You might want to write a custom `ImageStack` class if you have data stored in a format\n",
10
+ "that is chunked or capable of sub-file access, i.e. you want to be able to extract \n",
11
+ "patches during the training loop without having to load all the data into RAM at once. \n",
12
+ "The image stack has to follow the python `Protocol` defined in [patch_extractor/image_stack/image_stack_protocol.py](patch_extractor/image_stack/image_stack_protocol.py).\n",
13
+ "\n",
14
+ "To use a custom `ImageStack` with the `CAREamicsDataset` we will also have to write an\n",
15
+ "image stack loader function, that has a protocol defined in [src/careamics/dataset_ng/patch_extractor/image_stack_loader.py](patch_extractor/image_stack_loader.py). It is a callable with the function signature:\n",
16
+ "\n",
17
+ "```python\n",
18
+ "# example signature\n",
19
+ "def custom_image_stack_loader(\n",
20
+ " source: Any, axes: str, *args: Any, **kwargs: Any\n",
21
+ ") -> Sequence[ImageStack]: ...\n",
22
+ "```\n",
23
+ "\n",
24
+ "In this demo, we will create a custom image stack and image stack loader for data saved\n",
25
+ "in a hdf5 file."
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "from collections.abc import Sequence\n",
35
+ "from pathlib import Path\n",
36
+ "from typing import Union\n",
37
+ "\n",
38
+ "import h5py\n",
39
+ "import matplotlib.pyplot as plt\n",
40
+ "import numpy as np\n",
41
+ "import tifffile\n",
42
+ "from careamics_portfolio import PortfolioManager\n",
43
+ "from numpy.typing import DTypeLike, NDArray\n",
44
+ "\n",
45
+ "from careamics.config import create_care_configuration\n",
46
+ "from careamics.dataset_ng.dataset import Mode\n",
47
+ "from careamics.dataset_ng.factory import create_dataset"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {},
53
+ "source": [
54
+ "## Downloading and re-saving data\n",
55
+ "\n",
56
+ "We will resave some data as HDF5 for the purpose of this demo.\n",
57
+ "\n",
58
+ "First we download some data that is available using `careamics_portfolio`."
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "# instantiate data portfolio manager and download the data\n",
68
+ "data_path = Path(\"./data\")\n",
69
+ "\n",
70
+ "portfolio = PortfolioManager()\n",
71
+ "download = portfolio.denoising.CARE_U2OS.download(data_path)"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "root_path = data_path / \"denoising-CARE_U2OS.unzip\" / \"data\" / \"U2OS\"\n",
81
+ "train_path = root_path / \"train\" / \"low\"\n",
82
+ "target_path = root_path / \"train\" / \"GT\"\n",
83
+ "test_path = root_path / \"test\" / \"low\"\n",
84
+ "test_target_path = root_path / \"test\" / \"GT\""
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# checking the train input and target files we have\n",
94
+ "print(list(train_path.glob(\"*.tif\")))\n",
95
+ "print(list(target_path.glob(\"*.tif\")))"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": [
102
+ "### Save as HDF5\n",
103
+ "\n",
104
+ "We will save all the images in a HDF5 file, the input images under a \"train\" path and \n",
105
+ "target images under a \"target\" path, and all the images will have their original file \n",
106
+ "name."
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "hdf5_file_path = data_path / \"CARE_U2OS-train.h5\"\n",
116
+ "\n",
117
+ "if not hdf5_file_path.is_file():\n",
118
+ " with h5py.File(name=hdf5_file_path, mode=\"w\") as file:\n",
119
+ " train_group = file.create_group(\"train_input\")\n",
120
+ " target_group = file.create_group(\"train_target\")\n",
121
+ " test_group = file.create_group(\"test_input\")\n",
122
+ " test_target_group = file.create_group(\"test_target\")\n",
123
+ " for path in train_path.glob(\"*.tif\"):\n",
124
+ " image = tifffile.imread(path)\n",
125
+ " train_group.create_dataset(name=path.stem, data=image)\n",
126
+ " for path in target_path.glob(\"*.tif\"):\n",
127
+ " image = tifffile.imread(path)\n",
128
+ " target_group.create_dataset(name=path.stem, data=image)\n",
129
+ " for path in test_path.glob(\"*.tif\"):\n",
130
+ " image = tifffile.imread(path)\n",
131
+ " test_group.create_dataset(name=path.stem, data=image)\n",
132
+ " for path in test_target_path.glob(\"*.tif\"):\n",
133
+ " image = tifffile.imread(path)\n",
134
+ " test_target_group.create_dataset(name=path.stem, data=image)"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {},
140
+ "source": [
141
+ "# Defining the image stack\n",
142
+ "\n",
143
+ "An ImageStack must have the attributes: `data_shape`, `data_dtype` and `source` and the\n",
144
+ "method `extract_patch`.\n",
145
+ "\n",
146
+ "The `data_shape` attribute should be shape the data would have once reshaped to match the axes \n",
147
+ "`SC(Z)YX`.\n",
148
+ "\n",
149
+ "The `data_dtype` attribute is the data type of the underlying array.\n",
150
+ "\n",
151
+ "The `source` attribute should have the type `Path`, it will be returned alongside the patches by the\n",
152
+ "`CAREamicsDataset` and can be used as a way to identify where the data came from. In the\n",
153
+ "future it may be used as a way to automatically save predictions to disk.\n",
154
+ "\n",
155
+ "The `extract_patch` method needs to return a patch for a given `sample_index`, `coords` \n",
156
+ "and `patch_size` that has the axes `SC(Z)YX`. So, for our HDF5 case the patches need to \n",
157
+ "be reshaped when the `extract_patch_method` is called."
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": null,
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "from careamics.dataset_ng.patch_extractor.image_stack.zarr_image_stack import (\n",
167
+ " _reshaped_array_shape,\n",
168
+ ")\n",
169
+ "\n",
170
+ "from careamics.dataset.dataset_utils import reshape_array\n",
171
+ "\n",
172
+ "\n",
173
+ "class HDF5ImageStack:\n",
174
+ "\n",
175
+ " def __init__(self, image_data: h5py.Dataset, axes: str):\n",
176
+ " self._image_data = image_data\n",
177
+ " self._original_axes = axes\n",
178
+ " self._original_data_shape = image_data.shape\n",
179
+ " self.data_shape = _reshaped_array_shape(\n",
180
+ " self._original_axes, self._image_data.shape\n",
181
+ " )\n",
182
+ "\n",
183
+ " @property\n",
184
+ " def data_dtype(self) -> DTypeLike:\n",
185
+ " return self._image_data.dtype\n",
186
+ "\n",
187
+ " @property\n",
188
+ " def source(self) -> Path:\n",
189
+ " return Path(self._image_data.file.filename + str(self._image_data.name))\n",
190
+ "\n",
191
+ " # this method is almost an exact copy of the ZarrImageStack.extract patch\n",
192
+ " def extract_patch(\n",
193
+ " self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]\n",
194
+ " ) -> NDArray:\n",
195
+ " # original axes assumed to be any subset of STCZYX (containing YX), in any order\n",
196
+ " # arguments must be transformed to index data in original axes order\n",
197
+ " # to do this: loop through original axes and append correct index/slice\n",
198
+ " # for each case: STCZYX\n",
199
+ " # Note: if any axis is not present in original_axes it is skipped.\n",
200
+ "\n",
201
+ " # guard for no S and T in original axes\n",
202
+ " if (\"S\" not in self._original_axes) and (\"T\" not in self._original_axes):\n",
203
+ " if sample_idx not in [0, -1]:\n",
204
+ " raise IndexError(\n",
205
+ " f\"Sample index {sample_idx} out of bounds for S axes with size \"\n",
206
+ " f\"{self.data_shape[0]}\"\n",
207
+ " )\n",
208
+ "\n",
209
+ " patch_slice: list[Union[int, slice]] = []\n",
210
+ " for d in self._original_axes:\n",
211
+ " if d == \"S\":\n",
212
+ " patch_slice.append(self._get_S_index(sample_idx))\n",
213
+ " elif d == \"T\":\n",
214
+ " patch_slice.append(self._get_T_index(sample_idx))\n",
215
+ " elif d == \"C\":\n",
216
+ " patch_slice.append(slice(None, None))\n",
217
+ " elif d == \"Z\":\n",
218
+ " patch_slice.append(slice(coords[0], coords[0] + patch_size[0]))\n",
219
+ " elif d == \"Y\":\n",
220
+ " y_idx = 0 if \"Z\" not in self._original_axes else 1\n",
221
+ " patch_slice.append(\n",
222
+ " slice(coords[y_idx], coords[y_idx] + patch_size[y_idx])\n",
223
+ " )\n",
224
+ " elif d == \"X\":\n",
225
+ " x_idx = 1 if \"Z\" not in self._original_axes else 2\n",
226
+ " patch_slice.append(\n",
227
+ " slice(coords[x_idx], coords[x_idx] + patch_size[x_idx])\n",
228
+ " )\n",
229
+ " else:\n",
230
+ " raise ValueError(f\"Unrecognised axis '{d}', axes should be in STCZYX.\")\n",
231
+ "\n",
232
+ " patch = self._image_data[tuple(patch_slice)]\n",
233
+ " patch_axes = self._original_axes.replace(\"S\", \"\").replace(\"T\", \"\")\n",
234
+ " return reshape_array(patch, patch_axes)[0] # remove first sample dim\n",
235
+ "\n",
236
+ " def _get_T_index(self, sample_idx: int) -> int:\n",
237
+ " \"\"\"Get T index given `sample_idx`.\"\"\"\n",
238
+ " if \"T\" not in self._original_axes:\n",
239
+ " raise ValueError(\"No 'T' axis specified in original data axes.\")\n",
240
+ " axis_idx = self._original_axes.index(\"T\")\n",
241
+ " dim = self._original_data_shape[axis_idx]\n",
242
+ "\n",
243
+ " # new S' = S*T\n",
244
+ " # T_idx = S_idx' // T_size\n",
245
+ " # S_idx = S_idx' % T_size\n",
246
+ " # - floor divide finds the row\n",
247
+ " # - modulus finds how far along the row i.e. the column\n",
248
+ " return sample_idx % dim\n",
249
+ "\n",
250
+ " def _get_S_index(self, sample_idx: int) -> int:\n",
251
+ " \"\"\"Get S index given `sample_idx`.\"\"\"\n",
252
+ " if \"S\" not in self._original_axes:\n",
253
+ " raise ValueError(\"No 'S' axis specified in original data axes.\")\n",
254
+ " if \"T\" in self._original_axes:\n",
255
+ " T_axis_idx = self._original_axes.index(\"T\")\n",
256
+ " T_dim = self._original_data_shape[T_axis_idx]\n",
257
+ "\n",
258
+ " # new S' = S*T\n",
259
+ " # T_idx = S_idx' // T_size\n",
260
+ " # S_idx = S_idx' % T_size\n",
261
+ " # - floor divide finds the row\n",
262
+ " # - modulus finds how far along the row i.e. the column\n",
263
+ " return sample_idx // T_dim\n",
264
+ " else:\n",
265
+ " return sample_idx"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "metadata": {},
271
+ "source": [
272
+ "### Now define the image loader\n",
273
+ "\n",
274
+ "The loader needs to have the first two arguments be `source` and `axes`, then any \n",
275
+ "additional kwargs are allowed. However, note that the additional kwargs have to be \n",
276
+ "shared by both the input and the target when the dataset is initialized.\n"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "metadata": {},
283
+ "outputs": [],
284
+ "source": [
285
+ "# A image stack loader\n",
286
+ "# both the input and target image stacks must be contained within the same HDF5 file\n",
287
+ "def hdf5_image_stack_loader(\n",
288
+ " source: Sequence[str], axes: str, file: h5py.File\n",
289
+ ") -> Sequence[HDF5ImageStack]:\n",
290
+ " image_stacks: list[HDF5ImageStack] = []\n",
291
+ " for data_path in source:\n",
292
+ " if data_path not in file:\n",
293
+ " raise KeyError(f\"Data does not exist at path '{data_path}'\")\n",
294
+ " image_data = file[data_path]\n",
295
+ " if not isinstance(image_data, h5py.Dataset):\n",
296
+ " raise TypeError(f\"HDF5 node at path '{data_path}' is not a Dataset.\")\n",
297
+ " image_stacks.append(HDF5ImageStack(image_data, axes=axes))\n",
298
+ " return image_stacks"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "# This is an alternative hdf5 image stack loader\n",
308
+ "# The input and target files can be contained in separate hdf5 files\n",
309
+ "# An HDF5Source typed dict has to be defined\n",
310
+ "# this is to allow both the file and the data paths to be combined in a single argument\n",
311
+ "\n",
312
+ "from typing import TypedDict\n",
313
+ "\n",
314
+ "\n",
315
+ "class HDF5Source(TypedDict):\n",
316
+ " file: h5py.File\n",
317
+ " data_path: str\n",
318
+ "\n",
319
+ "\n",
320
+ "def hdf5_image_stack_loader_alt(\n",
321
+ " source: Sequence[HDF5Source], axes: str\n",
322
+ ") -> Sequence[HDF5ImageStack]:\n",
323
+ " image_stacks: list[HDF5ImageStack] = []\n",
324
+ " for image_stack_source in source:\n",
325
+ " data_path = image_stack_source[\"data_path\"]\n",
326
+ " file = image_stack_source[\"file\"]\n",
327
+ " if data_path not in file:\n",
328
+ " raise KeyError(f\"Data does not exist at path '{data_path}'\")\n",
329
+ " image_data = file[data_path]\n",
330
+ " if not isinstance(image_data, h5py.Dataset):\n",
331
+ " raise TypeError(f\"HDF5 node at path '{data_path}' is not a Dataset.\")\n",
332
+ " image_stacks.append(HDF5ImageStack(image_data, axes=axes))\n",
333
+ " return image_stacks"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "markdown",
338
+ "metadata": {},
339
+ "source": [
340
+ "## Now we test it\n",
341
+ "\n",
342
+ "### create a configuration for the data"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": null,
348
+ "metadata": {},
349
+ "outputs": [],
350
+ "source": [
351
+ "train_files = sorted(train_path.glob(\"*.tif\"))\n",
352
+ "train_target_files = sorted(target_path.glob(\"*.tif\"))\n",
353
+ "\n",
354
+ "config = create_care_configuration(\n",
355
+ " experiment_name=\"care_U20S\",\n",
356
+ " data_type=\"custom\",\n",
357
+ " axes=\"YX\",\n",
358
+ " patch_size=[128, 128],\n",
359
+ " batch_size=32,\n",
360
+ " num_epochs=50,\n",
361
+ ")"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": null,
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "hdf5_file = h5py.File(hdf5_file_path, mode=\"r\")\n",
371
+ "\n",
372
+ "inputs = sorted([f\"train_input/{key}\" for key in hdf5_file[\"train_input\"].keys()])\n",
373
+ "targets = sorted([f\"train_target/{key}\" for key in hdf5_file[\"train_target\"].keys()])\n",
374
+ "test_inputs = sorted([f\"test_input/{key}\" for key in hdf5_file[\"test_input\"].keys()])\n",
375
+ "test_targets = sorted([f\"test_target/{key}\" for key in hdf5_file[\"test_target\"].keys()])\n",
376
+ "\n",
377
+ "dataset = create_dataset(\n",
378
+ " config=config.data_config,\n",
379
+ " mode=Mode.TRAINING,\n",
380
+ " inputs=inputs,\n",
381
+ " targets=targets,\n",
382
+ " in_memory=False,\n",
383
+ " image_stack_loader=hdf5_image_stack_loader,\n",
384
+ " image_stack_loader_kwargs={\"file\": hdf5_file},\n",
385
+ ")"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "markdown",
390
+ "metadata": {},
391
+ "source": [
392
+ "### Index the dataset and display the result"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": null,
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "fig, axes = plt.subplots(1, 2)\n",
402
+ "train_input, target = dataset[0]\n",
403
+ "axes[0].imshow(train_input.data[0])\n",
404
+ "axes[0].set_title(\"Input\")\n",
405
+ "axes[1].imshow(target.data[0])\n",
406
+ "axes[1].set_title(\"Target\")"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": null,
412
+ "metadata": {},
413
+ "outputs": [],
414
+ "source": [
415
+ "# input and target are ImageRegionData objects\n",
416
+ "train_input, target"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "markdown",
421
+ "metadata": {},
422
+ "source": [
423
+ "### Test the alternative image stack loader"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "hdf5_file = h5py.File(hdf5_file_path, mode=\"r\")\n",
433
+ "\n",
434
+ "data_keys = sorted(hdf5_file[\"train_input\"].keys())\n",
435
+ "\n",
436
+ "# for the alternative image stack loader we have to construct a list of dicts\n",
437
+ "# because we defined the source type to be a HDF5Source typed dict\n",
438
+ "inputs: list[HDF5Source] = [\n",
439
+ " {\"data_path\": f\"train_input/{key}\", \"file\": hdf5_file} for key in data_keys\n",
440
+ "]\n",
441
+ "targets: list[HDF5Source] = [\n",
442
+ " {\"data_path\": f\"train_target/{key}\", \"file\": hdf5_file} for key in data_keys\n",
443
+ "]\n",
444
+ "\n",
445
+ "dataset = create_dataset(\n",
446
+ " config=config.data_config,\n",
447
+ " mode=Mode.TRAINING,\n",
448
+ " inputs=inputs,\n",
449
+ " targets=targets,\n",
450
+ " in_memory=False,\n",
451
+ " image_stack_loader=hdf5_image_stack_loader_alt,\n",
452
+ " # now we don't have any additional kwargs\n",
453
+ ")"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "metadata": {},
460
+ "outputs": [],
461
+ "source": [
462
+ "# display the first item\n",
463
+ "# note this will be a different patch because of the random patching\n",
464
+ "fig, axes = plt.subplots(1, 2)\n",
465
+ "train_input, target = dataset[0]\n",
466
+ "axes[0].imshow(train_input.data[0])\n",
467
+ "axes[0].set_title(\"Input\")\n",
468
+ "axes[1].imshow(target.data[0])\n",
469
+ "axes[1].set_title(\"Target\")"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "markdown",
474
+ "metadata": {},
475
+ "source": [
476
+ "### Now let's run N2V training pipeline and see how it performs"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "markdown",
481
+ "metadata": {},
482
+ "source": [
483
+ "#### Creating the lightning data module for training"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "metadata": {},
490
+ "outputs": [],
491
+ "source": [
492
+ "from careamics.config.inference_model import InferenceConfig\n",
493
+ "\n",
494
+ "from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
495
+ "\n",
496
+ "hdf5_file = h5py.File(hdf5_file_path, mode=\"r\")\n",
497
+ "\n",
498
+ "train_data_keys = sorted(hdf5_file[\"train_input\"].keys())\n",
499
+ "\n",
500
+ "inputs: list[HDF5Source] = [\n",
501
+ " {\"data_path\": f\"train_input/{key}\", \"file\": hdf5_file} for key in train_data_keys\n",
502
+ "]\n",
503
+ "targets: list[HDF5Source] = [\n",
504
+ " {\"data_path\": f\"train_target/{key}\", \"file\": hdf5_file} for key in train_data_keys\n",
505
+ "]\n",
506
+ "\n",
507
+ "test_data_keys = sorted(hdf5_file[\"test_input\"].keys())\n",
508
+ "test_inputs: list[HDF5Source] = [\n",
509
+ " {\"data_path\": f\"test_input/{key}\", \"file\": hdf5_file} for key in test_data_keys\n",
510
+ "]\n",
511
+ "test_targets: list[HDF5Source] = [\n",
512
+ " {\"data_path\": f\"test_target/{key}\", \"file\": hdf5_file} for key in test_data_keys\n",
513
+ "]\n",
514
+ "config = create_care_configuration(\n",
515
+ " experiment_name=\"care_U20S\",\n",
516
+ " data_type=\"custom\",\n",
517
+ " axes=\"YX\",\n",
518
+ " patch_size=[128, 128],\n",
519
+ " batch_size=32,\n",
520
+ " num_epochs=50,\n",
521
+ ")\n",
522
+ "train_data_module = CareamicsDataModule(\n",
523
+ " data_config=config.data_config,\n",
524
+ " train_data=inputs,\n",
525
+ " train_data_target=targets,\n",
526
+ " val_data=inputs,\n",
527
+ " val_data_target=targets,\n",
528
+ " image_stack_loader=hdf5_image_stack_loader_alt,\n",
529
+ ")"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "markdown",
534
+ "metadata": {},
535
+ "source": [
536
+ "#### Creating the model and the trainer"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "code",
541
+ "execution_count": null,
542
+ "metadata": {},
543
+ "outputs": [],
544
+ "source": [
545
+ "from pytorch_lightning import Trainer\n",
546
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
547
+ "\n",
548
+ "from careamics.lightning.dataset_ng.lightning_modules import CAREModule\n",
549
+ "\n",
550
+ "root = Path(\"care_stack_loader\")\n",
551
+ "\n",
552
+ "# TODO: replace with N2V!!!\n",
553
+ "model = CAREModule(config.algorithm_config)\n",
554
+ "\n",
555
+ "callbacks = [\n",
556
+ " ModelCheckpoint(\n",
557
+ " dirpath=root / \"checkpoints\",\n",
558
+ " filename=\"care_baseline\",\n",
559
+ " save_last=True,\n",
560
+ " monitor=\"val_loss\",\n",
561
+ " mode=\"min\",\n",
562
+ " )\n",
563
+ "]\n",
564
+ "\n",
565
+ "trainer = Trainer(max_epochs=50, default_root_dir=root, callbacks=callbacks)"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "markdown",
570
+ "metadata": {},
571
+ "source": [
572
+ "#### Training the model"
573
+ ]
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "execution_count": null,
578
+ "metadata": {},
579
+ "outputs": [],
580
+ "source": [
581
+ "trainer.fit(model, datamodule=train_data_module)"
582
+ ]
583
+ },
584
+ {
585
+ "cell_type": "markdown",
586
+ "metadata": {},
587
+ "source": [
588
+ "#### Creating the inference data module"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": null,
594
+ "metadata": {},
595
+ "outputs": [],
596
+ "source": [
597
+ "\n",
598
+ "inference_config = InferenceConfig(\n",
599
+ " model_config=config,\n",
600
+ " data_type=\"custom\",\n",
601
+ " tile_size=(128, 128),\n",
602
+ " tile_overlap=(32, 32),\n",
603
+ " axes=\"YX\",\n",
604
+ " batch_size=1,\n",
605
+ " image_means=train_data_module.train_dataset.input_stats.means,\n",
606
+ " image_stds=train_data_module.train_dataset.input_stats.stds,\n",
607
+ ")\n",
608
+ "\n",
609
+ "inf_data_module = CareamicsDataModule(\n",
610
+ " data_config=inference_config,\n",
611
+ " pred_data=test_inputs,\n",
612
+ " image_stack_loader=hdf5_image_stack_loader_alt,\n",
613
+ ")"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "markdown",
618
+ "metadata": {},
619
+ "source": [
620
+ "#### Running the prediction on the test set"
621
+ ]
622
+ },
623
+ {
624
+ "cell_type": "code",
625
+ "execution_count": null,
626
+ "metadata": {},
627
+ "outputs": [],
628
+ "source": [
629
+ "from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
630
+ "from careamics.prediction_utils import convert_outputs\n",
631
+ "\n",
632
+ "predictions = trainer.predict(model, datamodule=inf_data_module)\n",
633
+ "tile_infos = imageregions_to_tileinfos(predictions)\n",
634
+ "prediction = convert_outputs(tile_infos, tiled=True)"
635
+ ]
636
+ },
637
+ {
638
+ "cell_type": "markdown",
639
+ "metadata": {},
640
+ "source": [
641
+ "#### Displaying the predictions"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "metadata": {},
648
+ "outputs": [],
649
+ "source": [
650
+ "from careamics.utils.metrics import psnr, scale_invariant_psnr\n",
651
+ "\n",
652
+ "# Show two images\n",
653
+ "noises = [tifffile.imread(f) for f in sorted(test_path.glob(\"*.tif\"))]\n",
654
+ "gts = [tifffile.imread(f) for f in sorted(test_target_path.glob(\"*.tif\"))]\n",
655
+ "\n",
656
+ "fig, ax = plt.subplots(3, 3, figsize=(7, 7))\n",
657
+ "fig.tight_layout()\n",
658
+ "\n",
659
+ "for i in range(3):\n",
660
+ " pred_image = prediction[i].squeeze()\n",
661
+ " psnr_noisy = psnr(\n",
662
+ " gts[i],\n",
663
+ " noises[i],\n",
664
+ " data_range=gts[i].max() - gts[i].min(),\n",
665
+ " )\n",
666
+ " psnr_result = psnr(\n",
667
+ " gts[i],\n",
668
+ " pred_image,\n",
669
+ " data_range=gts[i].max() - gts[i].min(),\n",
670
+ " )\n",
671
+ "\n",
672
+ " scale_invariant_psnr_result = scale_invariant_psnr(gts[i], pred_image)\n",
673
+ "\n",
674
+ " ax[i, 0].imshow(noises[i], cmap=\"gray\")\n",
675
+ " ax[i, 0].title.set_text(f\"Noisy\\nPSNR: {psnr_noisy:.2f}\")\n",
676
+ "\n",
677
+ " ax[i, 1].imshow(pred_image, cmap=\"gray\")\n",
678
+ " ax[i, 1].title.set_text(\n",
679
+ " f\"Prediction\\nPSNR: {psnr_result:.2f}\\n\"\n",
680
+ " f\"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}\"\n",
681
+ " )\n",
682
+ "\n",
683
+ " ax[i, 2].imshow(gts[i], cmap=\"gray\")\n",
684
+ " ax[i, 2].title.set_text(\"Ground-truth\")"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "markdown",
689
+ "metadata": {},
690
+ "source": [
691
+ "#### Calculating the metrics on the test set"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": null,
697
+ "metadata": {},
698
+ "outputs": [],
699
+ "source": [
700
+ "psnrs = np.zeros((len(prediction), 1))\n",
701
+ "scale_invariant_psnrs = np.zeros((len(prediction), 1))\n",
702
+ "\n",
703
+ "for i, (pred, gt) in enumerate(zip(prediction, gts, strict=False)):\n",
704
+ " psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
705
+ " scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
706
+ "\n",
707
+ "print(f\"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}\")\n",
708
+ "print(\n",
709
+ " f\"Scale invariant PSNR: \"\n",
710
+ " f\"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}\"\n",
711
+ ")"
712
+ ]
713
+ }
714
+ ],
715
+ "metadata": {
716
+ "kernelspec": {
717
+ "display_name": "Python 3",
718
+ "language": "python",
719
+ "name": "python3"
720
+ },
721
+ "language_info": {
722
+ "codemirror_mode": {
723
+ "name": "ipython",
724
+ "version": 3
725
+ },
726
+ "file_extension": ".py",
727
+ "mimetype": "text/x-python",
728
+ "name": "python",
729
+ "nbconvert_exporter": "python",
730
+ "pygments_lexer": "ipython3",
731
+ "version": "3.9.21"
732
+ }
733
+ },
734
+ "nbformat": 4,
735
+ "nbformat_minor": 2
736
+ }