careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,59 @@
1
+ """Pydantic model for the Normalize transform."""
2
+
3
+ from typing import Literal, Self
4
+
5
+ from pydantic import ConfigDict, Field, model_validator
6
+
7
+ from .transform_config import TransformConfig
8
+
9
+
10
+ class NormalizeConfig(TransformConfig):
11
+ """
12
+ Pydantic model used to represent Normalize transformation.
13
+
14
+ The Normalize transform is a zero mean and unit variance transformation.
15
+
16
+ Attributes
17
+ ----------
18
+ name : Literal["Normalize"]
19
+ Name of the transformation.
20
+ mean : float
21
+ Mean value for normalization.
22
+ std : float
23
+ Standard deviation value for normalization.
24
+ """
25
+
26
+ model_config = ConfigDict(
27
+ validate_assignment=True,
28
+ )
29
+
30
+ name: Literal["Normalize"] = "Normalize"
31
+ image_means: list = Field(..., min_length=0, max_length=32)
32
+ image_stds: list = Field(..., min_length=0, max_length=32)
33
+ target_means: list | None = Field(default=None, min_length=0, max_length=32)
34
+ target_stds: list | None = Field(default=None, min_length=0, max_length=32)
35
+
36
+ @model_validator(mode="after")
37
+ def validate_means_stds(self: Self) -> Self:
38
+ """Validate that the means and stds have the same length.
39
+
40
+ Returns
41
+ -------
42
+ Self
43
+ The instance of the model.
44
+ """
45
+ if len(self.image_means) != len(self.image_stds):
46
+ raise ValueError("The number of image means and stds must be the same.")
47
+
48
+ if (self.target_means is None) != (self.target_stds is None):
49
+ raise ValueError(
50
+ "Both target means and stds must be provided together, or bot None."
51
+ )
52
+
53
+ if self.target_means is not None and self.target_stds is not None:
54
+ if len(self.target_means) != len(self.target_stds):
55
+ raise ValueError(
56
+ "The number of target means and stds must be the same."
57
+ )
58
+
59
+ return self
@@ -0,0 +1,45 @@
1
+ """Parent model for the transforms."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+
8
+ class TransformConfig(BaseModel):
9
+ """
10
+ Pydantic model used to represent a transformation.
11
+
12
+ The `model_dump` method is overwritten to exclude the name field.
13
+
14
+ Attributes
15
+ ----------
16
+ name : str
17
+ Name of the transformation.
18
+ """
19
+
20
+ model_config = ConfigDict(
21
+ extra="forbid", # throw errors if the parameters are not properly passed
22
+ )
23
+
24
+ name: str
25
+
26
+ def model_dump(self, **kwargs) -> dict[str, Any]:
27
+ """
28
+ Return the model as a dictionary.
29
+
30
+ Parameters
31
+ ----------
32
+ **kwargs
33
+ Pydantic BaseMode model_dump method keyword arguments.
34
+
35
+ Returns
36
+ -------
37
+ {str: Any}
38
+ Dictionary representation of the model.
39
+ """
40
+ model_dict = super().model_dump(**kwargs)
41
+
42
+ # remove the name field
43
+ model_dict.pop("name")
44
+
45
+ return model_dict
@@ -0,0 +1,29 @@
1
+ """Type used to represent all transformations users can create."""
2
+
3
+ from typing import Annotated, Union
4
+
5
+ from pydantic import Discriminator
6
+
7
+ from .normalize_config import NormalizeConfig
8
+ from .xy_flip_config import XYFlipConfig
9
+ from .xy_random_rotate90_config import XYRandomRotate90Config
10
+
11
+ NORM_AND_SPATIAL_UNION = Annotated[
12
+ Union[
13
+ NormalizeConfig,
14
+ XYFlipConfig,
15
+ XYRandomRotate90Config,
16
+ ],
17
+ Discriminator("name"), # used to tell the different transform models apart
18
+ ]
19
+ """All transforms including normalization."""
20
+
21
+
22
+ SPATIAL_TRANSFORMS_UNION = Annotated[
23
+ Union[
24
+ XYFlipConfig,
25
+ XYRandomRotate90Config,
26
+ ],
27
+ Discriminator("name"), # used to tell the different transform models apart
28
+ ]
29
+ """Available spatial transforms in CAREamics."""
@@ -0,0 +1,43 @@
1
+ """Pydantic model for the XYFlip transform."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_config import TransformConfig
8
+
9
+
10
+ class XYFlipConfig(TransformConfig):
11
+ """
12
+ Pydantic model used to represent XYFlip transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYFlip"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYFlip"] = "XYFlip"
29
+ flip_x: bool = Field(
30
+ True,
31
+ description="Whether to flip along the X axis.",
32
+ )
33
+ flip_y: bool = Field(
34
+ True,
35
+ description="Whether to flip along the Y axis.",
36
+ )
37
+ p: float = Field(
38
+ 0.5,
39
+ description="Probability of applying the transform.",
40
+ ge=0,
41
+ le=1,
42
+ )
43
+ seed: int | None = None
@@ -0,0 +1,35 @@
1
+ """Pydantic model for the XYRandomRotate90 transform."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from .transform_config import TransformConfig
8
+
9
+
10
+ class XYRandomRotate90Config(TransformConfig):
11
+ """
12
+ Pydantic model used to represent the XY random 90 degree rotation transformation.
13
+
14
+ Attributes
15
+ ----------
16
+ name : Literal["XYRandomRotate90"]
17
+ Name of the transformation.
18
+ p : float
19
+ Probability of applying the transform, by default 0.5.
20
+ seed : Optional[int]
21
+ Seed for the random number generator, by default None.
22
+ """
23
+
24
+ model_config = ConfigDict(
25
+ validate_assignment=True,
26
+ )
27
+
28
+ name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
29
+ p: float = Field(
30
+ 0.5,
31
+ description="Probability of applying the transform.",
32
+ ge=0,
33
+ le=1,
34
+ )
35
+ seed: int | None = None
@@ -0,0 +1,8 @@
1
+ """Configuration utilities."""
2
+
3
+ __all__ = [
4
+ "load_configuration",
5
+ "save_configuration",
6
+ ]
7
+
8
+ from .configuration_io import load_configuration, save_configuration
@@ -0,0 +1,85 @@
1
+ """I/O functions for Configuration objects."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import yaml
7
+
8
+ from careamics.config import Configuration
9
+
10
+
11
+ def load_configuration(path: Union[str, Path]) -> Configuration:
12
+ """
13
+ Load configuration from a yaml file.
14
+
15
+ Parameters
16
+ ----------
17
+ path : str or Path
18
+ Path to the configuration.
19
+
20
+ Returns
21
+ -------
22
+ Configuration
23
+ Configuration.
24
+
25
+ Raises
26
+ ------
27
+ FileNotFoundError
28
+ If the configuration file does not exist.
29
+ """
30
+ # load dictionary from yaml
31
+ if not Path(path).exists():
32
+ raise FileNotFoundError(
33
+ f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
34
+ )
35
+
36
+ dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
37
+
38
+ return Configuration(**dictionary)
39
+
40
+
41
+ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
42
+ """
43
+ Save configuration to path.
44
+
45
+ Parameters
46
+ ----------
47
+ config : Configuration
48
+ Configuration to save.
49
+ path : str or Path
50
+ Path to a existing folder in which to save the configuration, or to a valid
51
+ configuration file path (uses a .yml or .yaml extension).
52
+
53
+ Returns
54
+ -------
55
+ Path
56
+ Path object representing the configuration.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If the path does not point to an existing directory or .yml file.
62
+ """
63
+ # make sure path is a Path object
64
+ config_path = Path(path)
65
+
66
+ # check if path is pointing to an existing directory or .yml file
67
+ if config_path.exists():
68
+ if config_path.is_dir():
69
+ config_path = Path(config_path, "config.yml")
70
+ elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
71
+ raise ValueError(
72
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
73
+ )
74
+ else:
75
+ if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
76
+ raise ValueError(
77
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
78
+ )
79
+
80
+ # save configuration as dictionary to yaml
81
+ with open(config_path, "w") as f:
82
+ # dump configuration
83
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
84
+
85
+ return config_path
@@ -0,0 +1,18 @@
1
+ """Validator utilities."""
2
+
3
+ __all__ = [
4
+ "check_axes_validity",
5
+ "check_czi_axes_validity",
6
+ "model_matching_in_out_channels",
7
+ "model_without_final_activation",
8
+ "model_without_n2v2",
9
+ "patch_size_ge_than_8_power_of_2",
10
+ ]
11
+
12
+ from .axes_validators import check_axes_validity, check_czi_axes_validity
13
+ from .model_validators import (
14
+ model_matching_in_out_channels,
15
+ model_without_final_activation,
16
+ model_without_n2v2,
17
+ )
18
+ from .patch_validators import patch_size_ge_than_8_power_of_2
@@ -0,0 +1,90 @@
1
+ """Axes validation utilities."""
2
+
3
+ _AXES = "STCZYX"
4
+
5
+
6
+ def check_axes_validity(axes: str) -> None:
7
+ """
8
+ Sanity check on axes.
9
+
10
+ The constraints on the axes are the following:
11
+ - must be a combination of 'STCZYX'
12
+ - must not contain duplicates
13
+ - must contain at least 2 contiguous axes: X and Y
14
+ - must contain at most 4 axes
15
+
16
+ Axes do not need to be in the order 'STCZYX', as this depends on the user data.
17
+
18
+ Parameters
19
+ ----------
20
+ axes : str
21
+ Axes to validate.
22
+ """
23
+ _axes = axes.upper()
24
+
25
+ # Minimum is 2 (XY) and maximum is 4 (TZYX)
26
+ if len(_axes) < 2 or len(_axes) > 6:
27
+ raise ValueError(
28
+ f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
29
+ )
30
+
31
+ if "YX" not in _axes and "XY" not in _axes:
32
+ raise ValueError(
33
+ f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
34
+ )
35
+
36
+ # all characters must be in REF_AXES = 'STCZYX'
37
+ if not all(s in _AXES for s in _axes):
38
+ raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
39
+
40
+ # check for repeating characters
41
+ for i, s in enumerate(_axes):
42
+ if i != _axes.rfind(s):
43
+ raise ValueError(
44
+ f"Invalid axes {axes}. Cannot contain duplicate axes"
45
+ f" (got multiple {axes[i]})."
46
+ )
47
+
48
+
49
+ def check_czi_axes_validity(axes: str) -> bool:
50
+ """
51
+ Check if the provided axes string is valid for CZI files.
52
+
53
+ CZI axes is always in the "SC(Z/T)YX" format, where Z or T are optional, and S and C
54
+ can be singleton dimensions, but must be provided.
55
+
56
+ Parameters
57
+ ----------
58
+ axes : str
59
+ The axes string to validate.
60
+
61
+ Returns
62
+ -------
63
+ bool
64
+ True if the axes string is valid, False otherwise.
65
+ """
66
+ valid_axes = {"S", "C", "Z", "T", "Y", "X"}
67
+ axes_set = set(axes)
68
+
69
+ # check for invalid characters
70
+ if not axes_set.issubset(valid_axes):
71
+ return False
72
+
73
+ # check for mandatory axes
74
+ if not ({"S", "C", "Y", "X"}.issubset(axes_set)):
75
+ return False
76
+
77
+ # check for mutually exclusive axes
78
+ if "Z" in axes_set and "T" in axes_set:
79
+ return False
80
+
81
+ # check for correct order
82
+ order = "SCZYX" if "Z" in axes else "SCTYX"
83
+ last_index = -1
84
+ for axis in axes:
85
+ current_index = order.find(axis)
86
+ if current_index < last_index:
87
+ return False
88
+ last_index = current_index
89
+
90
+ return True
@@ -0,0 +1,84 @@
1
+ """Architecture model validators."""
2
+
3
+ from careamics.config.architectures import UNetConfig
4
+
5
+
6
+ def model_without_n2v2(model: UNetConfig) -> UNetConfig:
7
+ """Validate that the Unet model does not have the n2v2 attribute.
8
+
9
+ Parameters
10
+ ----------
11
+ model : UNetModel
12
+ Model to validate.
13
+
14
+ Returns
15
+ -------
16
+ UNetModel
17
+ The validated model.
18
+
19
+ Raises
20
+ ------
21
+ ValueError
22
+ If the model has the `n2v2` attribute set to `True`.
23
+ """
24
+ if model.n2v2:
25
+ raise ValueError(
26
+ "The algorithm does not support the `n2v2` attribute in the model. "
27
+ "Set it to `False`."
28
+ )
29
+
30
+ return model
31
+
32
+
33
+ def model_without_final_activation(model: UNetConfig) -> UNetConfig:
34
+ """Validate that the UNet model does not have the final_activation.
35
+
36
+ Parameters
37
+ ----------
38
+ model : UNetModel
39
+ Model to validate.
40
+
41
+ Returns
42
+ -------
43
+ UNetModel
44
+ The validated model.
45
+
46
+ Raises
47
+ ------
48
+ ValueError
49
+ If the model has the final_activation attribute set.
50
+ """
51
+ if model.final_activation != "None":
52
+ raise ValueError(
53
+ "The algorithm does not support a `final_activation` in the model. "
54
+ 'Set it to `"None"`.'
55
+ )
56
+
57
+ return model
58
+
59
+
60
+ def model_matching_in_out_channels(model: UNetConfig) -> UNetConfig:
61
+ """Validate that the UNet model has the same number of channel inputs and outputs.
62
+
63
+ Parameters
64
+ ----------
65
+ model : UNetModel
66
+ Model to validate.
67
+
68
+ Returns
69
+ -------
70
+ UNetModel
71
+ Validated model.
72
+
73
+ Raises
74
+ ------
75
+ ValueError
76
+ If the model has different number of input and output channels.
77
+ """
78
+ if model.num_classes != model.in_channels:
79
+ raise ValueError(
80
+ "The algorithm requires the same number of input and output channels. "
81
+ "Make sure that `in_channels` and `num_classes` are equal."
82
+ )
83
+
84
+ return model
@@ -0,0 +1,55 @@
1
+ """
2
+ Validator functions.
3
+
4
+ These functions are used to validate dimensions and axes of inputs.
5
+ """
6
+
7
+ from collections.abc import Sequence
8
+
9
+
10
+ def _value_ge_than_8_power_of_2(
11
+ value: int,
12
+ ) -> None:
13
+ """
14
+ Validate that the value is greater or equal than 8 and a power of 2.
15
+
16
+ Parameters
17
+ ----------
18
+ value : int
19
+ Value to validate.
20
+
21
+ Raises
22
+ ------
23
+ ValueError
24
+ If the value is smaller than 8.
25
+ ValueError
26
+ If the value is not a power of 2.
27
+ """
28
+ if value < 8:
29
+ raise ValueError(f"Value must be greater than 8 (got {value}).")
30
+
31
+ if (value & (value - 1)) != 0:
32
+ raise ValueError(f"Value must be a power of 2 (got {value}).")
33
+
34
+
35
+ def patch_size_ge_than_8_power_of_2(
36
+ patch_list: Sequence[int] | None,
37
+ ) -> None:
38
+ """
39
+ Validate that each entry is greater or equal than 8 and a power of 2.
40
+
41
+ Parameters
42
+ ----------
43
+ patch_list : Sequence of int, or None
44
+ Patch size.
45
+
46
+ Raises
47
+ ------
48
+ ValueError
49
+ If the patch size if smaller than 8.
50
+ ValueError
51
+ If the patch size is not a power of 2.
52
+ """
53
+ if patch_list is not None:
54
+ for dim in patch_list:
55
+ _value_ge_than_8_power_of_2(dim)
careamics/conftest.py ADDED
@@ -0,0 +1,39 @@
1
+ """File used to discover python modules and run doctest.
2
+
3
+ See https://sybil.readthedocs.io/en/latest/use.html#pytest
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import pytest
9
+ from pytest import TempPathFactory
10
+ from sybil import Sybil
11
+ from sybil.parsers.codeblock import PythonCodeBlockParser
12
+ from sybil.parsers.doctest import DocTestParser
13
+
14
+
15
+ @pytest.fixture(scope="module")
16
+ def my_path(tmpdir_factory: TempPathFactory) -> Path:
17
+ """Fixture used in doctest to create a temporary directory.
18
+
19
+ Parameters
20
+ ----------
21
+ tmpdir_factory : TempPathFactory
22
+ Temporary path factory from pytest.
23
+
24
+ Returns
25
+ -------
26
+ Path
27
+ Temporary directory path.
28
+ """
29
+ return tmpdir_factory.mktemp("my_path")
30
+
31
+
32
+ pytest_collect_file = Sybil(
33
+ parsers=[
34
+ DocTestParser(),
35
+ PythonCodeBlockParser(future_imports=["print_function"]),
36
+ ],
37
+ pattern="*.py",
38
+ fixtures=["my_path"],
39
+ ).pytest()
@@ -0,0 +1,17 @@
1
+ """Dataset module."""
2
+
3
+ __all__ = [
4
+ "InMemoryDataset",
5
+ "InMemoryPredDataset",
6
+ "InMemoryTiledPredDataset",
7
+ "IterablePredDataset",
8
+ "IterableTiledPredDataset",
9
+ "PathIterableDataset",
10
+ ]
11
+
12
+ from .in_memory_dataset import InMemoryDataset
13
+ from .in_memory_pred_dataset import InMemoryPredDataset
14
+ from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
15
+ from .iterable_dataset import PathIterableDataset
16
+ from .iterable_pred_dataset import IterablePredDataset
17
+ from .iterable_tiled_pred_dataset import IterableTiledPredDataset
@@ -0,0 +1,19 @@
1
+ """Files and arrays utils used in the datasets."""
2
+
3
+ __all__ = [
4
+ "WelfordStatistics",
5
+ "compute_normalization_stats",
6
+ "get_files_size",
7
+ "iterate_over_files",
8
+ "list_files",
9
+ "reshape_array",
10
+ "validate_source_target_files",
11
+ ]
12
+
13
+
14
+ from .dataset_utils import (
15
+ reshape_array,
16
+ )
17
+ from .file_utils import get_files_size, list_files, validate_source_target_files
18
+ from .iterate_over_files import iterate_over_files
19
+ from .running_stats import WelfordStatistics, compute_normalization_stats