careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,234 @@
1
+ """Module containing `PredictionWriterCallback` class."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+ from pathlib import Path
7
+ from typing import Any, Union
8
+
9
+ from pytorch_lightning import LightningModule, Trainer
10
+ from pytorch_lightning.callbacks import BasePredictionWriter
11
+ from torch.utils.data import DataLoader
12
+
13
+ from careamics.dataset import (
14
+ IterablePredDataset,
15
+ IterableTiledPredDataset,
16
+ )
17
+ from careamics.file_io import SupportedWriteType, WriteFunc
18
+ from careamics.utils import get_logger
19
+
20
+ from .write_strategy import WriteStrategy
21
+ from .write_strategy_factory import create_write_strategy
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ ValidPredDatasets = Union[IterablePredDataset, IterableTiledPredDataset]
26
+
27
+
28
+ class PredictionWriterCallback(BasePredictionWriter):
29
+ """
30
+ A PyTorch Lightning callback to save predictions.
31
+
32
+ Parameters
33
+ ----------
34
+ write_strategy : WriteStrategy
35
+ A strategy for writing predictions.
36
+ dirpath : Path or str, default="predictions"
37
+ The path to the directory where prediction outputs will be saved. If
38
+ `dirpath` is not absolute it is assumed to be relative to current working
39
+ directory.
40
+
41
+ Attributes
42
+ ----------
43
+ write_strategy : WriteStrategy
44
+ A strategy for writing predictions.
45
+ dirpath : pathlib.Path, default="predictions"
46
+ The path to the directory where prediction outputs will be saved. If
47
+ `dirpath` is not absolute it is assumed to be relative to current working
48
+ directory.
49
+ writing_predictions : bool
50
+ If writing predictions is turned on or off.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ write_strategy: WriteStrategy,
56
+ dirpath: Union[Path, str] = "predictions",
57
+ ):
58
+ """
59
+ A PyTorch Lightning callback to save predictions.
60
+
61
+ Parameters
62
+ ----------
63
+ write_strategy : WriteStrategy
64
+ A strategy for writing predictions.
65
+ dirpath : pathlib.Path or str, default="predictions"
66
+ The path to the directory where prediction outputs will be saved. If
67
+ `dirpath` is not absolute it is assumed to be relative to current working
68
+ directory.
69
+ """
70
+ super().__init__(write_interval="batch")
71
+
72
+ # Toggle for CAREamist to switch off saving if desired
73
+ self.writing_predictions: bool = True
74
+
75
+ self.write_strategy: WriteStrategy = write_strategy
76
+
77
+ # forward declaration
78
+ self.dirpath: Path
79
+ # attribute initialisation
80
+ self._init_dirpath(dirpath)
81
+
82
+ @classmethod
83
+ def from_write_func_params(
84
+ cls,
85
+ write_type: SupportedWriteType,
86
+ tiled: bool,
87
+ write_func: WriteFunc | None = None,
88
+ write_extension: str | None = None,
89
+ write_func_kwargs: dict[str, Any] | None = None,
90
+ dirpath: Union[Path, str] = "predictions",
91
+ ) -> PredictionWriterCallback: # TODO: change type hint to self (find out how)
92
+ """
93
+ Initialize a `PredictionWriterCallback` from write function parameters.
94
+
95
+ This will automatically create a `WriteStrategy` to be passed to the
96
+ initialization of `PredictionWriterCallback`.
97
+
98
+ Parameters
99
+ ----------
100
+ write_type : {"tiff", "custom"}
101
+ The data type to save as, includes custom.
102
+ tiled : bool
103
+ Whether the prediction will be tiled or not.
104
+ write_func : WriteFunc, optional
105
+ If a known `write_type` is selected this argument is ignored. For a custom
106
+ `write_type` a function to save the data must be passed. See notes below.
107
+ write_extension : str, optional
108
+ If a known `write_type` is selected this argument is ignored. For a custom
109
+ `write_type` an extension to save the data with must be passed.
110
+ write_func_kwargs : dict of {{str: any}}, optional
111
+ Additional keyword arguments to be passed to the save function.
112
+ dirpath : pathlib.Path or str, default="predictions"
113
+ The path to the directory where prediction outputs will be saved. If
114
+ `dirpath` is not absolute it is assumed to be relative to current working
115
+ directory.
116
+
117
+ Returns
118
+ -------
119
+ PredictionWriterCallback
120
+ Callback for writing predictions.
121
+ """
122
+ write_strategy = create_write_strategy(
123
+ write_type=write_type,
124
+ tiled=tiled,
125
+ write_func=write_func,
126
+ write_extension=write_extension,
127
+ write_func_kwargs=write_func_kwargs,
128
+ )
129
+ return cls(write_strategy=write_strategy, dirpath=dirpath)
130
+
131
+ def _init_dirpath(self, dirpath):
132
+ """
133
+ Initialize directory path. Should only be called from `__init__`.
134
+
135
+ Parameters
136
+ ----------
137
+ dirpath : pathlib.Path
138
+ See `__init__` description.
139
+ """
140
+ dirpath = Path(dirpath)
141
+ if not dirpath.is_absolute():
142
+ dirpath = Path.cwd() / dirpath
143
+ logger.warning(
144
+ "Prediction output directory is not absolute, absolute path assumed to"
145
+ f"be '{dirpath}'"
146
+ )
147
+ self.dirpath = dirpath
148
+
149
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
150
+ """
151
+ Create the prediction output directory when predict begins.
152
+
153
+ Called when fit, validate, test, predict, or tune begins.
154
+
155
+ Parameters
156
+ ----------
157
+ trainer : Trainer
158
+ PyTorch Lightning trainer.
159
+ pl_module : LightningModule
160
+ PyTorch Lightning module.
161
+ stage : str
162
+ Stage of training e.g. 'predict', 'fit', 'validate'.
163
+ """
164
+ super().setup(trainer, pl_module, stage)
165
+ if stage == "predict":
166
+ # make prediction output directory
167
+ logger.info("Making prediction output directory.")
168
+ self.dirpath.mkdir(parents=True, exist_ok=True)
169
+
170
+ def write_on_batch_end(
171
+ self,
172
+ trainer: Trainer,
173
+ pl_module: LightningModule,
174
+ prediction: Any, # TODO: change to expected type
175
+ batch_indices: Sequence[int] | None,
176
+ batch: Any, # TODO: change to expected type
177
+ batch_idx: int,
178
+ dataloader_idx: int,
179
+ ) -> None:
180
+ """
181
+ Write predictions at the end of a batch.
182
+
183
+ The method of prediction is determined by the attribute `write_strategy`.
184
+
185
+ Parameters
186
+ ----------
187
+ trainer : Trainer
188
+ PyTorch Lightning trainer.
189
+ pl_module : LightningModule
190
+ PyTorch Lightning module.
191
+ prediction : Any
192
+ Prediction outputs of `batch`.
193
+ batch_indices : sequence of Any, optional
194
+ Batch indices.
195
+ batch : Any
196
+ Input batch.
197
+ batch_idx : int
198
+ Batch index.
199
+ dataloader_idx : int
200
+ Dataloader index.
201
+ """
202
+ # if writing prediction is turned off
203
+ if not self.writing_predictions:
204
+ return
205
+
206
+ dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
207
+ dataloader: DataLoader = (
208
+ dataloaders[dataloader_idx]
209
+ if isinstance(dataloaders, list)
210
+ else dataloaders
211
+ )
212
+ dataset: ValidPredDatasets = dataloader.dataset
213
+ if not (
214
+ isinstance(dataset, IterablePredDataset)
215
+ or isinstance(dataset, IterableTiledPredDataset)
216
+ ):
217
+ # Note: Error will be raised before here from the source type
218
+ # This is for extra redundancy of errors.
219
+ raise TypeError(
220
+ "Prediction dataset has to be `IterableTiledPredDataset` or "
221
+ "`IterablePredDataset`. Cannot be `InMemoryPredDataset` because "
222
+ "filenames are taken from the original file."
223
+ )
224
+
225
+ self.write_strategy.write_batch(
226
+ trainer=trainer,
227
+ pl_module=pl_module,
228
+ prediction=prediction,
229
+ batch_indices=batch_indices,
230
+ batch=batch,
231
+ batch_idx=batch_idx,
232
+ dataloader_idx=dataloader_idx,
233
+ dirpath=self.dirpath,
234
+ )
@@ -0,0 +1,399 @@
1
+ """Module containing different strategies for writing predictions."""
2
+
3
+ from collections.abc import Sequence
4
+ from pathlib import Path
5
+ from typing import Any, Protocol, Union
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+ from pytorch_lightning import LightningModule, Trainer
10
+ from torch.utils.data import DataLoader
11
+
12
+ from careamics.config.data.tile_information import TileInformation
13
+ from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
14
+ from careamics.file_io import WriteFunc
15
+ from careamics.prediction_utils import stitch_prediction_single
16
+
17
+ from .file_path_utils import create_write_file_path, get_sample_file_path
18
+
19
+
20
+ class WriteStrategy(Protocol):
21
+ """Protocol for write strategy classes."""
22
+
23
+ def write_batch(
24
+ self,
25
+ trainer: Trainer,
26
+ pl_module: LightningModule,
27
+ prediction: Any, # TODO: change to expected type
28
+ batch_indices: Sequence[int] | None,
29
+ batch: Any, # TODO: change to expected type
30
+ batch_idx: int,
31
+ dataloader_idx: int,
32
+ dirpath: Path,
33
+ ) -> None:
34
+ """
35
+ WriteStrategy subclasses must contain this function to write a batch.
36
+
37
+ Parameters
38
+ ----------
39
+ trainer : Trainer
40
+ PyTorch Lightning Trainer.
41
+ pl_module : LightningModule
42
+ PyTorch Lightning LightningModule.
43
+ prediction : Any
44
+ Predictions on `batch`.
45
+ batch_indices : sequence of int
46
+ Indices identifying the samples in the batch.
47
+ batch : Any
48
+ Input batch.
49
+ batch_idx : int
50
+ Batch index.
51
+ dataloader_idx : int
52
+ Dataloader index.
53
+ dirpath : Path
54
+ Path to directory to save predictions to.
55
+ """
56
+
57
+
58
+ class CacheTiles(WriteStrategy):
59
+ """
60
+ A write strategy that will cache tiles.
61
+
62
+ Tiles are cached until a whole image is predicted on. Then the stitched
63
+ prediction is saved.
64
+
65
+ Parameters
66
+ ----------
67
+ write_func : WriteFunc
68
+ Function used to save predictions.
69
+ write_extension : str
70
+ Extension added to prediction file paths.
71
+ write_func_kwargs : dict of {str: Any}
72
+ Extra kwargs to pass to `write_func`.
73
+
74
+ Attributes
75
+ ----------
76
+ write_func : WriteFunc
77
+ Function used to save predictions.
78
+ write_extension : str
79
+ Extension added to prediction file paths.
80
+ write_func_kwargs : dict of {str: Any}
81
+ Extra kwargs to pass to `write_func`.
82
+ tile_cache : list of numpy.ndarray
83
+ Tiles cached for stitching prediction.
84
+ tile_info_cache : list of TileInformation
85
+ Cached tile information for stitching prediction.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ write_func: WriteFunc,
91
+ write_extension: str,
92
+ write_func_kwargs: dict[str, Any],
93
+ ) -> None:
94
+ """
95
+ A write strategy that will cache tiles.
96
+
97
+ Tiles are cached until a whole image is predicted on. Then the stitched
98
+ prediction is saved.
99
+
100
+ Parameters
101
+ ----------
102
+ write_func : WriteFunc
103
+ Function used to save predictions.
104
+ write_extension : str
105
+ Extension added to prediction file paths.
106
+ write_func_kwargs : dict of {str: Any}
107
+ Extra kwargs to pass to `write_func`.
108
+ """
109
+ super().__init__()
110
+
111
+ self.write_func: WriteFunc = write_func
112
+ self.write_extension: str = write_extension
113
+ self.write_func_kwargs: dict[str, Any] = write_func_kwargs
114
+
115
+ # where tiles will be cached until a whole image has been predicted
116
+ self.tile_cache: list[NDArray] = []
117
+ self.tile_info_cache: list[TileInformation] = []
118
+
119
+ @property
120
+ def last_tiles(self) -> list[bool]:
121
+ """
122
+ List of bool to determine whether each tile in the cache is the last tile.
123
+
124
+ Returns
125
+ -------
126
+ list of bool
127
+ Whether each tile in the tile cache is the last tile.
128
+ """
129
+ return [tile_info.last_tile for tile_info in self.tile_info_cache]
130
+
131
+ def write_batch(
132
+ self,
133
+ trainer: Trainer,
134
+ pl_module: LightningModule,
135
+ prediction: tuple[NDArray, list[TileInformation]],
136
+ batch_indices: Sequence[int] | None,
137
+ batch: tuple[NDArray, list[TileInformation]],
138
+ batch_idx: int,
139
+ dataloader_idx: int,
140
+ dirpath: Path,
141
+ ) -> None:
142
+ """
143
+ Cache tiles until the last tile is predicted; save the stitched prediction.
144
+
145
+ Parameters
146
+ ----------
147
+ trainer : Trainer
148
+ PyTorch Lightning Trainer.
149
+ pl_module : LightningModule
150
+ PyTorch Lightning LightningModule.
151
+ prediction : Any
152
+ Predictions on `batch`.
153
+ batch_indices : sequence of int
154
+ Indices identifying the samples in the batch.
155
+ batch : Any
156
+ Input batch.
157
+ batch_idx : int
158
+ Batch index.
159
+ dataloader_idx : int
160
+ Dataloader index.
161
+ dirpath : Path
162
+ Path to directory to save predictions to.
163
+ """
164
+ dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
165
+ dataloader: DataLoader = (
166
+ dataloaders[dataloader_idx]
167
+ if isinstance(dataloaders, list)
168
+ else dataloaders
169
+ )
170
+ dataset: IterableTiledPredDataset = dataloader.dataset
171
+ if not isinstance(dataset, IterableTiledPredDataset):
172
+ raise TypeError("Prediction dataset is not `IterableTiledPredDataset`.")
173
+
174
+ # cache tiles (batches are split into single samples)
175
+ self.tile_cache.extend(np.split(prediction[0], prediction[0].shape[0]))
176
+ self.tile_info_cache.extend(prediction[1])
177
+
178
+ # save stitched prediction
179
+ if self._has_last_tile():
180
+
181
+ # get image tiles and remove them from the cache
182
+ tiles, tile_infos = self._get_image_tiles()
183
+ self._clear_cache()
184
+
185
+ # stitch prediction
186
+ prediction_image = stitch_prediction_single(
187
+ tiles=tiles, tile_infos=tile_infos
188
+ )
189
+
190
+ # write prediction
191
+ sample_id = tile_infos[0].sample_id # need this to select correct file name
192
+ input_file_path = get_sample_file_path(dataset=dataset, sample_id=sample_id)
193
+ file_path = create_write_file_path(
194
+ dirpath=dirpath,
195
+ file_path=input_file_path,
196
+ write_extension=self.write_extension,
197
+ )
198
+ self.write_func(
199
+ file_path=file_path, img=prediction_image[0], **self.write_func_kwargs
200
+ )
201
+
202
+ def _has_last_tile(self) -> bool:
203
+ """
204
+ Whether a last tile is contained in the cached tiles.
205
+
206
+ Returns
207
+ -------
208
+ bool
209
+ Whether a last tile is contained in the cached tiles.
210
+ """
211
+ return any(self.last_tiles)
212
+
213
+ def _clear_cache(self) -> None:
214
+ """Remove the tiles in the cache up to the first last tile."""
215
+ index = self._last_tile_index()
216
+ self.tile_cache = self.tile_cache[index + 1 :]
217
+ self.tile_info_cache = self.tile_info_cache[index + 1 :]
218
+
219
+ def _last_tile_index(self) -> int:
220
+ """
221
+ Find the index of the last tile in the tile cache.
222
+
223
+ Returns
224
+ -------
225
+ int
226
+ Index of last tile.
227
+
228
+ Raises
229
+ ------
230
+ ValueError
231
+ If there is no last tile in the tile cache.
232
+ """
233
+ last_tiles = self.last_tiles
234
+ if not any(last_tiles):
235
+ raise ValueError("No last tile in the tile cache.")
236
+ index = np.where(last_tiles)[0][0]
237
+ return index
238
+
239
+ def _get_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]:
240
+ """
241
+ Get the tiles corresponding to a single image.
242
+
243
+ Returns
244
+ -------
245
+ tuple of (list of numpy.ndarray, list of TileInformation)
246
+ Tiles and tile information to stitch together a full image.
247
+ """
248
+ index = self._last_tile_index()
249
+ tiles = self.tile_cache[: index + 1]
250
+ tile_infos = self.tile_info_cache[: index + 1]
251
+ return tiles, tile_infos
252
+
253
+
254
+ class WriteTilesZarr(WriteStrategy):
255
+ """Strategy to write tiles to Zarr file."""
256
+
257
+ def write_batch(
258
+ self,
259
+ trainer: Trainer,
260
+ pl_module: LightningModule,
261
+ prediction: Any,
262
+ batch_indices: Sequence[int] | None,
263
+ batch: Any,
264
+ batch_idx: int,
265
+ dataloader_idx: int,
266
+ dirpath: Path,
267
+ ) -> None:
268
+ """
269
+ Write tiles to zarr file.
270
+
271
+ Parameters
272
+ ----------
273
+ trainer : Trainer
274
+ PyTorch Lightning Trainer.
275
+ pl_module : LightningModule
276
+ PyTorch Lightning LightningModule.
277
+ prediction : Any
278
+ Predictions on `batch`.
279
+ batch_indices : sequence of int
280
+ Indices identifying the samples in the batch.
281
+ batch : Any
282
+ Input batch.
283
+ batch_idx : int
284
+ Batch index.
285
+ dataloader_idx : int
286
+ Dataloader index.
287
+ dirpath : Path
288
+ Path to directory to save predictions to.
289
+
290
+ Raises
291
+ ------
292
+ NotImplementedError
293
+ """
294
+ raise NotImplementedError
295
+
296
+
297
+ class WriteImage(WriteStrategy):
298
+ """
299
+ A strategy for writing image predictions (i.e. un-tiled predictions).
300
+
301
+ Parameters
302
+ ----------
303
+ write_func : WriteFunc
304
+ Function used to save predictions.
305
+ write_extension : str
306
+ Extension added to prediction file paths.
307
+ write_func_kwargs : dict of {str: Any}
308
+ Extra kwargs to pass to `write_func`.
309
+
310
+ Attributes
311
+ ----------
312
+ write_func : WriteFunc
313
+ Function used to save predictions.
314
+ write_extension : str
315
+ Extension added to prediction file paths.
316
+ write_func_kwargs : dict of {str: Any}
317
+ Extra kwargs to pass to `write_func`.
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ write_func: WriteFunc,
323
+ write_extension: str,
324
+ write_func_kwargs: dict[str, Any],
325
+ ) -> None:
326
+ """
327
+ A strategy for writing image predictions (i.e. un-tiled predictions).
328
+
329
+ Parameters
330
+ ----------
331
+ write_func : WriteFunc
332
+ Function used to save predictions.
333
+ write_extension : str
334
+ Extension added to prediction file paths.
335
+ write_func_kwargs : dict of {str: Any}
336
+ Extra kwargs to pass to `write_func`.
337
+ """
338
+ super().__init__()
339
+
340
+ self.write_func: WriteFunc = write_func
341
+ self.write_extension: str = write_extension
342
+ self.write_func_kwargs: dict[str, Any] = write_func_kwargs
343
+
344
+ def write_batch(
345
+ self,
346
+ trainer: Trainer,
347
+ pl_module: LightningModule,
348
+ prediction: NDArray,
349
+ batch_indices: Sequence[int] | None,
350
+ batch: NDArray,
351
+ batch_idx: int,
352
+ dataloader_idx: int,
353
+ dirpath: Path,
354
+ ) -> None:
355
+ """
356
+ Save full images.
357
+
358
+ Parameters
359
+ ----------
360
+ trainer : Trainer
361
+ PyTorch Lightning Trainer.
362
+ pl_module : LightningModule
363
+ PyTorch Lightning LightningModule.
364
+ prediction : Any
365
+ Predictions on `batch`.
366
+ batch_indices : sequence of int
367
+ Indices identifying the samples in the batch.
368
+ batch : Any
369
+ Input batch.
370
+ batch_idx : int
371
+ Batch index.
372
+ dataloader_idx : int
373
+ Dataloader index.
374
+ dirpath : Path
375
+ Path to directory to save predictions to.
376
+
377
+ Raises
378
+ ------
379
+ TypeError
380
+ If trainer prediction dataset is not `IterablePredDataset`.
381
+ """
382
+ dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
383
+ dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls
384
+ ds: IterablePredDataset = dl.dataset
385
+ if not isinstance(ds, IterablePredDataset):
386
+ raise TypeError("Prediction dataset is not `IterablePredDataset`.")
387
+
388
+ for i in range(prediction.shape[0]):
389
+ prediction_image = prediction[0]
390
+ sample_id = batch_idx * dl.batch_size + i
391
+ input_file_path = get_sample_file_path(dataset=ds, sample_id=sample_id)
392
+ file_path = create_write_file_path(
393
+ dirpath=dirpath,
394
+ file_path=input_file_path,
395
+ write_extension=self.write_extension,
396
+ )
397
+ self.write_func(
398
+ file_path=file_path, img=prediction_image, **self.write_func_kwargs
399
+ )