careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
careamics/careamist.py ADDED
@@ -0,0 +1,961 @@
1
+ """A class to train, predict and export models in CAREamics."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Any, Literal, Union, overload
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+ from pytorch_lightning import Trainer
10
+ from pytorch_lightning.callbacks import (
11
+ Callback,
12
+ EarlyStopping,
13
+ ModelCheckpoint,
14
+ )
15
+ from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
16
+
17
+ from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
18
+ from careamics.config.support import (
19
+ SupportedAlgorithm,
20
+ SupportedArchitecture,
21
+ SupportedData,
22
+ SupportedLogger,
23
+ )
24
+ from careamics.dataset.dataset_utils import list_files, reshape_array
25
+ from careamics.file_io import WriteFunc, get_write_func
26
+ from careamics.lightning import (
27
+ FCNModule,
28
+ HyperParametersCallback,
29
+ PredictDataModule,
30
+ ProgressBarCallback,
31
+ TrainDataModule,
32
+ create_predict_datamodule,
33
+ )
34
+ from careamics.model_io import export_to_bmz, load_pretrained
35
+ from careamics.prediction_utils import convert_outputs
36
+ from careamics.utils import check_path_exists, get_logger
37
+ from careamics.utils.lightning_utils import read_csv_logger
38
+
39
+ logger = get_logger(__name__)
40
+
41
+ LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
42
+
43
+
44
+ # TODO type ignore have been added because of the czi data type in data configuration
45
+ class CAREamist:
46
+ """Main CAREamics class, allowing training and prediction using various algorithms.
47
+
48
+ Parameters
49
+ ----------
50
+ source : pathlib.Path or str or CAREamics Configuration
51
+ Path to a configuration file or a trained model.
52
+ work_dir : str, optional
53
+ Path to working directory in which to save checkpoints and logs,
54
+ by default None.
55
+ callbacks : list of Callback, optional
56
+ List of callbacks to use during training and prediction, by default None.
57
+ enable_progress_bar : bool
58
+ Whether a progress bar will be displayed during training, validation and
59
+ prediction.
60
+
61
+ Attributes
62
+ ----------
63
+ model : CAREamicsModule
64
+ CAREamics model.
65
+ cfg : Configuration
66
+ CAREamics configuration.
67
+ trainer : Trainer
68
+ PyTorch Lightning trainer.
69
+ experiment_logger : TensorBoardLogger or WandbLogger
70
+ Experiment logger, "wandb" or "tensorboard".
71
+ work_dir : pathlib.Path
72
+ Working directory.
73
+ train_datamodule : TrainDataModule
74
+ Training datamodule.
75
+ pred_datamodule : PredictDataModule
76
+ Prediction datamodule.
77
+ """
78
+
79
+ @overload
80
+ def __init__( # numpydoc ignore=GL08
81
+ self,
82
+ source: Union[Path, str],
83
+ work_dir: Union[Path, str] | None = None,
84
+ callbacks: list[Callback] | None = None,
85
+ enable_progress_bar: bool = True,
86
+ ) -> None: ...
87
+
88
+ @overload
89
+ def __init__( # numpydoc ignore=GL08
90
+ self,
91
+ source: Configuration,
92
+ work_dir: Union[Path, str] | None = None,
93
+ callbacks: list[Callback] | None = None,
94
+ enable_progress_bar: bool = True,
95
+ ) -> None: ...
96
+
97
+ def __init__(
98
+ self,
99
+ source: Union[Path, str, Configuration],
100
+ work_dir: Union[Path, str] | None = None,
101
+ callbacks: list[Callback] | None = None,
102
+ enable_progress_bar: bool = True,
103
+ ) -> None:
104
+ """
105
+ Initialize CAREamist with a configuration object or a path.
106
+
107
+ A configuration object can be created using directly by calling `Configuration`,
108
+ using the configuration factory or loading a configuration from a yaml file.
109
+
110
+ Path can contain either a yaml file with parameters, or a saved checkpoint.
111
+
112
+ If no working directory is provided, the current working directory is used.
113
+
114
+ Parameters
115
+ ----------
116
+ source : pathlib.Path or str or CAREamics Configuration
117
+ Path to a configuration file or a trained model.
118
+ work_dir : str or pathlib.Path, optional
119
+ Path to working directory in which to save checkpoints and logs,
120
+ by default None.
121
+ callbacks : list of Callback, optional
122
+ List of callbacks to use during training and prediction, by default None.
123
+ enable_progress_bar : bool
124
+ Whether a progress bar will be displayed during training, validation and
125
+ prediction.
126
+
127
+ Raises
128
+ ------
129
+ NotImplementedError
130
+ If the model is loaded from BioImage Model Zoo.
131
+ ValueError
132
+ If no hyper parameters are found in the checkpoint.
133
+ ValueError
134
+ If no data module hyper parameters are found in the checkpoint.
135
+ """
136
+ # select current working directory if work_dir is None
137
+ if work_dir is None:
138
+ self.work_dir = Path.cwd()
139
+ logger.warning(
140
+ f"No working directory provided. Using current working directory: "
141
+ f"{self.work_dir}."
142
+ )
143
+ else:
144
+ self.work_dir = Path(work_dir)
145
+
146
+ # configuration object
147
+ if isinstance(source, Configuration):
148
+ self.cfg = source
149
+
150
+ # instantiate model
151
+ if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
152
+ self.model = FCNModule(
153
+ algorithm_config=self.cfg.algorithm_config,
154
+ )
155
+ else:
156
+ raise NotImplementedError("Architecture not supported.")
157
+
158
+ # path to configuration file or model
159
+ else:
160
+ # TODO: update this check so models can be downloaded directly from BMZ
161
+ source = check_path_exists(source)
162
+
163
+ # configuration file
164
+ if source.is_file() and (
165
+ source.suffix == ".yaml" or source.suffix == ".yml"
166
+ ):
167
+ # load configuration
168
+ self.cfg = load_configuration(source)
169
+
170
+ # instantiate model
171
+ if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
172
+ self.model = FCNModule(
173
+ algorithm_config=self.cfg.algorithm_config,
174
+ ) # type: ignore
175
+ else:
176
+ raise NotImplementedError("Architecture not supported.")
177
+
178
+ # attempt loading a pre-trained model
179
+ else:
180
+ self.model, self.cfg = load_pretrained(source)
181
+
182
+ # define the checkpoint saving callback
183
+ self._define_callbacks(callbacks, enable_progress_bar)
184
+
185
+ # instantiate logger
186
+ csv_logger = CSVLogger(
187
+ name=self.cfg.experiment_name,
188
+ save_dir=self.work_dir / "csv_logs",
189
+ )
190
+
191
+ if self.cfg.training_config.has_logger():
192
+ if self.cfg.training_config.logger == SupportedLogger.WANDB:
193
+ experiment_logger: LOGGER_TYPES = [
194
+ WandbLogger(
195
+ name=self.cfg.experiment_name,
196
+ save_dir=self.work_dir / Path("wandb_logs"),
197
+ ),
198
+ csv_logger,
199
+ ]
200
+ elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
201
+ experiment_logger = [
202
+ TensorBoardLogger(
203
+ save_dir=self.work_dir / Path("tb_logs"),
204
+ ),
205
+ csv_logger,
206
+ ]
207
+ else:
208
+ experiment_logger = [csv_logger]
209
+
210
+ # instantiate trainer
211
+ self.trainer = Trainer(
212
+ enable_progress_bar=enable_progress_bar,
213
+ callbacks=self.callbacks,
214
+ default_root_dir=self.work_dir,
215
+ logger=experiment_logger,
216
+ **self.cfg.training_config.lightning_trainer_config or {},
217
+ )
218
+
219
+ # place holder for the datamodules
220
+ self.train_datamodule: TrainDataModule | None = None
221
+ self.pred_datamodule: PredictDataModule | None = None
222
+
223
+ def _define_callbacks(
224
+ self, callbacks: list[Callback] | None, enable_progress_bar: bool
225
+ ) -> None:
226
+ """Define the callbacks for the training loop.
227
+
228
+ Parameters
229
+ ----------
230
+ callbacks : list of Callback, optional
231
+ List of callbacks to use during training and prediction, by default None.
232
+ enable_progress_bar : bool
233
+ Whether a progress bar will be displayed during training, validation and
234
+ prediction. It controls whether a `ProgressBarCallback` is added to the
235
+ callback list.
236
+ """
237
+ self.callbacks = [] if callbacks is None else callbacks
238
+
239
+ # check that user callbacks are not any of the CAREamics callbacks
240
+ for c in self.callbacks:
241
+ if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
242
+ raise ValueError(
243
+ "ModelCheckpoint and EarlyStopping callbacks are already defined "
244
+ "in CAREamics and should only be modified through the "
245
+ "training configuration (see TrainingConfig)."
246
+ )
247
+
248
+ if isinstance(c, HyperParametersCallback) or isinstance(
249
+ c, ProgressBarCallback
250
+ ):
251
+ raise ValueError(
252
+ "HyperParameter and ProgressBar callbacks are defined internally "
253
+ "and should not be passed as callbacks."
254
+ )
255
+
256
+ # checkpoint callback saves checkpoints during training
257
+ self.callbacks.extend(
258
+ [
259
+ HyperParametersCallback(self.cfg),
260
+ ModelCheckpoint(
261
+ dirpath=self.work_dir / Path("checkpoints"),
262
+ filename=f"{self.cfg.experiment_name}_{{epoch:02d}}_step_{{step}}",
263
+ **self.cfg.training_config.checkpoint_callback.model_dump(),
264
+ ),
265
+ ]
266
+ )
267
+ if enable_progress_bar:
268
+ self.callbacks.append(ProgressBarCallback())
269
+
270
+ # early stopping callback
271
+ if self.cfg.training_config.early_stopping_callback is not None:
272
+ self.callbacks.append(
273
+ EarlyStopping(self.cfg.training_config.early_stopping_callback)
274
+ )
275
+
276
+ def stop_training(self) -> None:
277
+ """Stop the training loop."""
278
+ # raise stop training flag
279
+ self.trainer.should_stop = True
280
+ self.trainer.limit_val_batches = 0 # skip validation
281
+
282
+ # TODO: is there are more elegant way than calling train again after _train_on_paths
283
+ def train(
284
+ self,
285
+ *,
286
+ datamodule: TrainDataModule | None = None,
287
+ train_source: Union[Path, str, NDArray] | None = None,
288
+ val_source: Union[Path, str, NDArray] | None = None,
289
+ train_target: Union[Path, str, NDArray] | None = None,
290
+ val_target: Union[Path, str, NDArray] | None = None,
291
+ use_in_memory: bool = True,
292
+ val_percentage: float = 0.1,
293
+ val_minimum_split: int = 1,
294
+ ) -> None:
295
+ """
296
+ Train the model on the provided data.
297
+
298
+ If a datamodule is provided, then training will be performed using it.
299
+ Alternatively, the training data can be provided as arrays or paths.
300
+
301
+ If `use_in_memory` is set to True, the source provided as Path or str will be
302
+ loaded in memory if it fits. Otherwise, training will be performed by loading
303
+ patches from the files one by one. Training on arrays is always performed
304
+ in memory.
305
+
306
+ If no validation source is provided, then the validation is extracted from
307
+ the training data using `val_percentage` and `val_minimum_split`. In the case
308
+ of data provided as Path or str, the percentage and minimum number are applied
309
+ to the number of files. For arrays, it is the number of patches.
310
+
311
+ Parameters
312
+ ----------
313
+ datamodule : TrainDataModule, optional
314
+ Datamodule to train on, by default None.
315
+ train_source : pathlib.Path or str or NDArray, optional
316
+ Train source, if no datamodule is provided, by default None.
317
+ val_source : pathlib.Path or str or NDArray, optional
318
+ Validation source, if no datamodule is provided, by default None.
319
+ train_target : pathlib.Path or str or NDArray, optional
320
+ Train target source, if no datamodule is provided, by default None.
321
+ val_target : pathlib.Path or str or NDArray, optional
322
+ Validation target source, if no datamodule is provided, by default None.
323
+ use_in_memory : bool, optional
324
+ Use in memory dataset if possible, by default True.
325
+ val_percentage : float, optional
326
+ Percentage of validation extracted from training data, by default 0.1.
327
+ val_minimum_split : int, optional
328
+ Minimum number of validation (patch or file) extracted from training data,
329
+ by default 1.
330
+
331
+ Raises
332
+ ------
333
+ ValueError
334
+ If both `datamodule` and `train_source` are provided.
335
+ ValueError
336
+ If sources are not of the same type (e.g. train is an array and val is
337
+ a Path).
338
+ ValueError
339
+ If the training target is provided to N2V.
340
+ ValueError
341
+ If neither a datamodule nor a source is provided.
342
+ """
343
+ if datamodule is not None and train_source is not None:
344
+ raise ValueError(
345
+ "Only one of `datamodule` and `train_source` can be provided."
346
+ )
347
+
348
+ # check that inputs are the same type
349
+ source_types = {
350
+ type(s)
351
+ for s in (train_source, val_source, train_target, val_target)
352
+ if s is not None
353
+ }
354
+ if len(source_types) > 1:
355
+ raise ValueError("All sources should be of the same type.")
356
+
357
+ # train
358
+ if datamodule is not None:
359
+ self._train_on_datamodule(datamodule=datamodule)
360
+
361
+ else:
362
+ # raise error if target is provided to N2V
363
+ if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
364
+ if train_target is not None:
365
+ raise ValueError(
366
+ "Training target not compatible with N2V training."
367
+ )
368
+
369
+ # dispatch the training
370
+ if isinstance(train_source, np.ndarray):
371
+ # mypy checks
372
+ assert isinstance(val_source, np.ndarray) or val_source is None
373
+ assert isinstance(train_target, np.ndarray) or train_target is None
374
+ assert isinstance(val_target, np.ndarray) or val_target is None
375
+
376
+ self._train_on_array(
377
+ train_source,
378
+ val_source,
379
+ train_target,
380
+ val_target,
381
+ val_percentage,
382
+ val_minimum_split,
383
+ )
384
+
385
+ elif isinstance(train_source, Path) or isinstance(train_source, str):
386
+ # mypy checks
387
+ assert (
388
+ isinstance(val_source, Path)
389
+ or isinstance(val_source, str)
390
+ or val_source is None
391
+ )
392
+ assert (
393
+ isinstance(train_target, Path)
394
+ or isinstance(train_target, str)
395
+ or train_target is None
396
+ )
397
+ assert (
398
+ isinstance(val_target, Path)
399
+ or isinstance(val_target, str)
400
+ or val_target is None
401
+ )
402
+
403
+ self._train_on_path(
404
+ train_source,
405
+ val_source,
406
+ train_target,
407
+ val_target,
408
+ use_in_memory,
409
+ val_percentage,
410
+ val_minimum_split,
411
+ )
412
+
413
+ else:
414
+ raise ValueError(
415
+ f"Invalid input, expected a str, Path, array or TrainDataModule "
416
+ f"instance (got {type(train_source)})."
417
+ )
418
+
419
+ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
420
+ """
421
+ Train the model on the provided datamodule.
422
+
423
+ Parameters
424
+ ----------
425
+ datamodule : TrainDataModule
426
+ Datamodule to train on.
427
+ """
428
+ # register datamodule
429
+ self.train_datamodule = datamodule
430
+
431
+ # set defaults (in case `stop_training` was called before)
432
+ self.trainer.should_stop = False
433
+ self.trainer.limit_val_batches = 1.0 # 100%
434
+
435
+ # train
436
+ self.trainer.fit(self.model, datamodule=datamodule)
437
+
438
+ def _train_on_array(
439
+ self,
440
+ train_data: NDArray,
441
+ val_data: NDArray | None = None,
442
+ train_target: NDArray | None = None,
443
+ val_target: NDArray | None = None,
444
+ val_percentage: float = 0.1,
445
+ val_minimum_split: int = 5,
446
+ ) -> None:
447
+ """
448
+ Train the model on the provided data arrays.
449
+
450
+ Parameters
451
+ ----------
452
+ train_data : NDArray
453
+ Training data.
454
+ val_data : NDArray, optional
455
+ Validation data, by default None.
456
+ train_target : NDArray, optional
457
+ Train target data, by default None.
458
+ val_target : NDArray, optional
459
+ Validation target data, by default None.
460
+ val_percentage : float, optional
461
+ Percentage of patches to use for validation, by default 0.1.
462
+ val_minimum_split : int, optional
463
+ Minimum number of patches to use for validation, by default 5.
464
+ """
465
+ # create datamodule
466
+ datamodule = TrainDataModule(
467
+ data_config=self.cfg.data_config,
468
+ train_data=train_data,
469
+ val_data=val_data,
470
+ train_data_target=train_target,
471
+ val_data_target=val_target,
472
+ val_percentage=val_percentage,
473
+ val_minimum_split=val_minimum_split,
474
+ )
475
+
476
+ # train
477
+ self.train(datamodule=datamodule)
478
+
479
+ def _train_on_path(
480
+ self,
481
+ path_to_train_data: Union[Path, str],
482
+ path_to_val_data: Union[Path, str] | None = None,
483
+ path_to_train_target: Union[Path, str] | None = None,
484
+ path_to_val_target: Union[Path, str] | None = None,
485
+ use_in_memory: bool = True,
486
+ val_percentage: float = 0.1,
487
+ val_minimum_split: int = 1,
488
+ ) -> None:
489
+ """
490
+ Train the model on the provided data paths.
491
+
492
+ Parameters
493
+ ----------
494
+ path_to_train_data : pathlib.Path or str
495
+ Path to the training data.
496
+ path_to_val_data : pathlib.Path or str, optional
497
+ Path to validation data, by default None.
498
+ path_to_train_target : pathlib.Path or str, optional
499
+ Path to train target data, by default None.
500
+ path_to_val_target : pathlib.Path or str, optional
501
+ Path to validation target data, by default None.
502
+ use_in_memory : bool, optional
503
+ Use in memory dataset if possible, by default True.
504
+ val_percentage : float, optional
505
+ Percentage of files to use for validation, by default 0.1.
506
+ val_minimum_split : int, optional
507
+ Minimum number of files to use for validation, by default 1.
508
+ """
509
+ # sanity check on data (path exists)
510
+ path_to_train_data = check_path_exists(path_to_train_data)
511
+
512
+ if path_to_val_data is not None:
513
+ path_to_val_data = check_path_exists(path_to_val_data)
514
+
515
+ if path_to_train_target is not None:
516
+ path_to_train_target = check_path_exists(path_to_train_target)
517
+
518
+ if path_to_val_target is not None:
519
+ path_to_val_target = check_path_exists(path_to_val_target)
520
+
521
+ # create datamodule
522
+ datamodule = TrainDataModule(
523
+ data_config=self.cfg.data_config,
524
+ train_data=path_to_train_data,
525
+ val_data=path_to_val_data,
526
+ train_data_target=path_to_train_target,
527
+ val_data_target=path_to_val_target,
528
+ use_in_memory=use_in_memory,
529
+ val_percentage=val_percentage,
530
+ val_minimum_split=val_minimum_split,
531
+ )
532
+
533
+ # train
534
+ self.train(datamodule=datamodule)
535
+
536
+ @overload
537
+ def predict( # numpydoc ignore=GL08
538
+ self, source: PredictDataModule
539
+ ) -> Union[list[NDArray], NDArray]: ...
540
+
541
+ @overload
542
+ def predict( # numpydoc ignore=GL08
543
+ self,
544
+ source: Union[Path, str],
545
+ *,
546
+ batch_size: int = 1,
547
+ tile_size: tuple[int, ...] | None = None,
548
+ tile_overlap: tuple[int, ...] | None = (48, 48),
549
+ axes: str | None = None,
550
+ data_type: Literal["tiff", "custom"] | None = None,
551
+ tta_transforms: bool = False,
552
+ dataloader_params: dict | None = None,
553
+ read_source_func: Callable | None = None,
554
+ extension_filter: str = "",
555
+ ) -> Union[list[NDArray], NDArray]: ...
556
+
557
+ @overload
558
+ def predict( # numpydoc ignore=GL08
559
+ self,
560
+ source: NDArray,
561
+ *,
562
+ batch_size: int = 1,
563
+ tile_size: tuple[int, ...] | None = None,
564
+ tile_overlap: tuple[int, ...] | None = (48, 48),
565
+ axes: str | None = None,
566
+ data_type: Literal["array"] | None = None,
567
+ tta_transforms: bool = False,
568
+ dataloader_params: dict | None = None,
569
+ ) -> Union[list[NDArray], NDArray]: ...
570
+
571
+ def predict(
572
+ self,
573
+ source: Union[PredictDataModule, Path, str, NDArray],
574
+ *,
575
+ batch_size: int = 1,
576
+ tile_size: tuple[int, ...] | None = None,
577
+ tile_overlap: tuple[int, ...] | None = (48, 48),
578
+ axes: str | None = None,
579
+ data_type: Literal["array", "tiff", "custom"] | None = None,
580
+ tta_transforms: bool = False,
581
+ dataloader_params: dict | None = None,
582
+ read_source_func: Callable | None = None,
583
+ extension_filter: str = "",
584
+ **kwargs: Any,
585
+ ) -> Union[list[NDArray], NDArray]:
586
+ """
587
+ Make predictions on the provided data.
588
+
589
+ Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
590
+ array.
591
+
592
+ If `data_type`, `axes` and `tile_size` are not provided, the training
593
+ configuration parameters will be used, with the `patch_size` instead of
594
+ `tile_size`.
595
+
596
+ Test-time augmentation (TTA) can be switched on using the `tta_transforms`
597
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
598
+ rotations to the prediction input and averages the predictions. TTA augmentation
599
+ should not be used if you did not train with these augmentations.
600
+
601
+ Note that if you are using a UNet model and tiling, the tile size must be
602
+ divisible in every dimension by 2**d, where d is the depth of the model. This
603
+ avoids artefacts arising from the broken shift invariance induced by the
604
+ pooling layers of the UNet. If your image has less dimensions, as it may
605
+ happen in the Z dimension, consider padding your image.
606
+
607
+ Parameters
608
+ ----------
609
+ source : PredictDataModule, pathlib.Path, str or numpy.ndarray
610
+ Data to predict on.
611
+ batch_size : int, default=1
612
+ Batch size for prediction.
613
+ tile_size : tuple of int, optional
614
+ Size of the tiles to use for prediction.
615
+ tile_overlap : tuple of int, default=(48, 48)
616
+ Overlap between tiles, can be None.
617
+ axes : str, optional
618
+ Axes of the input data, by default None.
619
+ data_type : {"array", "tiff", "custom"}, optional
620
+ Type of the input data.
621
+ tta_transforms : bool, default=True
622
+ Whether to apply test-time augmentation.
623
+ dataloader_params : dict, optional
624
+ Parameters to pass to the dataloader.
625
+ read_source_func : Callable, optional
626
+ Function to read the source data.
627
+ extension_filter : str, default=""
628
+ Filter for the file extension.
629
+ **kwargs : Any
630
+ Unused.
631
+
632
+ Returns
633
+ -------
634
+ list of NDArray or NDArray
635
+ Predictions made by the model.
636
+
637
+ Raises
638
+ ------
639
+ ValueError
640
+ If mean and std are not provided in the configuration.
641
+ ValueError
642
+ If tile size is not divisible by 2**depth for UNet models.
643
+ ValueError
644
+ If tile overlap is not specified.
645
+ """
646
+ if (
647
+ self.cfg.data_config.image_means is None
648
+ or self.cfg.data_config.image_stds is None
649
+ ):
650
+ raise ValueError("Mean and std must be provided in the configuration.")
651
+
652
+ # tile size for UNets
653
+ if tile_size is not None:
654
+ model = self.cfg.algorithm_config.model
655
+
656
+ if model.architecture == SupportedArchitecture.UNET.value:
657
+ # tile size must be equal to k*2^n, where n is the number of pooling
658
+ # layers (equal to the depth) and k is an integer
659
+ depth = model.depth
660
+ tile_increment = 2**depth
661
+
662
+ for i, t in enumerate(tile_size):
663
+ if t % tile_increment != 0:
664
+ raise ValueError(
665
+ f"Tile size must be divisible by {tile_increment} along "
666
+ f"all axes (got {t} for axis {i}). If your image size is "
667
+ f"smaller along one axis (e.g. Z), consider padding the "
668
+ f"image."
669
+ )
670
+
671
+ # tile overlaps must be specified
672
+ if tile_overlap is None:
673
+ raise ValueError("Tile overlap must be specified.")
674
+
675
+ # create the prediction
676
+ self.pred_datamodule = create_predict_datamodule(
677
+ pred_data=source,
678
+ data_type=data_type or self.cfg.data_config.data_type, # type: ignore
679
+ axes=axes or self.cfg.data_config.axes,
680
+ image_means=self.cfg.data_config.image_means,
681
+ image_stds=self.cfg.data_config.image_stds,
682
+ tile_size=tile_size,
683
+ tile_overlap=tile_overlap,
684
+ batch_size=batch_size or self.cfg.data_config.batch_size,
685
+ tta_transforms=tta_transforms,
686
+ read_source_func=read_source_func,
687
+ extension_filter=extension_filter,
688
+ dataloader_params=dataloader_params,
689
+ )
690
+
691
+ # predict
692
+ predictions = self.trainer.predict(
693
+ model=self.model, datamodule=self.pred_datamodule
694
+ )
695
+ return convert_outputs(predictions, self.pred_datamodule.tiled)
696
+
697
+ def predict_to_disk(
698
+ self,
699
+ source: Union[PredictDataModule, Path, str],
700
+ *,
701
+ batch_size: int = 1,
702
+ tile_size: tuple[int, ...] | None = None,
703
+ tile_overlap: tuple[int, ...] | None = (48, 48),
704
+ axes: str | None = None,
705
+ data_type: Literal["tiff", "custom"] | None = None,
706
+ tta_transforms: bool = False,
707
+ dataloader_params: dict | None = None,
708
+ read_source_func: Callable | None = None,
709
+ extension_filter: str = "",
710
+ write_type: Literal["tiff", "custom"] = "tiff",
711
+ write_extension: str | None = None,
712
+ write_func: WriteFunc | None = None,
713
+ write_func_kwargs: dict[str, Any] | None = None,
714
+ prediction_dir: Union[Path, str] = "predictions",
715
+ **kwargs,
716
+ ) -> None:
717
+ """
718
+ Make predictions on the provided data and save outputs to files.
719
+
720
+ The predictions will be saved in a new directory 'predictions' within the set
721
+ working directory. The directory stucture within the 'predictions' directory
722
+ will match that of the source directory.
723
+
724
+ The `source` must be from files and not arrays. The file names of the
725
+ predictions will match those of the source. If there is more than one sample
726
+ within a file, the samples will be saved to seperate files. The file names of
727
+ samples will have the name of the corresponding source file but with the sample
728
+ index appended. E.g. If the the source file name is 'images.tiff' then the first
729
+ sample's prediction will be saved with the file name "image_0.tiff".
730
+ Input can be a PredictDataModule instance, a path to a data file, or a numpy
731
+ array.
732
+
733
+ If `data_type`, `axes` and `tile_size` are not provided, the training
734
+ configuration parameters will be used, with the `patch_size` instead of
735
+ `tile_size`.
736
+
737
+ Test-time augmentation (TTA) can be switched on using the `tta_transforms`
738
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
739
+ rotations to the prediction input and averages the predictions. TTA augmentation
740
+ should not be used if you did not train with these augmentations.
741
+
742
+ Note that if you are using a UNet model and tiling, the tile size must be
743
+ divisible in every dimension by 2**d, where d is the depth of the model. This
744
+ avoids artefacts arising from the broken shift invariance induced by the
745
+ pooling layers of the UNet. If your image has less dimensions, as it may
746
+ happen in the Z dimension, consider padding your image.
747
+
748
+ Parameters
749
+ ----------
750
+ source : PredictDataModule or pathlib.Path, str
751
+ Data to predict on.
752
+ batch_size : int, default=1
753
+ Batch size for prediction.
754
+ tile_size : tuple of int, optional
755
+ Size of the tiles to use for prediction.
756
+ tile_overlap : tuple of int, default=(48, 48)
757
+ Overlap between tiles.
758
+ axes : str, optional
759
+ Axes of the input data, by default None.
760
+ data_type : {"array", "tiff", "custom"}, optional
761
+ Type of the input data.
762
+ tta_transforms : bool, default=True
763
+ Whether to apply test-time augmentation.
764
+ dataloader_params : dict, optional
765
+ Parameters to pass to the dataloader.
766
+ read_source_func : Callable, optional
767
+ Function to read the source data.
768
+ extension_filter : str, default=""
769
+ Filter for the file extension.
770
+ write_type : {"tiff", "custom"}, default="tiff"
771
+ The data type to save as, includes custom.
772
+ write_extension : str, optional
773
+ If a known `write_type` is selected this argument is ignored. For a custom
774
+ `write_type` an extension to save the data with must be passed.
775
+ write_func : WriteFunc, optional
776
+ If a known `write_type` is selected this argument is ignored. For a custom
777
+ `write_type` a function to save the data must be passed. See notes below.
778
+ write_func_kwargs : dict of {str: any}, optional
779
+ Additional keyword arguments to be passed to the save function.
780
+ prediction_dir : Path | str, default="predictions"
781
+ The path to save the prediction results to. If `prediction_dir` is not
782
+ absolute, the directory will be assumed to be relative to the pre-set
783
+ `work_dir`. If the directory does not exist it will be created.
784
+ **kwargs : Any
785
+ Unused.
786
+
787
+ Raises
788
+ ------
789
+ ValueError
790
+ If `write_type` is custom and `write_extension` is None.
791
+ ValueError
792
+ If `write_type` is custom and `write_fun is None.
793
+ ValueError
794
+ If `source` is not `str`, `Path` or `PredictDataModule`
795
+ """
796
+ if write_func_kwargs is None:
797
+ write_func_kwargs = {}
798
+
799
+ if Path(prediction_dir).is_absolute():
800
+ write_dir = Path(prediction_dir)
801
+ else:
802
+ write_dir = self.work_dir / prediction_dir
803
+ write_dir.mkdir(exist_ok=True, parents=True)
804
+
805
+ # guards for custom types
806
+ if write_type == SupportedData.CUSTOM:
807
+ if write_extension is None:
808
+ raise ValueError(
809
+ "A `write_extension` must be provided for custom write types."
810
+ )
811
+ if write_func is None:
812
+ raise ValueError(
813
+ "A `write_func` must be provided for custom write types."
814
+ )
815
+ else:
816
+ write_func = get_write_func(write_type)
817
+ write_extension = SupportedData.get_extension(write_type)
818
+
819
+ # extract file names
820
+ source_path: Union[Path, str, NDArray]
821
+ source_data_type: Literal["array", "tiff", "custom"]
822
+ if isinstance(source, PredictDataModule):
823
+ source_path = source.pred_data
824
+ source_data_type = source.data_type # type: ignore
825
+ extension_filter = source.extension_filter
826
+ elif isinstance(source, (str | Path)):
827
+ source_path = source
828
+ source_data_type = (
829
+ data_type or self.cfg.data_config.data_type # type: ignore
830
+ )
831
+ extension_filter = SupportedData.get_extension_pattern(
832
+ SupportedData(source_data_type)
833
+ )
834
+ else:
835
+ raise ValueError(f"Unsupported source type: '{type(source)}'.")
836
+
837
+ if source_data_type == "array":
838
+ raise ValueError(
839
+ "Predicting to disk is not supported for input type 'array'."
840
+ )
841
+ assert isinstance(source_path, (Path | str)) # because data_type != "array"
842
+ source_path = Path(source_path)
843
+
844
+ file_paths = list_files(source_path, source_data_type, extension_filter)
845
+
846
+ # predict and write each file in turn
847
+ for file_path in file_paths:
848
+ # source_path is relative to original source path...
849
+ # should mirror original directory structure
850
+ prediction = self.predict(
851
+ source=file_path,
852
+ batch_size=batch_size,
853
+ tile_size=tile_size,
854
+ tile_overlap=tile_overlap,
855
+ axes=axes,
856
+ data_type=data_type,
857
+ tta_transforms=tta_transforms,
858
+ dataloader_params=dataloader_params,
859
+ read_source_func=read_source_func,
860
+ extension_filter=extension_filter,
861
+ **kwargs,
862
+ )
863
+ # TODO: cast to float16?
864
+ write_data = np.concatenate(prediction)
865
+
866
+ # create directory structure and write path
867
+ if not source_path.is_file():
868
+ file_write_dir = write_dir / file_path.parent.relative_to(source_path)
869
+ else:
870
+ file_write_dir = write_dir
871
+ file_write_dir.mkdir(parents=True, exist_ok=True)
872
+ write_path = (file_write_dir / file_path.name).with_suffix(write_extension)
873
+
874
+ # write data
875
+ write_func(file_path=write_path, img=write_data)
876
+
877
+ def export_to_bmz(
878
+ self,
879
+ path_to_archive: Union[Path | str],
880
+ friendly_model_name: str,
881
+ input_array: NDArray,
882
+ authors: list[dict],
883
+ general_description: str,
884
+ data_description: str,
885
+ covers: list[Union[Path, str]] | None = None,
886
+ channel_names: list[str] | None = None,
887
+ model_version: str = "0.1.0",
888
+ ) -> None:
889
+ """Export the model to the BioImage Model Zoo format.
890
+
891
+ This method packages the current weights into a zip file that can be uploaded
892
+ to the BioImage Model Zoo. The archive consists of the model weights, the model
893
+ specifications and various files (inputs, outputs, README, env.yaml etc.).
894
+
895
+ `path_to_archive` should point to a file with a ".zip" extension.
896
+
897
+ `friendly_model_name` is the name used for the model in the BMZ specs
898
+ and website, it should consist of letters, numbers, dashes, underscores and
899
+ parentheses only.
900
+
901
+ Input array must be of the same dimensions as the axes recorded in the
902
+ configuration of the `CAREamist`.
903
+
904
+ Parameters
905
+ ----------
906
+ path_to_archive : pathlib.Path or str
907
+ Path in which to save the model, including file name, which should end with
908
+ ".zip".
909
+ friendly_model_name : str
910
+ Name of the model as used in the BMZ specs, it should consist of letters,
911
+ numbers, dashes, underscores and parentheses only.
912
+ input_array : NDArray
913
+ Input array used to validate the model and as example.
914
+ authors : list of dict
915
+ List of authors of the model.
916
+ general_description : str
917
+ General description of the model used in the BMZ metadata.
918
+ data_description : str
919
+ Description of the data the model was trained on.
920
+ covers : list of pathlib.Path or str, default=None
921
+ Paths to the cover images.
922
+ channel_names : list of str, default=None
923
+ Channel names.
924
+ model_version : str, default="0.1.0"
925
+ Version of the model.
926
+ """
927
+ # TODO: add in docs that it is expected that input_array dimensions match
928
+ # those in data_config
929
+
930
+ output_patch = self.predict(
931
+ input_array,
932
+ data_type=SupportedData.ARRAY.value,
933
+ tta_transforms=False,
934
+ )
935
+ output = np.concatenate(output_patch, axis=0)
936
+ input_array = reshape_array(input_array, self.cfg.data_config.axes)
937
+
938
+ export_to_bmz(
939
+ model=self.model,
940
+ config=self.cfg,
941
+ path_to_archive=path_to_archive,
942
+ model_name=friendly_model_name,
943
+ general_description=general_description,
944
+ data_description=data_description,
945
+ authors=authors,
946
+ input_array=input_array,
947
+ output_array=output,
948
+ covers=covers,
949
+ channel_names=channel_names,
950
+ model_version=model_version,
951
+ )
952
+
953
+ def get_losses(self) -> dict[str, list]:
954
+ """Return data that can be used to plot train and validation loss curves.
955
+
956
+ Returns
957
+ -------
958
+ dict of str: list
959
+ Dictionary containing the losses for each epoch.
960
+ """
961
+ return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")