careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,5 @@
1
+ """
2
+ Package containing functions called by the careamics cli.
3
+
4
+ Built using third party package Typer.
5
+ """
careamics/cli/conf.py ADDED
@@ -0,0 +1,394 @@
1
+ """Configuration building convenience functions for the CAREamics CLI."""
2
+
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Annotated
7
+
8
+ import click
9
+ import typer
10
+ import yaml
11
+
12
+ from ..config import (
13
+ Configuration,
14
+ create_care_configuration,
15
+ create_n2n_configuration,
16
+ create_n2v_configuration,
17
+ save_configuration,
18
+ )
19
+ from .utils import handle_2D_3D_callback
20
+
21
+ WORK_DIR = Path.cwd()
22
+
23
+ app = typer.Typer()
24
+
25
+
26
+ def _config_builder_exit(ctx: typer.Context, config: Configuration) -> None:
27
+ """
28
+ Function to be called at the end of a CLI configuration builder.
29
+
30
+ Saves the `config` object and performs other functionality depending on the command
31
+ context.
32
+
33
+ Parameters
34
+ ----------
35
+ ctx : typer.Context
36
+ Typer Context.
37
+ config : Configuration
38
+ CAREamics configuration.
39
+ """
40
+ conf_path = (ctx.obj.dir / ctx.obj.name).with_suffix(".yaml")
41
+ save_configuration(config, conf_path)
42
+ if ctx.obj.print:
43
+ print(yaml.dump(config.model_dump(), indent=2))
44
+
45
+
46
+ @dataclass
47
+ class ConfOptions:
48
+ """Data class for containing CLI `conf` command option values."""
49
+
50
+ dir: Path
51
+ name: str
52
+ force: bool
53
+ print: bool
54
+
55
+
56
+ @app.callback()
57
+ def conf_options( # numpydoc ignore=PR01
58
+ ctx: typer.Context,
59
+ dir: Annotated[
60
+ Path,
61
+ typer.Option(
62
+ "--dir", "-d", exists=True, help="Directory to save the config file to."
63
+ ),
64
+ ] = WORK_DIR,
65
+ name: Annotated[
66
+ str, typer.Option("--name", "-n", help="The config file name.")
67
+ ] = "config",
68
+ force: Annotated[
69
+ bool,
70
+ typer.Option(
71
+ "--force", "-f", help="Whether to overwrite existing config files."
72
+ ),
73
+ ] = False,
74
+ print: Annotated[
75
+ bool,
76
+ typer.Option(
77
+ "--print",
78
+ "-p",
79
+ help="Whether to print the config file to the console.",
80
+ ),
81
+ ] = False,
82
+ ):
83
+ """Build and save CAREamics configuration files."""
84
+ # Callback is called still on --help command
85
+ # If a config exists it will complain that you need to use the -f flag
86
+ if "--help" in sys.argv:
87
+ return
88
+ conf_path = (dir / name).with_suffix(".yaml")
89
+ if conf_path.exists() and not force:
90
+ raise FileExistsError(f"To overwrite '{conf_path}' use flag --force/-f.")
91
+
92
+ ctx.obj = ConfOptions(dir, name, force, print)
93
+
94
+
95
+ # TODO: Need to decide how to parse model kwargs
96
+ # - Could be json style string to be loaded as dict e.g. {"depth": 3}
97
+ # - Cons: Annoying to type, easily have syntax errors
98
+ # - Could parse all unknown options as model kwargs
99
+ # - Cons: There could be argument name clashes
100
+
101
+
102
+ @app.command()
103
+ def care( # numpydoc ignore=PR01
104
+ ctx: typer.Context,
105
+ experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
106
+ axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
107
+ patch_size: Annotated[
108
+ click.Tuple,
109
+ typer.Option(
110
+ help=(
111
+ "Size of the patches along the spatial dimensions (if the data "
112
+ "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
113
+ ),
114
+ click_type=click.Tuple([int, int, int]),
115
+ callback=handle_2D_3D_callback,
116
+ ),
117
+ ],
118
+ batch_size: Annotated[int, typer.Option(help="Batch size.")],
119
+ num_epochs: Annotated[int, typer.Option(help="Number of epochs.")] = 100,
120
+ num_steps: Annotated[
121
+ int | None,
122
+ typer.Option(help="Number of batches per epoch (limit_train_batches)."),
123
+ ] = None,
124
+ data_type: Annotated[
125
+ click.Choice,
126
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
127
+ ] = "tiff",
128
+ use_augmentations: Annotated[
129
+ bool, typer.Option(help="Whether to use augmentations.")
130
+ ] = True,
131
+ independent_channels: Annotated[
132
+ bool, typer.Option(help="Whether to train all channels independently.")
133
+ ] = False,
134
+ loss: Annotated[
135
+ click.Choice,
136
+ typer.Option(
137
+ click_type=click.Choice(["mae", "mse"]),
138
+ help="Loss function to use.",
139
+ ),
140
+ ] = "mae",
141
+ n_channels_in: Annotated[
142
+ int | None, typer.Option(help="Number of channels in")
143
+ ] = None,
144
+ n_channels_out: Annotated[
145
+ int | None, typer.Option(help="Number of channels out")
146
+ ] = None,
147
+ logger: Annotated[
148
+ click.Choice,
149
+ typer.Option(
150
+ click_type=click.Choice(["wandb", "tensorboard", "none"]),
151
+ help="Logger to use.",
152
+ ),
153
+ ] = "none",
154
+ # TODO: How to address model kwargs
155
+ ) -> None:
156
+ """
157
+ Create a configuration for training CARE.
158
+
159
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
160
+ 2.
161
+
162
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
163
+ channels. Likewise, if you set the number of channels, then "C" must be present in
164
+ `axes`.
165
+
166
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
167
+ not specified, it will be assumed to be equal to `n_channels_in`.
168
+
169
+ By default, all channels are trained together. To train all channels independently,
170
+ set `independent_channels` to True.
171
+
172
+ By setting `use_augmentations` to False, the only transformation applied will be
173
+ normalization.
174
+ """
175
+ config = create_care_configuration(
176
+ experiment_name=experiment_name,
177
+ data_type=data_type,
178
+ axes=axes,
179
+ patch_size=patch_size,
180
+ batch_size=batch_size,
181
+ num_epochs=num_epochs,
182
+ num_steps=num_steps,
183
+ # TODO: fix choosing augmentations
184
+ augmentations=None if use_augmentations else [],
185
+ independent_channels=independent_channels,
186
+ loss=loss,
187
+ n_channels_in=n_channels_in,
188
+ n_channels_out=n_channels_out,
189
+ logger=logger,
190
+ )
191
+ _config_builder_exit(ctx, config)
192
+
193
+
194
+ @app.command()
195
+ def n2n( # numpydoc ignore=PR01
196
+ ctx: typer.Context,
197
+ experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
198
+ axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
199
+ patch_size: Annotated[
200
+ click.Tuple,
201
+ typer.Option(
202
+ help=(
203
+ "Size of the patches along the spatial dimensions (if the data "
204
+ "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
205
+ ),
206
+ click_type=click.Tuple([int, int, int]),
207
+ callback=handle_2D_3D_callback,
208
+ ),
209
+ ],
210
+ batch_size: Annotated[int, typer.Option(help="Batch size.")],
211
+ num_epochs: Annotated[int, typer.Option(help="Number of epochs.")] = 100,
212
+ num_steps: Annotated[
213
+ int | None,
214
+ typer.Option(help="Number of batches per epoch (limit_train_batches)."),
215
+ ] = None,
216
+ data_type: Annotated[
217
+ click.Choice,
218
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
219
+ ] = "tiff",
220
+ use_augmentations: Annotated[
221
+ bool, typer.Option(help="Whether to use augmentations.")
222
+ ] = True,
223
+ independent_channels: Annotated[
224
+ bool, typer.Option(help="Whether to train all channels independently.")
225
+ ] = False,
226
+ loss: Annotated[
227
+ click.Choice,
228
+ typer.Option(
229
+ click_type=click.Choice(["mae", "mse"]),
230
+ help="Loss function to use.",
231
+ ),
232
+ ] = "mae",
233
+ n_channels_in: Annotated[
234
+ int | None, typer.Option(help="Number of channels in")
235
+ ] = None,
236
+ n_channels_out: Annotated[
237
+ int | None, typer.Option(help="Number of channels out")
238
+ ] = None,
239
+ logger: Annotated[
240
+ click.Choice,
241
+ typer.Option(
242
+ click_type=click.Choice(["wandb", "tensorboard", "none"]),
243
+ help="Logger to use.",
244
+ ),
245
+ ] = "none",
246
+ # TODO: How to address model kwargs
247
+ ) -> None:
248
+ """
249
+ Create a configuration for training Noise2Noise.
250
+
251
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
252
+ 2.
253
+
254
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
255
+ channels. Likewise, if you set the number of channels, then "C" must be present in
256
+ `axes`.
257
+
258
+ By default, all channels are trained together. To train all channels independently,
259
+ set `independent_channels` to True.
260
+
261
+ By setting `use_augmentations` to False, the only transformation applied will be
262
+ normalization.
263
+ """
264
+ config = create_n2n_configuration(
265
+ experiment_name=experiment_name,
266
+ data_type=data_type,
267
+ axes=axes,
268
+ patch_size=patch_size,
269
+ batch_size=batch_size,
270
+ num_epochs=num_epochs,
271
+ num_steps=num_steps,
272
+ # TODO: fix choosing augmentations
273
+ augmentations=None if use_augmentations else [],
274
+ independent_channels=independent_channels,
275
+ loss=loss,
276
+ n_channels_in=n_channels_in,
277
+ n_channels_out=n_channels_out,
278
+ logger=logger,
279
+ )
280
+ _config_builder_exit(ctx, config)
281
+
282
+
283
+ @app.command()
284
+ def n2v( # numpydoc ignore=PR01
285
+ ctx: typer.Context,
286
+ experiment_name: Annotated[str, typer.Option(help="Name of the experiment.")],
287
+ axes: Annotated[str, typer.Option(help="Axes of the data (e.g. SYX).")],
288
+ patch_size: Annotated[
289
+ click.Tuple,
290
+ typer.Option(
291
+ help=(
292
+ "Size of the patches along the spatial dimensions (if the data "
293
+ "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
294
+ ),
295
+ click_type=click.Tuple([int, int, int]),
296
+ callback=handle_2D_3D_callback,
297
+ ),
298
+ ],
299
+ batch_size: Annotated[int, typer.Option(help="Batch size.")],
300
+ num_epochs: Annotated[int, typer.Option(help="Number of epochs.")] = 100,
301
+ num_steps: Annotated[
302
+ int | None,
303
+ typer.Option(help="Number of batches per epoch (limit_train_batches)."),
304
+ ] = None,
305
+ data_type: Annotated[
306
+ click.Choice,
307
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the data."),
308
+ ] = "tiff",
309
+ use_augmentations: Annotated[
310
+ bool, typer.Option(help="Whether to use augmentations.")
311
+ ] = True,
312
+ independent_channels: Annotated[
313
+ bool, typer.Option(help="Whether to train all channels independently.")
314
+ ] = True,
315
+ use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
316
+ n_channels: Annotated[
317
+ int | None, typer.Option(help="Number of channels (in and out)")
318
+ ] = None,
319
+ roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
320
+ masked_pixel_percentage: Annotated[
321
+ float, typer.Option(help="Percentage of pixels masked in each patch.")
322
+ ] = 0.2,
323
+ struct_n2v_axis: Annotated[
324
+ click.Choice,
325
+ typer.Option(click_type=click.Choice(["horizontal", "vertical", "none"])),
326
+ ] = "none",
327
+ struct_n2v_span: Annotated[
328
+ int, typer.Option(help="Span of the structN2V mask.")
329
+ ] = 5,
330
+ logger: Annotated[
331
+ click.Choice,
332
+ typer.Option(
333
+ click_type=click.Choice(["wandb", "tensorboard", "none"]),
334
+ help="Logger to use.",
335
+ ),
336
+ ] = "none",
337
+ # TODO: How to address model kwargs
338
+ ) -> None:
339
+ """
340
+ Create a configuration for training Noise2Void.
341
+
342
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
343
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
344
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
345
+
346
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
347
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
348
+ or horizontal correlations are present in the noise; it applies an additional mask
349
+ to the manipulated pixel neighbors.
350
+
351
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
352
+ 2.
353
+
354
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
355
+ channels.
356
+
357
+ By default, all channels are trained independently. To train all channels together,
358
+ set `independent_channels` to False.
359
+
360
+ By setting `use_augmentations` to False, the only transformations applied will be
361
+ normalization and N2V manipulation.
362
+
363
+ The `roi_size` parameter specifies the size of the area around each pixel that will
364
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
365
+ pixels per patch will be manipulated.
366
+
367
+ The parameters of the UNet can be specified in the `model_kwargs` (passed as a
368
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
369
+ corresponding parameters passed in `model_kwargs`.
370
+
371
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
372
+ will be applied to each manipulated pixel.
373
+ """
374
+ config = create_n2v_configuration(
375
+ experiment_name=experiment_name,
376
+ data_type=data_type,
377
+ axes=axes,
378
+ patch_size=patch_size,
379
+ batch_size=batch_size,
380
+ num_epochs=num_epochs,
381
+ num_steps=num_steps,
382
+ # TODO: fix choosing augmentations
383
+ augmentations=None if use_augmentations else [],
384
+ independent_channels=independent_channels,
385
+ use_n2v2=use_n2v2,
386
+ n_channels=n_channels,
387
+ roi_size=roi_size,
388
+ masked_pixel_percentage=masked_pixel_percentage,
389
+ struct_n2v_axis=struct_n2v_axis,
390
+ struct_n2v_span=struct_n2v_span,
391
+ logger=logger,
392
+ # TODO: Model kwargs
393
+ )
394
+ _config_builder_exit(ctx, config)
careamics/cli/main.py ADDED
@@ -0,0 +1,234 @@
1
+ """
2
+ Module for CLI functionality and entrypoint.
3
+
4
+ Contains the CLI entrypoint, the `run` function; and first level subcommands `train`
5
+ and `predict`. The `conf` subcommand is added through the `app.add_typer` function, and
6
+ its implementation is contained in the conf.py file.
7
+ """
8
+
9
+ from pathlib import Path
10
+ from typing import Annotated
11
+
12
+ import click
13
+ import typer
14
+
15
+ from ..careamist import CAREamist
16
+ from . import conf
17
+ from .utils import handle_2D_3D_callback
18
+
19
+ app = typer.Typer(
20
+ help="Run CAREamics algorithms from the command line, including Noise2Void "
21
+ "and its many variants and cousins",
22
+ pretty_exceptions_show_locals=False,
23
+ )
24
+ app.add_typer(conf.app, name="conf")
25
+
26
+
27
+ @app.command()
28
+ def train( # numpydoc ignore=PR01
29
+ source: Annotated[
30
+ Path,
31
+ typer.Argument(
32
+ help="Path to a configuration file or a trained model.",
33
+ exists=True,
34
+ file_okay=True,
35
+ dir_okay=False,
36
+ ),
37
+ ],
38
+ train_source: Annotated[
39
+ Path,
40
+ typer.Option(
41
+ "--train-source",
42
+ "-ts",
43
+ help="Path to the training data.",
44
+ exists=True,
45
+ file_okay=True,
46
+ dir_okay=True,
47
+ ),
48
+ ],
49
+ train_target: Annotated[
50
+ Path | None,
51
+ typer.Option(
52
+ "--train-target",
53
+ "-tt",
54
+ help="Path to train target data.",
55
+ exists=True,
56
+ file_okay=True,
57
+ dir_okay=True,
58
+ ),
59
+ ] = None,
60
+ val_source: Annotated[
61
+ Path | None,
62
+ typer.Option(
63
+ "--val-source",
64
+ "-vs",
65
+ help="Path to validation data.",
66
+ exists=True,
67
+ file_okay=True,
68
+ dir_okay=True,
69
+ ),
70
+ ] = None,
71
+ val_target: Annotated[
72
+ Path | None,
73
+ typer.Option(
74
+ "--val-target",
75
+ "-vt",
76
+ help="Path to validation target data.",
77
+ exists=True,
78
+ file_okay=True,
79
+ dir_okay=True,
80
+ ),
81
+ ] = None,
82
+ use_in_memory: Annotated[
83
+ bool,
84
+ typer.Option(
85
+ "--use-in-memory/--not-in-memory",
86
+ "-m/-M",
87
+ help="Use in memory dataset if possible.",
88
+ ),
89
+ ] = True,
90
+ val_percentage: Annotated[
91
+ float,
92
+ typer.Option(help="Percentage of files to use for validation."),
93
+ ] = 0.1,
94
+ val_minimum_split: Annotated[
95
+ int,
96
+ typer.Option(help="Minimum number of files to use for validation,"),
97
+ ] = 1,
98
+ work_dir: Annotated[
99
+ Path | None,
100
+ typer.Option(
101
+ "--work-dir",
102
+ "-wd",
103
+ help=("Path to working directory in which to save checkpoints and logs"),
104
+ exists=True,
105
+ file_okay=False,
106
+ dir_okay=True,
107
+ ),
108
+ ] = None,
109
+ ):
110
+ """Train CAREamics models."""
111
+ engine = CAREamist(source=source, work_dir=work_dir)
112
+ engine.train(
113
+ train_source=train_source,
114
+ val_source=val_source,
115
+ train_target=train_target,
116
+ val_target=val_target,
117
+ use_in_memory=use_in_memory,
118
+ val_percentage=val_percentage,
119
+ val_minimum_split=val_minimum_split,
120
+ )
121
+
122
+
123
+ @app.command()
124
+ def predict( # numpydoc ignore=PR01
125
+ model: Annotated[
126
+ Path,
127
+ typer.Argument(
128
+ help="Path to a configuration file or a trained model.",
129
+ exists=True,
130
+ file_okay=True,
131
+ dir_okay=False,
132
+ ),
133
+ ],
134
+ source: Annotated[
135
+ Path,
136
+ typer.Argument(
137
+ help="Path to the training data. Can be a directory or single file.",
138
+ exists=True,
139
+ file_okay=True,
140
+ dir_okay=True,
141
+ ),
142
+ ],
143
+ batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
144
+ tile_size: Annotated[
145
+ click.Tuple | None,
146
+ typer.Option(
147
+ help=(
148
+ "Size of the tiles to use for prediction, (if the data "
149
+ "is not 3D pass the last value as -1 e.g. --tile_size 64 64 -1)."
150
+ ),
151
+ click_type=click.Tuple([int, int, int]),
152
+ callback=handle_2D_3D_callback,
153
+ ),
154
+ ] = None,
155
+ tile_overlap: Annotated[
156
+ click.Tuple,
157
+ typer.Option(
158
+ help=(
159
+ "Overlap between tiles, (if the data is not 3D pass the last value as "
160
+ "-1 e.g. --tile_overlap 64 64 -1)."
161
+ ),
162
+ click_type=click.Tuple([int, int, int]),
163
+ callback=handle_2D_3D_callback,
164
+ ),
165
+ ] = (48, 48, -1),
166
+ axes: Annotated[
167
+ str | None,
168
+ typer.Option(
169
+ help="Axes of the input data. If unused the data is assumed to have the "
170
+ "same axes as the original training data."
171
+ ),
172
+ ] = None,
173
+ data_type: Annotated[
174
+ click.Choice,
175
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the input data."),
176
+ ] = "tiff",
177
+ tta_transforms: Annotated[
178
+ bool,
179
+ typer.Option(
180
+ "--tta-transforms/--no-tta-transforms",
181
+ "-t/-T",
182
+ help="Whether to apply test-time augmentation.",
183
+ ),
184
+ ] = False,
185
+ write_type: Annotated[
186
+ click.Choice,
187
+ typer.Option(
188
+ click_type=click.Choice(["tiff"]), help="Type of the output data."
189
+ ),
190
+ ] = "tiff",
191
+ # TODO: could make dataloader_params as json, necessary?
192
+ work_dir: Annotated[
193
+ Path | None,
194
+ typer.Option(
195
+ "--work-dir",
196
+ "-wd",
197
+ help=("Path to working directory."),
198
+ exists=True,
199
+ file_okay=False,
200
+ dir_okay=True,
201
+ ),
202
+ ] = None,
203
+ prediction_dir: Annotated[
204
+ Path,
205
+ typer.Option(
206
+ "--prediction-dir",
207
+ "-pd",
208
+ help=(
209
+ "Directory to save predictions to. If not an abosulte path it will be "
210
+ "relative to the set working directory."
211
+ ),
212
+ file_okay=False,
213
+ dir_okay=True,
214
+ ),
215
+ ] = Path("predictions"),
216
+ ):
217
+ """Create and save predictions from CAREamics models."""
218
+ engine = CAREamist(source=model, work_dir=work_dir)
219
+ engine.predict_to_disk(
220
+ source=source,
221
+ batch_size=batch_size,
222
+ tile_size=tile_size,
223
+ tile_overlap=tile_overlap,
224
+ axes=axes,
225
+ data_type=data_type,
226
+ tta_transforms=tta_transforms,
227
+ write_type=write_type,
228
+ prediction_dir=prediction_dir,
229
+ )
230
+
231
+
232
+ def run():
233
+ """CLI Entry point."""
234
+ app()
careamics/cli/utils.py ADDED
@@ -0,0 +1,27 @@
1
+ """Utility functions for the CAREamics CLI."""
2
+
3
+
4
+ def handle_2D_3D_callback(
5
+ value: tuple[int, int, int] | None,
6
+ ) -> tuple[int, ...] | None:
7
+ """
8
+ Callback for options that require 2D or 3D inputs.
9
+
10
+ In the case of 2D, the 3rd element should be set to -1.
11
+
12
+ Parameters
13
+ ----------
14
+ value : (int, int, int)
15
+ Tile size value.
16
+
17
+ Returns
18
+ -------
19
+ (int, int, int) | (int, int)
20
+ If the last element in `value` is -1 the tuple is reduced to the first two
21
+ values.
22
+ """
23
+ if value is None:
24
+ return value
25
+ if value[2] == -1:
26
+ return value[:2]
27
+ return value