careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,219 @@
1
+ """Module containing `PredictionWriterCallback` class."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from pytorch_lightning import LightningModule, Trainer
10
+ from pytorch_lightning.callbacks import BasePredictionWriter
11
+
12
+ from careamics.dataset_ng.dataset import ImageRegionData
13
+ from careamics.file_io.write.get_func import SupportedWriteType, WriteFunc
14
+ from careamics.lightning.dataset_ng.prediction import decollate_image_region_data
15
+ from careamics.utils import get_logger
16
+
17
+ from .write_strategy import WriteStrategy
18
+ from .write_strategy_factory import create_write_strategy
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class PredictionWriterCallback(BasePredictionWriter):
24
+ """
25
+ PyTorch Lightning callback to save predictions.
26
+
27
+ A `WriteStrategy` must be provided at instantiation or later via
28
+ `set_writing_strategy`.
29
+
30
+ Parameters
31
+ ----------
32
+ dirpath : Path or str, default="predictions"
33
+ The path to the directory where prediction outputs will be saved. If
34
+ `dirpath` is not absolute it is assumed to be relative to current working
35
+ directory.
36
+ write_strategy : WriteStrategy or None, default=None
37
+ A strategy for writing predictions.
38
+
39
+ Attributes
40
+ ----------
41
+ writing_predictions : bool
42
+ If writing predictions is turned on or off.
43
+ dirpath : pathlib.Path, default=""
44
+ The path to the directory where prediction outputs will be saved. If
45
+ `dirpath` is not absolute it is assumed to be relative to current working
46
+ directory.
47
+ write_strategy : WriteStrategy or None
48
+ A strategy for writing predictions.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ dirpath: Path | str = "",
54
+ write_strategy: WriteStrategy | None = None,
55
+ ):
56
+ """
57
+ Constructor.
58
+
59
+ A `WriteStrategy` must be provided at instantiation or later via
60
+ `set_writing_strategy`.
61
+
62
+ Parameters
63
+ ----------
64
+ dirpath : pathlib.Path or str, default="predictions"
65
+ The path to the directory where prediction outputs will be saved. If
66
+ `dirpath` is not absolute it is assumed to be relative to current working
67
+ directory.
68
+ write_strategy : WriteStrategy or None, default=None
69
+ A strategy for writing predictions.
70
+ """
71
+ super().__init__(write_interval="batch")
72
+
73
+ self.writing_predictions = True # flag to turn off predictions
74
+
75
+ # forward declaration
76
+ self.write_strategy: WriteStrategy
77
+ if write_strategy is not None: # avoid `WriteStrategy | None` type
78
+ self.write_strategy = write_strategy
79
+
80
+ self.dirpath: Path
81
+
82
+ # if a dirpath is provided, initialize it
83
+ # in some cases (e.g. zarr), destination is provided by the zarr store path
84
+ if dirpath != "":
85
+ self._init_dirpath(dirpath)
86
+
87
+ def disable_writing(self, disable_writing: bool) -> None:
88
+ """Disable writing.
89
+
90
+ Parameters
91
+ ----------
92
+ disable_writing : bool
93
+ If writing predictions should be disabled.
94
+ """
95
+ self.writing_predictions = disable_writing
96
+
97
+ def _init_dirpath(self, dirpath):
98
+ """
99
+ Initialize directory path. Should only be called from `__init__`.
100
+
101
+ Parameters
102
+ ----------
103
+ dirpath : pathlib.Path
104
+ See `__init__` description.
105
+ """
106
+ dirpath = Path(dirpath)
107
+ if not dirpath.is_absolute():
108
+ dirpath = Path.cwd() / dirpath
109
+ logger.warning(
110
+ "Prediction output directory is not absolute, absolute path assumed to"
111
+ f"be '{dirpath}'"
112
+ )
113
+ self.dirpath = dirpath
114
+
115
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
116
+ """
117
+ Create the prediction output directory when predict begins.
118
+
119
+ Called when fit, validate, test, predict, or tune begins.
120
+
121
+ Parameters
122
+ ----------
123
+ trainer : Trainer
124
+ PyTorch Lightning trainer.
125
+ pl_module : LightningModule
126
+ PyTorch Lightning module.
127
+ stage : str
128
+ Stage of training e.g. 'predict', 'fit', 'validate'.
129
+ """
130
+ super().setup(trainer, pl_module, stage)
131
+ if stage == "predict":
132
+ if self.dirpath is not None:
133
+ # make prediction output directory
134
+ logger.info("Making prediction output directory.")
135
+ self.dirpath.mkdir(parents=True, exist_ok=True)
136
+
137
+ def set_writing_strategy(
138
+ self,
139
+ write_type: SupportedWriteType,
140
+ tiled: bool,
141
+ write_func: WriteFunc | None = None,
142
+ write_extension: str | None = None,
143
+ write_func_kwargs: dict[str, Any] | None = None,
144
+ ) -> None:
145
+ """
146
+ Set the writing strategy.
147
+
148
+ Must be called before writing predictions.
149
+
150
+ Parameters
151
+ ----------
152
+ write_type : SupportedWriteType
153
+ The type of writing to perform.
154
+ tiled : bool
155
+ Whether to write in tiled format.
156
+ write_func : WriteFunc or None, default=None
157
+ A custom writing function.
158
+ write_extension : str or None, default=None
159
+ The file extension to use when writing files.
160
+ write_func_kwargs : dict of str to Any, default=None
161
+ Additional keyword arguments to pass to `write_func`.
162
+ """
163
+ self.write_strategy = create_write_strategy(
164
+ write_type=write_type,
165
+ tiled=tiled,
166
+ write_func=write_func,
167
+ write_extension=write_extension,
168
+ write_func_kwargs=write_func_kwargs,
169
+ )
170
+
171
+ def write_on_batch_end(
172
+ self,
173
+ trainer: Trainer,
174
+ pl_module: LightningModule,
175
+ prediction: ImageRegionData,
176
+ batch_indices: Sequence[int] | None,
177
+ batch: ImageRegionData,
178
+ batch_idx: int,
179
+ dataloader_idx: int,
180
+ ) -> None:
181
+ """
182
+ Write predictions at the end of a batch.
183
+
184
+ Writing method is determined by the attribute `write_strategy`.
185
+
186
+ Parameters
187
+ ----------
188
+ trainer : Trainer
189
+ PyTorch Lightning trainer.
190
+ pl_module : LightningModule
191
+ PyTorch Lightning module.
192
+ prediction : ImageRegionData
193
+ Prediction outputs of `batch`.
194
+ batch_indices : sequence of Any, optional
195
+ Batch indices.
196
+ batch : ImageRegionData
197
+ Input batch.
198
+ batch_idx : int
199
+ Batch index.
200
+ dataloader_idx : int
201
+ Dataloader index.
202
+ """
203
+ # if writing prediction is turned off
204
+ if not self.writing_predictions:
205
+ return
206
+
207
+ if self.write_strategy is not None:
208
+ assert prediction is not None
209
+ predictions = decollate_image_region_data(prediction)
210
+
211
+ self.write_strategy.write_batch(
212
+ dirpath=self.dirpath,
213
+ predictions=predictions,
214
+ )
215
+ else:
216
+ raise RuntimeError(
217
+ "No write strategy defined for `PredictionWriterCallback`, cannot write"
218
+ " predictions. Call `set_writing_strategy` to pass a write strategy."
219
+ )
@@ -0,0 +1,91 @@
1
+ """A strategy writing whole images directly."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from careamics.dataset_ng.dataset import ImageRegionData
7
+ from careamics.file_io import WriteFunc
8
+ from careamics.lightning.dataset_ng.prediction import (
9
+ combine_samples,
10
+ )
11
+
12
+ from .file_path_utils import create_write_file_path
13
+ from .write_strategy import WriteStrategy
14
+
15
+
16
+ # TODO bug: batch is over samples for whole images, if one batch does not cover
17
+ # all samples, it will write an incomplete image, then overwrite it whith the next
18
+ # batch
19
+ class WriteImage(WriteStrategy):
20
+ """
21
+ A strategy for writing image predictions (i.e. un-tiled predictions).
22
+
23
+ Parameters
24
+ ----------
25
+ write_func : WriteFunc
26
+ Function used to save predictions.
27
+ write_extension : str
28
+ Extension added to prediction file paths.
29
+ write_func_kwargs : dict of {str: Any}
30
+ Extra kwargs to pass to `write_func`.
31
+
32
+ Attributes
33
+ ----------
34
+ write_func : WriteFunc
35
+ Function used to save predictions.
36
+ write_extension : str
37
+ Extension added to prediction file paths.
38
+ write_func_kwargs : dict of {str: Any}
39
+ Extra kwargs to pass to `write_func`.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ write_func: WriteFunc,
45
+ write_extension: str,
46
+ write_func_kwargs: dict[str, Any],
47
+ ) -> None:
48
+ """
49
+ A strategy for writing image predictions (i.e. un-tiled predictions).
50
+
51
+ Parameters
52
+ ----------
53
+ write_func : WriteFunc
54
+ Function used to save predictions.
55
+ write_extension : str
56
+ Extension added to prediction file paths.
57
+ write_func_kwargs : dict of {str: Any}
58
+ Extra kwargs to pass to `write_func`.
59
+ """
60
+ super().__init__()
61
+
62
+ self.write_func: WriteFunc = write_func
63
+ self.write_extension: str = write_extension
64
+ self.write_func_kwargs: dict[str, Any] = write_func_kwargs
65
+
66
+ def write_batch(
67
+ self,
68
+ dirpath: Path,
69
+ predictions: list[ImageRegionData],
70
+ ) -> None:
71
+ """
72
+ Save full images.
73
+
74
+ Parameters
75
+ ----------
76
+ dirpath : Path
77
+ Path to directory to save predictions to.
78
+ predictions : list[ImageRegionData]
79
+ Decollated predictions.
80
+ """
81
+ assert predictions is not None
82
+
83
+ image_lst, sources = combine_samples(predictions)
84
+
85
+ for i, image in enumerate(image_lst):
86
+ file_path = create_write_file_path(
87
+ dirpath=dirpath,
88
+ file_path=Path(sources[i]),
89
+ write_extension=self.write_extension,
90
+ )
91
+ self.write_func(file_path=file_path, img=image, **self.write_func_kwargs)
@@ -0,0 +1,27 @@
1
+ """Module containing different strategies for writing predictions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Protocol
5
+
6
+ from careamics.dataset_ng.dataset import ImageRegionData
7
+
8
+
9
+ class WriteStrategy(Protocol):
10
+ """Protocol for write strategy classes."""
11
+
12
+ def write_batch(
13
+ self,
14
+ dirpath: Path,
15
+ predictions: list[ImageRegionData],
16
+ ) -> None:
17
+ """
18
+ WriteStrategy subclasses must contain this function to write a batch.
19
+
20
+ Parameters
21
+ ----------
22
+ dirpath : Path
23
+ Path to directory to save predictions to.
24
+ predictions : list[ImageRegionData]
25
+ Decollated predictions.
26
+ """
27
+ ...
@@ -0,0 +1,214 @@
1
+ """Module containing convenience function to create `WriteStrategy`."""
2
+
3
+ from typing import Any
4
+
5
+ from careamics.config.support import SupportedData
6
+ from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
7
+
8
+ from .cached_tiles_strategy import CachedTiles
9
+ from .write_image_strategy import WriteImage
10
+ from .write_strategy import WriteStrategy
11
+ from .write_tiles_zarr_strategy import WriteTilesZarr
12
+
13
+
14
+ def create_write_strategy(
15
+ write_type: SupportedWriteType,
16
+ tiled: bool,
17
+ write_func: WriteFunc | None = None,
18
+ write_extension: str | None = None,
19
+ write_func_kwargs: dict[str, Any] | None = None,
20
+ ) -> WriteStrategy:
21
+ """
22
+ Create a write strategy from convenient parameters.
23
+
24
+ Parameters
25
+ ----------
26
+ write_type : {"tiff", "zarr", "custom"}
27
+ The data type to save as, includes custom.
28
+ tiled : bool
29
+ Whether the prediction will be tiled or not.
30
+ write_func : WriteFunc, optional
31
+ If a known `write_type` is selected this argument is ignored. For a custom
32
+ `write_type` a function to save the data must be passed. See notes below.
33
+ write_extension : str, optional
34
+ If a known `write_type` is selected this argument is ignored. For a custom
35
+ `write_type` an extension to save the data with must be passed.
36
+ write_func_kwargs : dict of {str: any}, optional
37
+ Additional keyword arguments to be passed to the save function.
38
+
39
+ Returns
40
+ -------
41
+ WriteStrategy
42
+ A strategy for writing predicions.
43
+
44
+ Notes
45
+ -----
46
+ The `write_func` function signature must match that of the example below
47
+ ```
48
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
49
+ ```
50
+
51
+ The `write_func_kwargs` will be passed to the `write_func` doing the following:
52
+ ```
53
+ write_func(file_path=file_path, img=img, **kwargs)
54
+ ```
55
+ """
56
+ if write_func_kwargs is None:
57
+ write_func_kwargs = {}
58
+
59
+ write_strategy: WriteStrategy
60
+ if not tiled:
61
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
62
+ write_extension = select_write_extension(
63
+ write_type=write_type, write_extension=write_extension
64
+ )
65
+ write_strategy = WriteImage(
66
+ write_func=write_func,
67
+ write_extension=write_extension,
68
+ write_func_kwargs=write_func_kwargs,
69
+ )
70
+ else:
71
+ # select CacheTiles or ZarrTiles
72
+ write_strategy = _create_tiled_write_strategy(
73
+ write_type=write_type,
74
+ write_func=write_func,
75
+ write_extension=write_extension,
76
+ write_func_kwargs=write_func_kwargs,
77
+ )
78
+
79
+ return write_strategy
80
+
81
+
82
+ def _create_tiled_write_strategy(
83
+ write_type: SupportedWriteType,
84
+ write_func: WriteFunc | None,
85
+ write_extension: str | None,
86
+ write_func_kwargs: dict[str, Any],
87
+ ) -> WriteStrategy:
88
+ """
89
+ Create a tiled write strategy.
90
+
91
+ Either `CacheTiles` for caching tiles until a whole image is predicted or
92
+ `ZarrTiles` for writing tiles directly to disk.
93
+
94
+ Parameters
95
+ ----------
96
+ write_type : {"tiff", "zarr", "custom"}
97
+ The data type to save as, includes custom.
98
+ write_func : WriteFunc, optional
99
+ If a known `write_type` is selected this argument is ignored. For a custom
100
+ `write_type` a function to save the data must be passed. See notes below.
101
+ write_extension : str, optional
102
+ If a known `write_type` is selected this argument is ignored. For a custom
103
+ `write_type` an extension to save the data with must be passed.
104
+ write_func_kwargs : dict of {str: any}
105
+ Additional keyword arguments to be passed to the save function.
106
+
107
+ Returns
108
+ -------
109
+ WriteStrategy
110
+ A strategy for writing tiled predictions.
111
+
112
+ Raises
113
+ ------
114
+ NotImplementedError
115
+ if `write_type="zarr" is chosen.
116
+ """
117
+ if write_type == "zarr":
118
+ return WriteTilesZarr()
119
+ else:
120
+ write_func = select_write_func(write_type=write_type, write_func=write_func)
121
+ write_extension = select_write_extension(
122
+ write_type=write_type, write_extension=write_extension
123
+ )
124
+ return CachedTiles(
125
+ write_func=write_func,
126
+ write_extension=write_extension,
127
+ write_func_kwargs=write_func_kwargs,
128
+ )
129
+
130
+
131
+ def select_write_func(
132
+ write_type: SupportedWriteType, write_func: WriteFunc | None = None
133
+ ) -> WriteFunc:
134
+ """
135
+ Return a function to write images.
136
+
137
+ If `write_type` is "custom" then `write_func`, otherwise the known write function
138
+ is selected.
139
+
140
+ Parameters
141
+ ----------
142
+ write_type : {"tiff", "custom"}
143
+ The data type to save as, includes custom.
144
+ write_func : WriteFunc, optional
145
+ If a known `write_type` is selected this argument is ignored. For a custom
146
+ `write_type` a function to save the data must be passed. See notes below.
147
+
148
+ Returns
149
+ -------
150
+ WriteFunc
151
+ A function for writing images.
152
+
153
+ Raises
154
+ ------
155
+ ValueError
156
+ If `write_type="custom"` but `write_func` has not been given.
157
+
158
+ Notes
159
+ -----
160
+ The `write_func` function signature must match that of the example below
161
+ ```
162
+ write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
163
+ ```
164
+ """
165
+ if write_type == SupportedData.CUSTOM:
166
+ if write_func is None:
167
+ raise ValueError(
168
+ "A save function must be provided for custom data types."
169
+ # TODO: link to how save functions should be implemented
170
+ )
171
+ else:
172
+ write_func = write_func
173
+ else:
174
+ write_func = get_write_func(write_type)
175
+ return write_func
176
+
177
+
178
+ def select_write_extension(
179
+ write_type: SupportedWriteType, write_extension: str | None = None
180
+ ) -> str:
181
+ """
182
+ Return an extension to add to file paths.
183
+
184
+ If `write_type` is "custom" then `write_extension`, otherwise the known
185
+ write extension is selected.
186
+
187
+ Parameters
188
+ ----------
189
+ write_type : {"tiff", "custom"}
190
+ The data type to save as, includes custom.
191
+ write_extension : str, optional
192
+ If a known `write_type` is selected this argument is ignored. For a custom
193
+ `write_type` an extension to save the data with must be passed.
194
+
195
+ Returns
196
+ -------
197
+ str
198
+ The extension to be added to file paths.
199
+
200
+ Raises
201
+ ------
202
+ ValueError
203
+ If `self.save_type="custom"` but `save_extension` has not been given.
204
+ """
205
+ write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
206
+ if write_type_ == SupportedData.CUSTOM:
207
+ if write_extension is None:
208
+ raise ValueError("A save extension must be provided for custom data types.")
209
+ else:
210
+ write_extension = write_extension
211
+ else:
212
+ # kind of a weird pattern -> reason to move get_extension from SupportedData
213
+ write_extension = write_type_.get_extension(write_type_)
214
+ return write_extension