careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,15 @@
1
+ """Pydantic models representing coordinate and patch filters."""
2
+
3
+ __all__ = [
4
+ "FilterConfig",
5
+ "MaskFilterConfig",
6
+ "MaxFilterConfig",
7
+ "MeanSTDFilterConfig",
8
+ "ShannonFilterConfig",
9
+ ]
10
+
11
+ from .filter_config import FilterConfig
12
+ from .mask_filter_config import MaskFilterConfig
13
+ from .max_filter_config import MaxFilterConfig
14
+ from .meanstd_filter_config import MeanSTDFilterConfig
15
+ from .shannon_filter_config import ShannonFilterConfig
@@ -0,0 +1,16 @@
1
+ """Base class for patch and coordinate filtering models."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class FilterConfig(BaseModel):
7
+ """Base class for patch and coordinate filtering models."""
8
+
9
+ name: str
10
+ """Name of the filter."""
11
+
12
+ p: float = Field(1.0, ge=0.0, le=1.0)
13
+ """Probability of applying the filter to a patch or coordinate."""
14
+
15
+ seed: int | None = Field(default=None, gt=0)
16
+ """Seed for the random number generator for reproducibility."""
@@ -0,0 +1,17 @@
1
+ """Pydantic model for the mask coordinate filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import Field
6
+
7
+ from .filter_config import FilterConfig
8
+
9
+
10
+ class MaskFilterConfig(FilterConfig):
11
+ """Pydantic model for the mask coordinate filter."""
12
+
13
+ name: Literal["mask"] = "mask"
14
+ """Name of the filter."""
15
+
16
+ coverage: float = Field(0.5, ge=0.0, le=1.0)
17
+ """Percentage of masked pixels required to keep a patch."""
@@ -0,0 +1,15 @@
1
+ """Pydantic model for the max patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_config import FilterConfig
6
+
7
+
8
+ class MaxFilterConfig(FilterConfig):
9
+ """Pydantic model for the max patch filter."""
10
+
11
+ name: Literal["max"] = "max"
12
+ """Name of the filter."""
13
+
14
+ threshold: float
15
+ """Threshold for the minimum of the max-filtered patch."""
@@ -0,0 +1,18 @@
1
+ """Pydantic model for the mean std patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_config import FilterConfig
6
+
7
+
8
+ class MeanSTDFilterConfig(FilterConfig):
9
+ """Pydantic model for the mean std patch filter."""
10
+
11
+ name: Literal["mean_std"] = "mean_std"
12
+ """Name of the filter."""
13
+
14
+ mean_threshold: float
15
+ """Minimum mean intensity required to keep a patch."""
16
+
17
+ std_threshold: float | None = None
18
+ """Minimum standard deviation required to keep a patch."""
@@ -0,0 +1,15 @@
1
+ """Pydantic model for the Shannon entropy patch filter."""
2
+
3
+ from typing import Literal
4
+
5
+ from .filter_config import FilterConfig
6
+
7
+
8
+ class ShannonFilterConfig(FilterConfig):
9
+ """Pydantic model for the Shannon entropy patch filter."""
10
+
11
+ name: Literal["shannon"] = "shannon"
12
+ """Name of the filter."""
13
+
14
+ threshold: float
15
+ """Minimum Shannon entropy required to keep a patch."""
@@ -0,0 +1,15 @@
1
+ """Patching strategies Pydantic models."""
2
+
3
+ __all__ = [
4
+ "FixedRandomPatchingConfig",
5
+ "RandomPatchingConfig",
6
+ "SequentialPatchingConfig",
7
+ "TiledPatchingConfig",
8
+ "WholePatchingConfig",
9
+ ]
10
+
11
+
12
+ from .random_patching_config import FixedRandomPatchingConfig, RandomPatchingConfig
13
+ from .sequential_patching_config import SequentialPatchingConfig
14
+ from .tiled_patching_config import TiledPatchingConfig
15
+ from .whole_patching_config import WholePatchingConfig
@@ -0,0 +1,102 @@
1
+ """Sequential patching Pydantic model."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from pydantic import Field, ValidationInfo, field_validator
6
+
7
+ from ._patched_config import _PatchedConfig
8
+
9
+
10
+ class _OverlappingPatchedConfig(_PatchedConfig):
11
+ """Overlapping patching Pydantic model.
12
+
13
+ This model is only used for inheritance and validation purposes.
14
+
15
+ Attributes
16
+ ----------
17
+ patch_size : list of int
18
+ The size of the patch in each spatial dimension, each patch size must be a power
19
+ of 2 and larger than 8.
20
+ overlaps : sequence of int, optional
21
+ The overlaps between patches in each spatial dimension. If `None`, no overlap is
22
+ applied. The overlaps must be smaller than the patch size in each spatial
23
+ dimension, and the number of dimensions be either 2 or 3.
24
+ """
25
+
26
+ overlaps: Sequence[int] | None = Field(
27
+ default=None,
28
+ min_length=2,
29
+ max_length=3,
30
+ )
31
+ """The overlaps between patches in each spatial dimension. If `None`, no overlap is
32
+ applied. The overlaps must be smaller than the patch size in each spatial dimension,
33
+ and the number of dimensions be either 2 or 3.
34
+ """
35
+
36
+ @field_validator("overlaps")
37
+ @classmethod
38
+ def overlap_smaller_than_patch_size(
39
+ cls, overlaps: Sequence[int] | None, values: ValidationInfo
40
+ ) -> Sequence[int] | None:
41
+ """
42
+ Validate overlap.
43
+
44
+ Overlaps must be smaller than the patch size in each spatial dimension.
45
+
46
+ Parameters
47
+ ----------
48
+ overlaps : Sequence of int
49
+ Overlap in each dimension.
50
+ values : ValidationInfo
51
+ Dictionary of values.
52
+
53
+ Returns
54
+ -------
55
+ Sequence of int
56
+ Validated overlap.
57
+ """
58
+ if overlaps is None:
59
+ return None
60
+
61
+ patch_size = values.data["patch_size"]
62
+
63
+ if len(overlaps) != len(patch_size):
64
+ raise ValueError(
65
+ f"Overlaps must have the same number of dimensions as the patch size. "
66
+ f"Got {len(overlaps)} dimensions for overlaps and {len(patch_size)} "
67
+ f"dimensions for patch size."
68
+ )
69
+
70
+ if any(o >= p for o, p in zip(overlaps, patch_size, strict=False)):
71
+ raise ValueError(
72
+ f"Overlap must be smaller than the patch size, got {overlaps} versus "
73
+ f"{patch_size}."
74
+ )
75
+
76
+ return overlaps
77
+
78
+ @field_validator("overlaps")
79
+ @classmethod
80
+ def overlap_even(cls, overlaps: Sequence[int] | None) -> Sequence[int] | None:
81
+ """
82
+ Validate overlaps.
83
+
84
+ Overlap must be even.
85
+
86
+ Parameters
87
+ ----------
88
+ overlaps : Sequence of int
89
+ Overlaps.
90
+
91
+ Returns
92
+ -------
93
+ Sequence of int
94
+ Validated overlap.
95
+ """
96
+ if overlaps is None:
97
+ return None
98
+
99
+ if any(o % 2 != 0 for o in overlaps):
100
+ raise ValueError(f"Overlaps must be even, got {overlaps}.")
101
+
102
+ return overlaps
@@ -0,0 +1,56 @@
1
+ """Generic patching Pydantic model."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
6
+
7
+ from careamics.config.validators import patch_size_ge_than_8_power_of_2
8
+
9
+
10
+ class _PatchedConfig(BaseModel):
11
+ """Generic patching Pydantic model.
12
+
13
+ This model is only used for inheritance and validation purposes.
14
+ """
15
+
16
+ model_config = ConfigDict(
17
+ extra="ignore", # default behaviour, make it explicit
18
+ )
19
+
20
+ name: str
21
+ """The name of the patching strategy."""
22
+
23
+ patch_size: Sequence[int] = Field(..., min_length=2, max_length=3)
24
+ """The size of the patch in each spatial dimensions, each patch size must be a power
25
+ of 2 and larger than 8."""
26
+
27
+ @field_validator("patch_size")
28
+ @classmethod
29
+ def all_elements_power_of_2_minimum_8(
30
+ cls, patch_list: Sequence[int]
31
+ ) -> Sequence[int]:
32
+ """
33
+ Validate patch size.
34
+
35
+ Patch size must be powers of 2 and minimum 8.
36
+
37
+ Parameters
38
+ ----------
39
+ patch_list : Sequence of int
40
+ Patch size.
41
+
42
+ Returns
43
+ -------
44
+ Sequence of int
45
+ Validated patch size.
46
+
47
+ Raises
48
+ ------
49
+ ValueError
50
+ If the patch size is smaller than 8.
51
+ ValueError
52
+ If the patch size is not a power of 2.
53
+ """
54
+ patch_size_ge_than_8_power_of_2(patch_list)
55
+
56
+ return patch_list
@@ -0,0 +1,45 @@
1
+ """Random patching Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import Field
6
+
7
+ from ._patched_config import _PatchedConfig
8
+
9
+
10
+ class RandomPatchingConfig(_PatchedConfig):
11
+ """Random patching Pydantic model.
12
+
13
+ Attributes
14
+ ----------
15
+ name : "random"
16
+ The name of the patching strategy.
17
+ patch_size : sequence of int
18
+ The size of the patch in each spatial dimension, each patch size must be a power
19
+ of 2 and larger than 8.
20
+ """
21
+
22
+ name: Literal["random"] = "random"
23
+ """The name of the patching strategy."""
24
+
25
+ seed: int | None = Field(default=None, gt=0)
26
+ """Random seed for patch sampling, set to None for random seeding."""
27
+
28
+
29
+ class FixedRandomPatchingConfig(_PatchedConfig):
30
+ """Fixed random patching Pydantic model.
31
+
32
+ Attributes
33
+ ----------
34
+ name : "fixed_random"
35
+ The name of the patching strategy.
36
+ patch_size : sequence of int
37
+ The size of the patch in each spatial dimension, each patch size must be a power
38
+ of 2 and larger than 8.
39
+ """
40
+
41
+ name: Literal["fixed_random"] = "fixed_random"
42
+ """The name of the patching strategy."""
43
+
44
+ seed: int | None = Field(default=None, gt=0)
45
+ """The random seed to use for patch sampling."""
@@ -0,0 +1,25 @@
1
+ """Sequential patching Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from ._overlapping_patched_config import _OverlappingPatchedConfig
6
+
7
+
8
+ class SequentialPatchingConfig(_OverlappingPatchedConfig):
9
+ """Sequential patching Pydantic model.
10
+
11
+ Attributes
12
+ ----------
13
+ name : "sequential"
14
+ The name of the patching strategy.
15
+ patch_size : sequence of int
16
+ The size of the patch in each spatial dimension, each patch size must be a power
17
+ of 2 and larger than 8.
18
+ overlaps : list of int, optional
19
+ The overlaps between patches in each spatial dimension. If `None`, no overlap is
20
+ applied. The overlaps must be smaller than the patch size in each spatial
21
+ dimension, and the number of dimensions be either 2 or 3.
22
+ """
23
+
24
+ name: Literal["sequential"] = "sequential"
25
+ """The name of the patching strategy."""
@@ -0,0 +1,40 @@
1
+ """Tiled patching Pydantic model."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Literal
5
+
6
+ from pydantic import Field
7
+
8
+ from ._overlapping_patched_config import _OverlappingPatchedConfig
9
+
10
+
11
+ # TODO with UNet tiling must obey different rules than sequential tiling
12
+ # - needs to validated at the level of the configuration
13
+ class TiledPatchingConfig(_OverlappingPatchedConfig):
14
+ """Tiled patching Pydantic model.
15
+
16
+ Attributes
17
+ ----------
18
+ name : "tiled"
19
+ The name of the patching strategy.
20
+ patch_size : sequence of int
21
+ The size of the patch in each spatial dimension, each patch size must be a power
22
+ of 2 and larger than 8.
23
+ overlaps : sequence of int
24
+ The overlaps between patches in each spatial dimension. The overlaps must be
25
+ smaller than the patch size in each spatial dimension, and the number of
26
+ dimensions be either 2 or 3.
27
+ """
28
+
29
+ name: Literal["tiled"] = "tiled"
30
+ """The name of the patching strategy."""
31
+
32
+ overlaps: Sequence[int] = Field(
33
+ ...,
34
+ min_length=2,
35
+ max_length=3,
36
+ )
37
+ """The overlaps between patches in each spatial dimension. The overlaps must be
38
+ smaller than the patch size in each spatial dimension, and the number of dimensions
39
+ be either 2 or 3.
40
+ """
@@ -0,0 +1,12 @@
1
+ """Whole image patching Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class WholePatchingConfig(BaseModel):
9
+ """Whole image patching Pydantic model."""
10
+
11
+ name: Literal["whole"] = "whole"
12
+ """The name of the patching strategy."""
@@ -0,0 +1,65 @@
1
+ """Pydantic model representing the metadata of a prediction tile."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Annotated
6
+
7
+ from annotated_types import Len
8
+ from pydantic import BaseModel, ConfigDict
9
+
10
+ DimTuple = Annotated[tuple, Len(min_length=3, max_length=4)]
11
+
12
+
13
+ class TileInformation(BaseModel):
14
+ """
15
+ Pydantic model containing tile information.
16
+
17
+ This model is used to represent the information required to stitch back a tile into
18
+ a larger image. It is used throughout the prediction pipeline of CAREamics.
19
+
20
+ Array shape should be C(Z)YX, where Z is an optional dimensions.
21
+ """
22
+
23
+ model_config = ConfigDict(validate_default=True)
24
+
25
+ array_shape: DimTuple # TODO: find a way to add custom error message?
26
+ """Shape of the original (untiled) array."""
27
+
28
+ last_tile: bool = False
29
+ """Whether this tile is the last one of the array."""
30
+
31
+ overlap_crop_coords: tuple[tuple[int, ...], ...]
32
+ """Inner coordinates of the tile where to crop the prediction in order to stitch
33
+ it back into the original image."""
34
+
35
+ stitch_coords: tuple[tuple[int, ...], ...]
36
+ """Coordinates in the original image where to stitch the cropped tile back."""
37
+
38
+ sample_id: int
39
+ """Sample ID of the tile."""
40
+
41
+ # TODO: Test that ZYX axes are not singleton ?
42
+
43
+ def __eq__(self, other_tile: object):
44
+ """Check if two tile information objects are equal.
45
+
46
+ Parameters
47
+ ----------
48
+ other_tile : object
49
+ Tile information object to compare with.
50
+
51
+ Returns
52
+ -------
53
+ bool
54
+ Whether the two tile information objects are equal.
55
+ """
56
+ if not isinstance(other_tile, TileInformation):
57
+ return NotImplemented
58
+
59
+ return (
60
+ self.array_shape == other_tile.array_shape
61
+ and self.last_tile == other_tile.last_tile
62
+ and self.overlap_crop_coords == other_tile.overlap_crop_coords
63
+ and self.stitch_coords == other_tile.stitch_coords
64
+ and self.sample_id == other_tile.sample_id
65
+ )
@@ -0,0 +1,15 @@
1
+ """Training and lightning related Pydantic configurations."""
2
+
3
+ __all__ = [
4
+ "CheckpointConfig",
5
+ "EarlyStoppingConfig",
6
+ "LrSchedulerConfig",
7
+ "OptimizerConfig",
8
+ "TrainerConfig",
9
+ "TrainingConfig",
10
+ ]
11
+
12
+
13
+ from .callbacks import CheckpointConfig, EarlyStoppingConfig
14
+ from .optimizer_configs import LrSchedulerConfig, OptimizerConfig
15
+ from .training_config import TrainingConfig
@@ -0,0 +1,8 @@
1
+ """Callbacks Pydantic configurations."""
2
+
3
+ __all__ = [
4
+ "CheckpointConfig",
5
+ "EarlyStoppingConfig",
6
+ ]
7
+
8
+ from .callback_config import CheckpointConfig, EarlyStoppingConfig
@@ -0,0 +1,116 @@
1
+ """Callback Pydantic models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import timedelta
6
+ from typing import Literal
7
+
8
+ from pydantic import (
9
+ BaseModel,
10
+ ConfigDict,
11
+ Field,
12
+ )
13
+
14
+
15
+ class CheckpointConfig(BaseModel):
16
+ """Checkpoint saving callback Pydantic model.
17
+
18
+ The parameters corresponds to those of
19
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
20
+
21
+ See:
22
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
23
+ """
24
+
25
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
26
+
27
+ monitor: Literal["val_loss"] | str | None = Field(default="val_loss")
28
+ """Quantity to monitor, currently only `val_loss`."""
29
+
30
+ verbose: bool = Field(default=False)
31
+ """Verbosity mode."""
32
+
33
+ save_weights_only: bool = Field(default=False)
34
+ """When `True`, only the model's weights will be saved (model.save_weights)."""
35
+
36
+ save_last: Literal[True, False, "link"] | None = Field(default=True)
37
+ """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
38
+
39
+ save_top_k: int = Field(
40
+ default=3,
41
+ ge=-1,
42
+ le=100,
43
+ )
44
+ """If `save_top_k == k, the best k models according to the quantity monitored
45
+ will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
46
+ all models are saved."""
47
+
48
+ mode: Literal["min", "max"] = Field(default="min")
49
+ """One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
50
+ save file is made based on either the maximization or the minimization of the
51
+ monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
52
+ be 'min', etc.
53
+ """
54
+
55
+ auto_insert_metric_name: bool = Field(default=False)
56
+ """When `True`, the checkpoints filenames will contain the metric name."""
57
+
58
+ every_n_train_steps: int | None = Field(default=None, ge=1, le=1000)
59
+ """Number of training steps between checkpoints."""
60
+
61
+ train_time_interval: timedelta | None = Field(default=None)
62
+ """Checkpoints are monitored at the specified time interval."""
63
+
64
+ every_n_epochs: int | None = Field(default=None, ge=1, le=100)
65
+ """Number of epochs between checkpoints."""
66
+
67
+
68
+ class EarlyStoppingConfig(BaseModel):
69
+ """Early stopping callback Pydantic model.
70
+
71
+ The parameters corresponds to those of
72
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
73
+
74
+ See:
75
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
76
+ """
77
+
78
+ model_config = ConfigDict(
79
+ validate_assignment=True,
80
+ validate_default=True,
81
+ )
82
+
83
+ monitor: Literal["val_loss"] = Field(default="val_loss")
84
+ """Quantity to monitor."""
85
+
86
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0)
87
+ """Minimum change in the monitored quantity to qualify as an improvement, i.e. an
88
+ absolute change of less than or equal to min_delta, will count as no improvement."""
89
+
90
+ patience: int = Field(default=3, ge=1, le=10)
91
+ """Number of checks with no improvement after which training will be stopped."""
92
+
93
+ verbose: bool = Field(default=False)
94
+ """Verbosity mode."""
95
+
96
+ mode: Literal["min", "max", "auto"] = Field(default="min")
97
+ """One of {min, max, auto}."""
98
+
99
+ check_finite: bool = Field(default=True)
100
+ """When `True`, stops training when the monitored quantity becomes `NaN` or
101
+ `inf`."""
102
+
103
+ stopping_threshold: float | None = Field(default=None)
104
+ """Stop training immediately once the monitored quantity reaches this threshold."""
105
+
106
+ divergence_threshold: float | None = Field(default=None)
107
+ """Stop training as soon as the monitored quantity becomes worse than this
108
+ threshold."""
109
+
110
+ check_on_train_epoch_end: bool | None = Field(default=False)
111
+ """Whether to run early stopping at the end of the training epoch. If this is
112
+ `False`, then the check runs at the end of the validation."""
113
+
114
+ log_rank_zero_only: bool = Field(default=False)
115
+ """When set `True`, logs the status of the early stopping callback only for rank 0
116
+ process."""