careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,221 @@
1
+ """Generic UNet Lightning DataModule."""
2
+
3
+ from typing import Any, Union
4
+
5
+ import pytorch_lightning as L
6
+ import torch
7
+ from torch import nn
8
+ from torchmetrics import MetricCollection
9
+ from torchmetrics.image import PeakSignalNoiseRatio
10
+
11
+ from careamics.config import algorithm_factory
12
+ from careamics.config.algorithms import (
13
+ CAREAlgorithm,
14
+ N2NAlgorithm,
15
+ N2VAlgorithm,
16
+ PN2VAlgorithm,
17
+ )
18
+ from careamics.dataset_ng.dataset import ImageRegionData
19
+ from careamics.models.unet import UNet
20
+ from careamics.transforms import Denormalize
21
+ from careamics.utils.logging import get_logger
22
+ from careamics.utils.torch_utils import get_optimizer, get_scheduler
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ class UnetModule(L.LightningModule):
28
+ """CAREamics PyTorch Lightning module for UNet based algorithms.
29
+
30
+ Parameters
31
+ ----------
32
+ algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
33
+ Configuration for the algorithm, either as an instance of a specific algorithm
34
+ class or a dictionary that can be converted to an algorithm instance.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ algorithm_config: Union[
40
+ CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, PN2VAlgorithm, dict
41
+ ],
42
+ ) -> None:
43
+ """Instantiate UNet DataModule.
44
+
45
+ Parameters
46
+ ----------
47
+ algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
48
+ Configuration for the algorithm, either as an instance of a specific
49
+ algorithm class or a dictionary that can be converted to an algorithm
50
+ instance.
51
+ """
52
+ super().__init__()
53
+
54
+ if isinstance(algorithm_config, dict):
55
+ algorithm_config = algorithm_factory(algorithm_config)
56
+
57
+ self.config = algorithm_config
58
+ self.model: nn.Module = UNet(**algorithm_config.model.model_dump())
59
+
60
+ self._best_checkpoint_loaded = False
61
+
62
+ # TODO: how to support metric evaluation better
63
+ self.metrics = MetricCollection(PeakSignalNoiseRatio())
64
+
65
+ def forward(self, x: Any) -> Any:
66
+ """Default forward method.
67
+
68
+ Parameters
69
+ ----------
70
+ x : Any
71
+ Input data.
72
+
73
+ Returns
74
+ -------
75
+ Any
76
+ Output from the model.
77
+ """
78
+ return self.model(x)
79
+
80
+ def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
81
+ """Log training statistics.
82
+
83
+ Parameters
84
+ ----------
85
+ loss : Any
86
+ The loss value for the current training step.
87
+ batch_size : Any
88
+ The size of the batch used in the current training step.
89
+ """
90
+ self.log(
91
+ "train_loss",
92
+ loss,
93
+ on_step=True,
94
+ on_epoch=True,
95
+ prog_bar=True,
96
+ logger=True,
97
+ batch_size=batch_size,
98
+ )
99
+
100
+ optimizer = self.optimizers()
101
+ if isinstance(optimizer, list):
102
+ current_lr = optimizer[0].param_groups[0]["lr"]
103
+ else:
104
+ current_lr = optimizer.param_groups[0]["lr"]
105
+ self.log(
106
+ "learning_rate",
107
+ current_lr,
108
+ on_step=False,
109
+ on_epoch=True,
110
+ logger=True,
111
+ batch_size=batch_size,
112
+ )
113
+
114
+ def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
115
+ """Log validation statistics.
116
+
117
+ Parameters
118
+ ----------
119
+ loss : Any
120
+ The loss value for the current validation step.
121
+ batch_size : Any
122
+ The size of the batch used in the current validation step.
123
+ """
124
+ self.log(
125
+ "val_loss",
126
+ loss,
127
+ on_step=False,
128
+ on_epoch=True,
129
+ prog_bar=True,
130
+ logger=True,
131
+ batch_size=batch_size,
132
+ )
133
+ self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
134
+
135
+ def _load_best_checkpoint(self) -> None:
136
+ """Load the best checkpoint from the trainer's checkpoint callback."""
137
+ if (
138
+ not hasattr(self.trainer, "checkpoint_callback")
139
+ or self.trainer.checkpoint_callback is None
140
+ ):
141
+ logger.warning("No checkpoint callback found, cannot load best checkpoint.")
142
+ return
143
+
144
+ best_model_path = self.trainer.checkpoint_callback.best_model_path
145
+ if best_model_path and best_model_path != "":
146
+ logger.info(f"Loading best checkpoint from: {best_model_path}")
147
+ model_state = torch.load(best_model_path, weights_only=True)["state_dict"]
148
+ self.load_state_dict(model_state)
149
+ else:
150
+ logger.warning("No best checkpoint found.")
151
+
152
+ def predict_step(
153
+ self,
154
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
155
+ batch_idx: Any,
156
+ load_best_checkpoint=False,
157
+ ) -> Any:
158
+ """Default predict step.
159
+
160
+ Parameters
161
+ ----------
162
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
163
+ A tuple containing the input data and optionally the target data.
164
+ batch_idx : Any
165
+ The index of the current batch in the prediction loop.
166
+ load_best_checkpoint : bool, default=False
167
+ Whether to load the best checkpoint before making predictions.
168
+
169
+ Returns
170
+ -------
171
+ Any
172
+ The output batch containing the predictions.
173
+ """
174
+ if self._best_checkpoint_loaded is False and load_best_checkpoint:
175
+ self._load_best_checkpoint()
176
+ self._best_checkpoint_loaded = True
177
+
178
+ x = batch[0]
179
+ # TODO: add TTA
180
+ prediction = self.model(x.data).cpu().numpy()
181
+
182
+ means = self._trainer.datamodule.stats.means
183
+ stds = self._trainer.datamodule.stats.stds
184
+ denormalize = Denormalize(
185
+ image_means=means,
186
+ image_stds=stds,
187
+ )
188
+ denormalized_output = denormalize(prediction)
189
+
190
+ output_batch = ImageRegionData(
191
+ data=denormalized_output,
192
+ source=x.source,
193
+ data_shape=x.data_shape,
194
+ dtype=x.dtype,
195
+ axes=x.axes,
196
+ region_spec=x.region_spec,
197
+ additional_metadata={},
198
+ )
199
+ return output_batch
200
+
201
+ def configure_optimizers(self) -> Any:
202
+ """Configure optimizers.
203
+
204
+ Returns
205
+ -------
206
+ Any
207
+ A dictionary containing the optimizer and learning rate scheduler.
208
+ """
209
+ optimizer_func = get_optimizer(self.config.optimizer.name)
210
+ optimizer = optimizer_func(
211
+ self.model.parameters(), **self.config.optimizer.parameters
212
+ )
213
+
214
+ scheduler_func = get_scheduler(self.config.lr_scheduler.name)
215
+ scheduler = scheduler_func(optimizer, **self.config.lr_scheduler.parameters)
216
+
217
+ return {
218
+ "optimizer": optimizer,
219
+ "lr_scheduler": scheduler,
220
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
221
+ }
@@ -0,0 +1,16 @@
1
+ """Prediction utilities for the NG Dataset."""
2
+
3
+ __all__ = [
4
+ "combine_samples",
5
+ "convert_prediction",
6
+ "decollate_image_region_data",
7
+ "stitch_prediction",
8
+ "stitch_single_prediction",
9
+ ]
10
+
11
+ from .convert_prediction import (
12
+ combine_samples,
13
+ convert_prediction,
14
+ decollate_image_region_data,
15
+ )
16
+ from .stitch_prediction import stitch_prediction, stitch_single_prediction
@@ -0,0 +1,198 @@
1
+ """Module containing functions to convert prediction outputs to desired form."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.dataset_ng.dataset import ImageRegionData
9
+
10
+ from .stitch_prediction import group_tiles_by_key, stitch_prediction
11
+
12
+ if TYPE_CHECKING:
13
+ from torch import Tensor
14
+
15
+
16
+ def _decollate_batch_dict(
17
+ batched_dict: "dict[str, list | Tensor]",
18
+ index: int,
19
+ ) -> dict[str, int | tuple[int, ...]]:
20
+ """
21
+ Decollate element `index` from a batched_dict.
22
+
23
+ This method is only compatible with integer elements.
24
+
25
+ Parameters
26
+ ----------
27
+ batched_dict : dict of {str: list or Tensor}
28
+ Batch dictionary where each value is a list of elements of length B or a
29
+ Tensor of shape (B,).
30
+ index : int
31
+ Index of the element to extract.
32
+
33
+ Returns
34
+ -------
35
+ dict of {str: int | tuple[int, ...]}
36
+ Dictionary of the `index` element in the collated batch.
37
+ """
38
+ item_dict = {
39
+ key: (
40
+ # cast to int otherwise we have Tensor scalars
41
+ # TODO for additional types (e.g. axes in additional_metadata), we will need
42
+ # to handle it differently
43
+ tuple(int(value[idx][index]) for idx in range(len(value)))
44
+ if isinstance(value, list)
45
+ else int(value[index])
46
+ ) # handles tensor (1D) vs list of 1D tensors (2D)
47
+ for key, value in batched_dict.items()
48
+ }
49
+
50
+ return item_dict
51
+
52
+
53
+ def decollate_image_region_data(
54
+ batch: ImageRegionData,
55
+ ) -> list[ImageRegionData]:
56
+ """
57
+ Decollate a batch of `ImageRegionData` into a list of `ImageRegionData`.
58
+
59
+ Input batch has the following structure:
60
+ - data: (B, C, (Z), Y, X) numpy.ndarray
61
+ - source: sequence of str, length B
62
+ - data_shape: sequence of tuple of int, each tuple being of length B
63
+ - dtype: list of numpy.dtype, length B
64
+ - axes: list of str, length B
65
+ - region_spec: dict of {str: sequence}, each sequence being of length B
66
+ - additional_metadata: dict of {str: Any}, each sequence being of length B
67
+
68
+ Parameters
69
+ ----------
70
+ batch : ImageRegionData
71
+ Batch of `ImageRegionData`.
72
+
73
+ Returns
74
+ -------
75
+ list of ImageRegionData
76
+ List of `ImageRegionData`.
77
+ """
78
+ batch_size = batch.data.shape[0]
79
+ decollated: list[ImageRegionData] = []
80
+ for i in range(batch_size):
81
+ # unpack region spec irrespective of whether it is a PatchSpecs or TileSpecs
82
+ region_spec = _decollate_batch_dict(batch.region_spec, i)
83
+
84
+ # handle additional metadata
85
+ # currently only zarr chunks and shards may be stored there, as tuples.
86
+ # TODO if additional metadata becomes used for anything else, this function
87
+ # call may not be appropriate anymore.
88
+ additional_metadata = _decollate_batch_dict(batch.additional_metadata, i)
89
+
90
+ # data shape
91
+ assert isinstance(batch.data_shape, list)
92
+ data_shape = tuple(int(dim[i]) for dim in batch.data_shape)
93
+
94
+ image_region = ImageRegionData(
95
+ data=batch.data[i], # discard batch dimension
96
+ source=batch.source[i],
97
+ dtype=batch.dtype[i],
98
+ data_shape=data_shape,
99
+ axes=batch.axes[i],
100
+ region_spec=region_spec, # type: ignore
101
+ additional_metadata=additional_metadata,
102
+ )
103
+ decollated.append(image_region)
104
+
105
+ return decollated
106
+
107
+
108
+ def combine_samples(
109
+ predictions: list[ImageRegionData],
110
+ ) -> tuple[list[NDArray], list[str]]:
111
+ """
112
+ Combine predictions by `data_idx`.
113
+
114
+ Images are first grouped by their `data_idx` found in their `region_spec`, then
115
+ sorted by ascending `sample_idx` before being stacked along the `S` dimension.
116
+
117
+ Parameters
118
+ ----------
119
+ predictions : list of ImageRegionData
120
+ List of `ImageRegionData`.
121
+
122
+ Returns
123
+ -------
124
+ list of numpy.ndarray
125
+ List of combined predictions, one per unique `data_idx`.
126
+ list of str
127
+ List of sources, one per unique `data_idx`.
128
+ """
129
+ # group predictions by data idx
130
+ grouped_prediction: dict[int, list[ImageRegionData]] = group_tiles_by_key(
131
+ predictions, key="data_idx"
132
+ )
133
+
134
+ # sort predictions by sample idx
135
+ combined_predictions: list[NDArray] = []
136
+ combined_sources: list[str] = []
137
+ for data_idx in sorted(grouped_prediction.keys()):
138
+ image_regions = grouped_prediction[data_idx]
139
+ combined_sources.append(image_regions[0].source)
140
+
141
+ # sort by sample idx
142
+ image_regions.sort(key=lambda x: x.region_spec["sample_idx"])
143
+
144
+ # remove singleton dims and stack along S axis
145
+ combined_data = np.stack([img.data.squeeze() for img in image_regions], axis=0)
146
+ combined_predictions.append(combined_data)
147
+
148
+ return combined_predictions, combined_sources
149
+
150
+
151
+ def convert_prediction(
152
+ predictions: list[ImageRegionData],
153
+ tiled: bool,
154
+ ) -> tuple[list[NDArray], list[str]]:
155
+ """
156
+ Convert the Lightning trainer outputs to the desired form.
157
+
158
+ This method allows decollating batches and stitching back together tiled
159
+ predictions.
160
+
161
+ If the `source` of all predictions is "array" (see `InMemoryImageStack`), then the
162
+ returned sources list will be empty.
163
+
164
+ Parameters
165
+ ----------
166
+ predictions : list[ImageRegionData]
167
+ Output from `Trainer.predict`, list of batches.
168
+ tiled : bool
169
+ Whether the predictions are tiled.
170
+
171
+ Returns
172
+ -------
173
+ list of numpy.ndarray
174
+ List of arrays with the axes SC(Z)YX.
175
+ list of str
176
+ List of sources, one per output or empty if all equal to `array`.
177
+ """
178
+ # decollate batches
179
+ decollated_predictions: list[ImageRegionData] = []
180
+ for batch in predictions:
181
+ decollated_batch = decollate_image_region_data(batch)
182
+ decollated_predictions.extend(decollated_batch)
183
+
184
+ if not tiled and "total_tiles" in decollated_predictions[0].region_spec:
185
+ raise ValueError(
186
+ "Predictions contain `total_tiles` in region_spec but `tiled` is set to "
187
+ "False."
188
+ )
189
+
190
+ if tiled:
191
+ predictions_output, sources = stitch_prediction(decollated_predictions)
192
+ else:
193
+ predictions_output, sources = combine_samples(decollated_predictions)
194
+
195
+ if set(sources) == {"array"}:
196
+ sources = []
197
+
198
+ return predictions_output, sources
@@ -0,0 +1,171 @@
1
+ """Tiled prediction stitching utilities."""
2
+
3
+ import builtins
4
+ from collections import defaultdict
5
+ from typing import Literal
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from careamics.dataset_ng.dataset import ImageRegionData
11
+ from careamics.dataset_ng.patching_strategies import TileSpecs
12
+
13
+
14
+ def group_tiles_by_key(
15
+ tiles: list[ImageRegionData], key: Literal["data_idx", "sample_idx"]
16
+ ) -> dict[int, list[ImageRegionData]]:
17
+ """
18
+ Sort tiles by key.
19
+
20
+ Parameters
21
+ ----------
22
+ tiles : list of ImageRegionData
23
+ List of tiles to sort.
24
+ key : {'data_idx', 'sample_idx'}
25
+ Key to group tiles by.
26
+
27
+ Returns
28
+ -------
29
+ {int: list of ImageRegionData}
30
+ Dictionary mapping data indices to lists of tiles.
31
+ """
32
+ sorted_tiles: dict[int, list[ImageRegionData]] = defaultdict(list)
33
+ for tile in tiles:
34
+ key_value = tile.region_spec[key]
35
+ sorted_tiles[key_value].append(tile)
36
+ return sorted_tiles
37
+
38
+
39
+ def stitch_prediction(
40
+ tiles: list[ImageRegionData],
41
+ ) -> tuple[list[NDArray], list[str]]:
42
+ """
43
+ Stitch tiles back together to form full images.
44
+
45
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
46
+ singleton dimension.
47
+
48
+ Parameters
49
+ ----------
50
+ tiles : list of ImageRegionData
51
+ Cropped tiles and their respective stitching coordinates. Can contain tiles
52
+ from multiple images.
53
+
54
+ Returns
55
+ -------
56
+ list of numpy.ndarray
57
+ Full images, may be a single image.
58
+ list of str
59
+ List of sources, one per output.
60
+ """
61
+ # sort tiles by data index
62
+ grouped_tiles: dict[int, list[ImageRegionData]] = group_tiles_by_key(
63
+ tiles, key="data_idx"
64
+ )
65
+
66
+ # stitch each image separately
67
+ image_predictions: list[NDArray] = []
68
+ image_sources: list[str] = []
69
+ for data_idx in sorted(grouped_tiles.keys()):
70
+ image_predictions.append(stitch_single_prediction(grouped_tiles[data_idx]))
71
+ image_sources.append(grouped_tiles[data_idx][0].source)
72
+
73
+ return image_predictions, image_sources
74
+
75
+
76
+ def stitch_single_prediction(
77
+ tiles: list[ImageRegionData],
78
+ ) -> NDArray:
79
+ """
80
+ Stitch tiles back together to form a full image.
81
+
82
+ Tiles are of dimensions C(Z)YX, where C is the number of channels and can be a
83
+ singleton dimension.
84
+
85
+ Parameters
86
+ ----------
87
+ tiles : list of ImageRegionData
88
+ Cropped tiles and their respective stitching coordinates.
89
+
90
+ Returns
91
+ -------
92
+ numpy.ndarray
93
+ Full image, with dimensions SC(Z)YX.
94
+ """
95
+ data_shape = tiles[0].data_shape
96
+ predicted_image = np.zeros(data_shape, dtype=np.float32)
97
+
98
+ if "S" in tiles[0].axes:
99
+ tiles_by_sample = group_tiles_by_key(tiles, key="sample_idx")
100
+ for sample_idx in tiles_by_sample.keys():
101
+ sample_tiles = tiles_by_sample[sample_idx]
102
+ stitched_sample = stitch_single_sample(sample_tiles)
103
+
104
+ # compute sample slice
105
+ sample_slice = slice(
106
+ sample_idx,
107
+ sample_idx + 1,
108
+ )
109
+
110
+ # insert stitched sample into predicted image
111
+ predicted_image[sample_slice] = stitched_sample.astype(np.float32)
112
+ else:
113
+ # stitch as a single sample
114
+ # predicted_image has singleton sample dimension
115
+ predicted_image[0] = stitch_single_sample(tiles)
116
+
117
+ return predicted_image
118
+
119
+
120
+ def stitch_single_sample(
121
+ tiles: list[ImageRegionData],
122
+ ) -> NDArray:
123
+ """
124
+ Stitch tiles back together to form a full sample.
125
+
126
+ Tiles are of dimensions C(Z)YX, where C is the number of channels and can be a
127
+ singleton dimension.
128
+
129
+ Parameters
130
+ ----------
131
+ tiles : list of ImageRegionData
132
+ Cropped tiles and their respective stitching coordinates.
133
+
134
+ Returns
135
+ -------
136
+ numpy.ndarray
137
+ Full sample, with dimensions C(Z)YX.
138
+ """
139
+ data_shape = tiles[0].data_shape # SC(Z)YX
140
+ predicted_sample = np.zeros(data_shape[1:], dtype=np.float32)
141
+
142
+ for tile in tiles:
143
+ # compute crop coordinates and stitiching coordinates
144
+ tile_spec: TileSpecs = tile.region_spec # type: ignore
145
+ crop_coords = tile_spec["crop_coords"]
146
+ crop_size = tile_spec["crop_size"]
147
+ stitch_coords = tile_spec["stitch_coords"]
148
+
149
+ crop_slices: tuple[builtins.ellipsis | slice, ...] = (
150
+ ...,
151
+ *[
152
+ slice(start, start + length)
153
+ for start, length in zip(crop_coords, crop_size, strict=True)
154
+ ],
155
+ )
156
+
157
+ stitch_slices: tuple[builtins.ellipsis | slice, ...] = (
158
+ ...,
159
+ *[
160
+ slice(start, start + length)
161
+ for start, length in zip(stitch_coords, crop_size, strict=True)
162
+ ],
163
+ )
164
+
165
+ # crop predited tile according to overlap coordinates
166
+ cropped_tile = tile.data[crop_slices]
167
+
168
+ # insert cropped tile into predicted image
169
+ predicted_sample[stitch_slices] = cropped_tile.astype(np.float32)
170
+
171
+ return predicted_sample