careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,215 @@
1
+ """Module containing convenience function to create `WriteStrategy`."""
2
+
3
+ from typing import Any
4
+
5
+ from careamics.config.support import SupportedData
6
+ from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
7
+
8
+ from .write_strategy import CacheTiles, WriteImage, WriteStrategy
9
+
10
+
11
+ def create_write_strategy(
12
+ write_type: SupportedWriteType,
13
+ tiled: bool,
14
+ write_func: WriteFunc | None = None,
15
+ write_extension: str | None = None,
16
+ write_func_kwargs: dict[str, Any] | None = None,
17
+ ) -> WriteStrategy:
18
+ """
19
+ Create a write strategy from convenient parameters.
20
+
21
+ Parameters
22
+ ----------
23
+ write_type : {"tiff", "custom"}
24
+ The data type to save as, includes custom.
25
+ tiled : bool
26
+ Whether the prediction will be tiled or not.
27
+ write_func : WriteFunc, optional
28
+ If a known `write_type` is selected this argument is ignored. For a custom
29
+ `write_type` a function to save the data must be passed. See notes below.
30
+ write_extension : str, optional
31
+ If a known `write_type` is selected this argument is ignored. For a custom
32
+ `write_type` an extension to save the data with must be passed.
33
+ write_func_kwargs : dict of {str: any}, optional
34
+ Additional keyword arguments to be passed to the save function.
35
+
36
+ Returns
37
+ -------
38
+ WriteStrategy
39
+ A strategy for writing predicions.
40
+
41
+ Notes
42
+ -----
43
+ The `write_func` function signature must match that of the example below
44
+ ```
45
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
46
+ ```
47
+
48
+ The `write_func_kwargs` will be passed to the `write_func` doing the following:
49
+ ```
50
+ write_func(file_path=file_path, img=img, **kwargs)
51
+ ```
52
+ """
53
+ if write_func_kwargs is None:
54
+ write_func_kwargs = {}
55
+
56
+ write_strategy: WriteStrategy
57
+ if not tiled:
58
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
59
+ write_extension = select_write_extension(
60
+ write_type=write_type, write_extension=write_extension
61
+ )
62
+ write_strategy = WriteImage(
63
+ write_func=write_func,
64
+ write_extension=write_extension,
65
+ write_func_kwargs=write_func_kwargs,
66
+ )
67
+ else:
68
+ # select CacheTiles or WriteTilesZarr (when implemented)
69
+ write_strategy = _create_tiled_write_strategy(
70
+ write_type=write_type,
71
+ write_func=write_func,
72
+ write_extension=write_extension,
73
+ write_func_kwargs=write_func_kwargs,
74
+ )
75
+
76
+ return write_strategy
77
+
78
+
79
+ def _create_tiled_write_strategy(
80
+ write_type: SupportedWriteType,
81
+ write_func: WriteFunc | None,
82
+ write_extension: str | None,
83
+ write_func_kwargs: dict[str, Any],
84
+ ) -> WriteStrategy:
85
+ """
86
+ Create a tiled write strategy.
87
+
88
+ Either `CacheTiles` for caching tiles until a whole image is predicted or
89
+ `WriteTilesZarr` for writing tiles directly to disk.
90
+
91
+ Parameters
92
+ ----------
93
+ write_type : {"tiff", "custom"}
94
+ The data type to save as, includes custom.
95
+ write_func : WriteFunc, optional
96
+ If a known `write_type` is selected this argument is ignored. For a custom
97
+ `write_type` a function to save the data must be passed. See notes below.
98
+ write_extension : str, optional
99
+ If a known `write_type` is selected this argument is ignored. For a custom
100
+ `write_type` an extension to save the data with must be passed.
101
+ write_func_kwargs : dict of {str: any}
102
+ Additional keyword arguments to be passed to the save function.
103
+
104
+ Returns
105
+ -------
106
+ WriteStrategy
107
+ A strategy for writing tiled predictions.
108
+
109
+ Raises
110
+ ------
111
+ NotImplementedError
112
+ if `write_type="zarr" is chosen.
113
+ """
114
+ # if write_type == SupportedData.ZARR:
115
+ # create *args, **kwargs
116
+ # return WriteTilesZarr(*args, **kwargs)
117
+ # else:
118
+ if write_type == "zarr":
119
+ raise NotImplementedError("Saving to zarr is not implemented yet.")
120
+ else:
121
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
122
+ write_extension = select_write_extension(
123
+ write_type=write_type, write_extension=write_extension
124
+ )
125
+ return CacheTiles(
126
+ write_func=write_func,
127
+ write_extension=write_extension,
128
+ write_func_kwargs=write_func_kwargs,
129
+ )
130
+
131
+
132
+ def select_write_func(
133
+ write_type: SupportedWriteType, write_func: WriteFunc | None = None
134
+ ) -> WriteFunc:
135
+ """
136
+ Return a function to write images.
137
+
138
+ If `write_type` is "custom" then `write_func`, otherwise the known write function
139
+ is selected.
140
+
141
+ Parameters
142
+ ----------
143
+ write_type : {"tiff", "custom"}
144
+ The data type to save as, includes custom.
145
+ write_func : WriteFunc, optional
146
+ If a known `write_type` is selected this argument is ignored. For a custom
147
+ `write_type` a function to save the data must be passed. See notes below.
148
+
149
+ Returns
150
+ -------
151
+ WriteFunc
152
+ A function for writing images.
153
+
154
+ Raises
155
+ ------
156
+ ValueError
157
+ If `write_type="custom"` but `write_func` has not been given.
158
+
159
+ Notes
160
+ -----
161
+ The `write_func` function signature must match that of the example below
162
+ ```
163
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
164
+ ```
165
+ """
166
+ if write_type == SupportedData.CUSTOM:
167
+ if write_func is None:
168
+ raise ValueError(
169
+ "A save function must be provided for custom data types."
170
+ # TODO: link to how save functions should be implemented
171
+ )
172
+ else:
173
+ write_func = write_func
174
+ else:
175
+ write_func = get_write_func(write_type)
176
+ return write_func
177
+
178
+
179
+ def select_write_extension(
180
+ write_type: SupportedWriteType, write_extension: str | None = None
181
+ ) -> str:
182
+ """
183
+ Return an extension to add to file paths.
184
+
185
+ If `write_type` is "custom" then `write_extension`, otherwise the known
186
+ write extension is selected.
187
+
188
+ Parameters
189
+ ----------
190
+ write_type : {"tiff", "custom"}
191
+ The data type to save as, includes custom.
192
+ write_extension : str, optional
193
+ If a known `write_type` is selected this argument is ignored. For a custom
194
+ `write_type` an extension to save the data with must be passed.
195
+
196
+ Returns
197
+ -------
198
+ str
199
+ The extension to be added to file paths.
200
+
201
+ Raises
202
+ ------
203
+ ValueError
204
+ If `self.save_type="custom"` but `save_extension` has not been given.
205
+ """
206
+ write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
207
+ if write_type_ == SupportedData.CUSTOM:
208
+ if write_extension is None:
209
+ raise ValueError("A save extension must be provided for custom data types.")
210
+ else:
211
+ write_extension = write_extension
212
+ else:
213
+ # kind of a weird pattern -> reason to move get_extension from SupportedData
214
+ write_extension = write_type_.get_extension(write_type_)
215
+ return write_extension
@@ -0,0 +1,90 @@
1
+ """Progressbar callback."""
2
+
3
+ import sys
4
+ from typing import Union
5
+
6
+ from pytorch_lightning import LightningModule, Trainer
7
+ from pytorch_lightning.callbacks import TQDMProgressBar
8
+ from tqdm.auto import tqdm
9
+
10
+
11
+ class ProgressBarCallback(TQDMProgressBar):
12
+ """Progress bar for training and validation steps."""
13
+
14
+ def init_train_tqdm(self) -> tqdm:
15
+ """Override this to customize the tqdm bar for training.
16
+
17
+ Returns
18
+ -------
19
+ tqdm
20
+ A tqdm bar.
21
+ """
22
+ bar = tqdm(
23
+ desc="Training",
24
+ position=(2 * self.process_position),
25
+ disable=self.is_disabled,
26
+ leave=True,
27
+ dynamic_ncols=True,
28
+ file=sys.stdout,
29
+ smoothing=0,
30
+ )
31
+ return bar
32
+
33
+ def init_validation_tqdm(self) -> tqdm:
34
+ """Override this to customize the tqdm bar for validation.
35
+
36
+ Returns
37
+ -------
38
+ tqdm
39
+ A tqdm bar.
40
+ """
41
+ # The main progress bar doesn't exist in `trainer.validate()`
42
+ has_main_bar = self.train_progress_bar is not None
43
+ bar = tqdm(
44
+ desc="Validating",
45
+ position=(2 * self.process_position + has_main_bar),
46
+ disable=self.is_disabled,
47
+ leave=False,
48
+ dynamic_ncols=True,
49
+ file=sys.stdout,
50
+ )
51
+ return bar
52
+
53
+ def init_test_tqdm(self) -> tqdm:
54
+ """Override this to customize the tqdm bar for testing.
55
+
56
+ Returns
57
+ -------
58
+ tqdm
59
+ A tqdm bar.
60
+ """
61
+ bar = tqdm(
62
+ desc="Testing",
63
+ position=(2 * self.process_position),
64
+ disable=self.is_disabled,
65
+ leave=True,
66
+ dynamic_ncols=False,
67
+ ncols=100,
68
+ file=sys.stdout,
69
+ )
70
+ return bar
71
+
72
+ def get_metrics(
73
+ self, trainer: Trainer, pl_module: LightningModule
74
+ ) -> dict[str, Union[int, str, float, dict[str, float]]]:
75
+ """Override this to customize the metrics displayed in the progress bar.
76
+
77
+ Parameters
78
+ ----------
79
+ trainer : Trainer
80
+ The trainer object.
81
+ pl_module : LightningModule
82
+ The LightningModule object, unused.
83
+
84
+ Returns
85
+ -------
86
+ dict
87
+ A dictionary with the metrics to display in the progress bar.
88
+ """
89
+ pbar_metrics = trainer.progress_bar_metrics
90
+ return {**pbar_metrics}
@@ -0,0 +1 @@
1
+ """Next-Generation DataModules for Careamics."""
@@ -0,0 +1 @@
1
+ """NG Dataset compatible callbacks for PyTorch Lightning."""
@@ -0,0 +1,29 @@
1
+ """A package for the `PredictionWriterCallback` class and utilities."""
2
+
3
+ __all__ = [
4
+ "CachedTiles",
5
+ "PredictionWriterCallback",
6
+ "WriteImage",
7
+ "WriteStrategy",
8
+ "WriteTilesZarr",
9
+ "create_write_file_path",
10
+ "create_write_strategy",
11
+ "decollate_image_region_data",
12
+ "select_write_extension",
13
+ "select_write_func",
14
+ ]
15
+
16
+ from .cached_tiles_strategy import CachedTiles
17
+ from .file_path_utils import create_write_file_path
18
+ from .prediction_writer_callback import (
19
+ PredictionWriterCallback,
20
+ decollate_image_region_data,
21
+ )
22
+ from .write_image_strategy import WriteImage
23
+ from .write_strategy import WriteStrategy
24
+ from .write_strategy_factory import (
25
+ create_write_strategy,
26
+ select_write_extension,
27
+ select_write_func,
28
+ )
29
+ from .write_tiles_zarr_strategy import WriteTilesZarr
@@ -0,0 +1,164 @@
1
+ """A writing strategy that caches tiles until a whole image is predicted."""
2
+
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from careamics.dataset_ng.dataset import ImageRegionData
8
+ from careamics.file_io import WriteFunc
9
+ from careamics.lightning.dataset_ng.prediction import (
10
+ stitch_single_prediction,
11
+ )
12
+
13
+ from .file_path_utils import create_write_file_path
14
+ from .write_strategy import WriteStrategy
15
+
16
+
17
+ class CachedTiles(WriteStrategy):
18
+ """
19
+ A write strategy that will cache tiles.
20
+
21
+ Tiles are cached until a whole image is predicted on. Then the stitched
22
+ prediction is saved.
23
+
24
+ Parameters
25
+ ----------
26
+ write_func : WriteFunc
27
+ Function used to save predictions.
28
+ write_extension : str
29
+ Extension added to prediction file paths.
30
+ write_func_kwargs : dict of {str: Any}
31
+ Extra kwargs to pass to `write_func`.
32
+
33
+ Attributes
34
+ ----------
35
+ write_func : WriteFunc
36
+ Function used to save predictions.
37
+ write_extension : str
38
+ Extension added to prediction file paths.
39
+ write_func_kwargs : dict of {str: Any}
40
+ Extra kwargs to pass to `write_func`.
41
+ tile_cache : list of numpy.ndarray
42
+ Tiles cached for stitching prediction.
43
+ tile_info_cache : list of TileInformation
44
+ Cached tile information for stitching prediction.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ write_func: WriteFunc,
50
+ write_extension: str,
51
+ write_func_kwargs: dict[str, Any],
52
+ ) -> None:
53
+ """
54
+ A write strategy that will cache tiles.
55
+
56
+ Tiles are cached until a whole image is predicted on. Then the stitched
57
+ prediction is saved.
58
+
59
+ Parameters
60
+ ----------
61
+ write_func : WriteFunc
62
+ Function used to save predictions.
63
+ write_extension : str
64
+ Extension added to prediction file paths.
65
+ write_func_kwargs : dict of {str: Any}
66
+ Extra kwargs to pass to `write_func`.
67
+ """
68
+ super().__init__()
69
+
70
+ self.write_func: WriteFunc = write_func
71
+ self.write_extension: str = write_extension
72
+ self.write_func_kwargs: dict[str, Any] = write_func_kwargs
73
+
74
+ # where tiles will be cached until a whole image has been predicted
75
+ self.tile_cache: dict[int, list[ImageRegionData]] = defaultdict(list)
76
+
77
+ def write_batch(
78
+ self,
79
+ dirpath: Path,
80
+ predictions: list[ImageRegionData],
81
+ ) -> None:
82
+ """
83
+ Cache tiles until the last tile is predicted, then save the stitched image.
84
+
85
+ Parameters
86
+ ----------
87
+ dirpath : Path
88
+ Path to directory to save predictions to.
89
+ predictions : list[ImageRegionData]
90
+ Decollated predictions.
91
+ """
92
+ assert predictions is not None
93
+
94
+ # cache tiles
95
+ for tile in predictions:
96
+ data_idx = tile.region_spec["data_idx"]
97
+ self.tile_cache[data_idx].append(tile)
98
+
99
+ self._write_images(dirpath)
100
+
101
+ def _get_full_images(self) -> list[int]:
102
+ """
103
+ Get data indices of full images contained in the cache.
104
+
105
+ Returns
106
+ -------
107
+ list of int
108
+ Data indices of full images contained in the cache.
109
+ """
110
+ full_images = []
111
+ for data_idx in self.tile_cache.keys():
112
+ exp_n_tiles = self.tile_cache[data_idx][0].region_spec["total_tiles"]
113
+
114
+ if len(self.tile_cache[data_idx]) == exp_n_tiles:
115
+ full_images.append(data_idx)
116
+ elif len(self.tile_cache[data_idx]) > exp_n_tiles:
117
+ raise ValueError(
118
+ f"More tiles cached for data_idx {data_idx} than expected. "
119
+ f"Expected {exp_n_tiles}, found "
120
+ f"{len(self.tile_cache[data_idx])}."
121
+ )
122
+
123
+ return full_images
124
+
125
+ def _stitch_and_write_single(
126
+ self, dirpath: Path, tiles: list[ImageRegionData]
127
+ ) -> None:
128
+ """
129
+ Stitch and write a single image from tiles.
130
+
131
+ Parameters
132
+ ----------
133
+ dirpath : Path
134
+ Path to directory to save predictions to.
135
+ tiles : list[ImageRegionData]
136
+ Tiles to stitch and write.
137
+ """
138
+ # stitch prediction
139
+ prediction_image = stitch_single_prediction(tiles)
140
+
141
+ # write prediction
142
+ source: Path = Path(tiles[0].source)
143
+ file_path = create_write_file_path(
144
+ dirpath=dirpath,
145
+ file_path=source,
146
+ write_extension=self.write_extension,
147
+ )
148
+ self.write_func(
149
+ file_path=file_path, img=prediction_image, **self.write_func_kwargs
150
+ )
151
+
152
+ def _write_images(self, dirpath: Path) -> None:
153
+ """
154
+ Write full images from cached tiles.
155
+
156
+ Parameters
157
+ ----------
158
+ dirpath : Path
159
+ Path to directory to save predictions to.
160
+ """
161
+ full_images = self._get_full_images()
162
+ for data_idx in full_images:
163
+ tiles = self.tile_cache.pop(data_idx)
164
+ self._stitch_and_write_single(dirpath, tiles)
@@ -0,0 +1,33 @@
1
+ """Module containing file path utilities for `WriteStrategy` to use."""
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ def create_write_file_path(
7
+ dirpath: Path, file_path: Path, write_extension: str
8
+ ) -> Path:
9
+ """
10
+ Create the file name for the output file.
11
+
12
+ Takes the original file path, changes the directory to `dirpath` and changes
13
+ the extension to `write_extension`.
14
+
15
+ Parameters
16
+ ----------
17
+ dirpath : pathlib.Path
18
+ The output directory to write file to.
19
+ file_path : pathlib.Path
20
+ The original file path.
21
+ write_extension : str
22
+ The extension that output files should have.
23
+
24
+ Returns
25
+ -------
26
+ Path
27
+ The output file path.
28
+ """
29
+ file_path = Path(file_path) # as a guard against str input
30
+
31
+ file_name = Path(file_path.stem).with_suffix(write_extension)
32
+ file_path = dirpath / file_name
33
+ return file_path