careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,61 @@
1
+ from collections.abc import Sequence
2
+
3
+ from .patching_strategy_protocol import PatchSpecs
4
+
5
+
6
+ class WholeSamplePatchingStrategy:
7
+ # TODO: warn this strategy should only be used with batch size = 1
8
+ # for the case of multiple image stacks with different dimensions
9
+
10
+ # TODO: docs
11
+ def __init__(self, data_shapes: Sequence[Sequence[int]]):
12
+ self.data_shapes = data_shapes
13
+
14
+ self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
15
+
16
+ @property
17
+ def n_patches(self) -> int:
18
+ return len(self.patch_specs)
19
+
20
+ def get_patch_spec(self, index: int) -> PatchSpecs:
21
+ return self.patch_specs[index]
22
+
23
+ # Note: this is used by the FileIterSampler
24
+ def get_patch_indices(self, data_idx: int) -> Sequence[int]:
25
+ """
26
+ Get the patch indices will return patches for a specific `image_stack`.
27
+
28
+ The `image_stack` corresponds to the given `data_idx`.
29
+
30
+ Parameters
31
+ ----------
32
+ data_idx : int
33
+ An index that corresponds to a given `image_stack`.
34
+
35
+ Returns
36
+ -------
37
+ sequence of int
38
+ A sequence of patch indices, that when used to index the `CAREamicsDataset
39
+ will return a patch that comes from the `image_stack` corresponding to the
40
+ given `data_idx`.
41
+ """
42
+ return [
43
+ i
44
+ for i, patch_spec in enumerate(self.patch_specs)
45
+ if patch_spec["data_idx"] == data_idx
46
+ ]
47
+
48
+ def _initialize_patch_specs(self) -> list[PatchSpecs]:
49
+ patch_specs: list[PatchSpecs] = []
50
+ for data_idx, data_shape in enumerate(self.data_shapes):
51
+ spatial_shape = data_shape[2:]
52
+ for sample_idx in range(data_shape[0]):
53
+ patch_specs.append(
54
+ {
55
+ "data_idx": data_idx,
56
+ "sample_idx": sample_idx,
57
+ "coords": tuple(0 for _ in spatial_shape),
58
+ "patch_size": spatial_shape,
59
+ }
60
+ )
61
+ return patch_specs
@@ -0,0 +1,15 @@
1
+ """Functions relating reading and writing image files."""
2
+
3
+ __all__ = [
4
+ "ReadFunc",
5
+ "SupportedWriteType",
6
+ "WriteFunc",
7
+ "get_read_func",
8
+ "get_write_func",
9
+ "read",
10
+ "write",
11
+ ]
12
+
13
+ from . import read, write
14
+ from .read import ReadFunc, get_read_func
15
+ from .write import SupportedWriteType, WriteFunc, get_write_func
@@ -0,0 +1,11 @@
1
+ """Functions relating to reading image files of different formats."""
2
+
3
+ __all__ = [
4
+ "ReadFunc",
5
+ "get_read_func",
6
+ "read_tiff",
7
+ "read_zarr",
8
+ ]
9
+
10
+ from .get_func import ReadFunc, get_read_func
11
+ from .tiff import read_tiff
@@ -0,0 +1,57 @@
1
+ """Module to get read functions."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Protocol, Union
6
+
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.support import SupportedData
10
+
11
+ from .tiff import read_tiff
12
+
13
+
14
+ # This is very strict, function signature has to match including arg names
15
+ # See WriteFunc notes
16
+ class ReadFunc(Protocol):
17
+ """Protocol for type hinting read functions."""
18
+
19
+ def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
20
+ """
21
+ Type hinted callables must match this function signature (not including self).
22
+
23
+ Parameters
24
+ ----------
25
+ file_path : pathlib.Path
26
+ Path to file.
27
+ *args
28
+ Other positional arguments.
29
+ **kwargs
30
+ Other keyword arguments.
31
+ """
32
+
33
+
34
+ READ_FUNCS: dict[SupportedData, ReadFunc] = {
35
+ SupportedData.TIFF: read_tiff,
36
+ }
37
+
38
+
39
+ def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
40
+ """
41
+ Get the read function for the data type.
42
+
43
+ Parameters
44
+ ----------
45
+ data_type : SupportedData
46
+ Data type.
47
+
48
+ Returns
49
+ -------
50
+ callable
51
+ Read function.
52
+ """
53
+ if data_type in READ_FUNCS:
54
+ data_type = SupportedData(data_type) # mypy complaining about dict key type
55
+ return READ_FUNCS[data_type]
56
+ else:
57
+ raise NotImplementedError(f"Data type '{data_type}' is not supported.")
@@ -0,0 +1,58 @@
1
+ """Functions to read tiff images."""
2
+
3
+ import logging
4
+ from fnmatch import fnmatch
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import tifffile
9
+
10
+ from careamics.config.support import SupportedData
11
+ from careamics.utils.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
17
+ """
18
+ Read a tiff file and return a numpy array.
19
+
20
+ Parameters
21
+ ----------
22
+ file_path : Path
23
+ Path to a file.
24
+ *args : list
25
+ Additional arguments.
26
+ **kwargs : dict
27
+ Additional keyword arguments.
28
+
29
+ Returns
30
+ -------
31
+ np.ndarray
32
+ Resulting array.
33
+
34
+ Raises
35
+ ------
36
+ ValueError
37
+ If the file failed to open.
38
+ OSError
39
+ If the file failed to open.
40
+ ValueError
41
+ If the file is not a valid tiff.
42
+ ValueError
43
+ If the data dimensions are incorrect.
44
+ ValueError
45
+ If the axes length is incorrect.
46
+ """
47
+ if fnmatch(
48
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
49
+ ):
50
+ try:
51
+ array = tifffile.imread(file_path)
52
+ except (ValueError, OSError) as e:
53
+ logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
54
+ raise e
55
+ else:
56
+ raise ValueError(f"File {file_path} is not a valid tiff.")
57
+
58
+ return array
@@ -0,0 +1,15 @@
1
+ """Functions relating to writing image files of different formats."""
2
+
3
+ __all__ = [
4
+ "SupportedWriteType",
5
+ "WriteFunc",
6
+ "get_write_func",
7
+ "write_tiff",
8
+ ]
9
+
10
+ from .get_func import (
11
+ SupportedWriteType,
12
+ WriteFunc,
13
+ get_write_func,
14
+ )
15
+ from .tiff import write_tiff
@@ -0,0 +1,63 @@
1
+ """Module to get write functions."""
2
+
3
+ from pathlib import Path
4
+ from typing import Literal, Protocol
5
+
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedData
9
+
10
+ from .tiff import write_tiff
11
+
12
+ SupportedWriteType = Literal["tiff", "zarr", "custom"]
13
+
14
+
15
+ # This is very strict, arguments have to be called file_path & img
16
+ # Alternative? - doesn't capture *args & **kwargs
17
+ # WriteFunc = Callable[[Path, NDArray], None]
18
+ class WriteFunc(Protocol):
19
+ """Protocol for type hinting write functions."""
20
+
21
+ def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
22
+ """
23
+ Type hinted callables must match this function signature (not including self).
24
+
25
+ Parameters
26
+ ----------
27
+ file_path : pathlib.Path
28
+ Path to file.
29
+ img : numpy.ndarray
30
+ Image data to save.
31
+ *args
32
+ Other positional arguments.
33
+ **kwargs
34
+ Other keyword arguments.
35
+ """
36
+
37
+
38
+ WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
39
+ SupportedData.TIFF: write_tiff,
40
+ }
41
+
42
+
43
+ def get_write_func(data_type: SupportedWriteType) -> WriteFunc:
44
+ """
45
+ Get the write function for the data type.
46
+
47
+ Parameters
48
+ ----------
49
+ data_type : {"tiff", "custom"}
50
+ Data type.
51
+
52
+ Returns
53
+ -------
54
+ callable
55
+ Write function.
56
+ """
57
+ # error raised here if not supported
58
+ data_type_ = SupportedData(data_type) # new variable for mypy
59
+ # error if no write func.
60
+ if data_type_ not in WRITE_FUNCS:
61
+ raise NotImplementedError(f"No write function for data type '{data_type}'.")
62
+
63
+ return WRITE_FUNCS[data_type_]
@@ -0,0 +1,40 @@
1
+ """Write tiff function."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+
6
+ import tifffile
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.support import SupportedData
10
+
11
+
12
+ def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
13
+ # TODO: add link to tiffile docs for args kwrgs?
14
+ """
15
+ Write tiff files.
16
+
17
+ Parameters
18
+ ----------
19
+ file_path : pathlib.Path
20
+ Path to file.
21
+ img : numpy.ndarray
22
+ Image data to save.
23
+ *args
24
+ Positional arguments passed to `tifffile.imwrite`.
25
+ **kwargs
26
+ Keyword arguments passed to `tifffile.imwrite`.
27
+
28
+ Raises
29
+ ------
30
+ ValueError
31
+ When the file extension of `file_path` does not match the Unix shell-style
32
+ pattern '*.tif*'.
33
+ """
34
+ if not fnmatch(
35
+ file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
36
+ ):
37
+ raise ValueError(
38
+ f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
39
+ )
40
+ tifffile.imwrite(file_path, img, *args, **kwargs)
@@ -0,0 +1,32 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ __all__ = [
4
+ "DataStatsCallback",
5
+ "FCNModule",
6
+ "HyperParametersCallback",
7
+ "MicroSplitDataModule",
8
+ "PredictDataModule",
9
+ "ProgressBarCallback",
10
+ "TrainDataModule",
11
+ "VAEModule",
12
+ "create_careamics_module",
13
+ "create_microsplit_predict_datamodule",
14
+ "create_microsplit_train_datamodule",
15
+ "create_predict_datamodule",
16
+ "create_train_datamodule",
17
+ "create_unet_based_module",
18
+ "create_vae_based_module",
19
+ ]
20
+
21
+ from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
22
+ from .lightning_module import FCNModule, VAEModule, create_careamics_module
23
+ from .microsplit_data_module import (
24
+ MicroSplitDataModule,
25
+ create_microsplit_predict_datamodule,
26
+ create_microsplit_train_datamodule,
27
+ )
28
+ from .predict_data_module import PredictDataModule, create_predict_datamodule
29
+ from .train_data_module import (
30
+ TrainDataModule,
31
+ create_train_datamodule,
32
+ )
@@ -0,0 +1,13 @@
1
+ """Callbacks module."""
2
+
3
+ __all__ = [
4
+ "DataStatsCallback",
5
+ "HyperParametersCallback",
6
+ "PredictionWriterCallback",
7
+ "ProgressBarCallback",
8
+ ]
9
+
10
+ from .data_stats_callback import DataStatsCallback
11
+ from .hyperparameters_callback import HyperParametersCallback
12
+ from .prediction_writer_callback import PredictionWriterCallback
13
+ from .progress_bar_callback import ProgressBarCallback
@@ -0,0 +1,33 @@
1
+ """Data statistics callback."""
2
+
3
+ import pytorch_lightning as L
4
+ from pytorch_lightning.callbacks import Callback
5
+
6
+
7
+ class DataStatsCallback(Callback):
8
+ """Callback to update model's data statistics from datamodule.
9
+
10
+ This callback ensures that the model has access to the data statistics (mean, std)
11
+ calculated by the datamodule before training starts.
12
+ """
13
+
14
+ def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
15
+ """Called when trainer is setting up.
16
+
17
+ Parameters
18
+ ----------
19
+ trainer : Lightning.Trainer
20
+ PyTorch Lightning trainer.
21
+ module : Lightning.LightningModule
22
+ Lightning module.
23
+ stage : str
24
+ Current stage (fit, validate, test, or predict).
25
+ """
26
+ if stage == "fit":
27
+ # Get data statistics from datamodule
28
+ (data_mean, data_std), _ = trainer.datamodule.get_data_stats()
29
+
30
+ # Set data statistics in the model's likelihood module
31
+ module.noise_model_likelihood.set_data_stats(
32
+ data_mean=data_mean["target"], data_std=data_std["target"]
33
+ )
@@ -0,0 +1,49 @@
1
+ """Callback saving CAREamics configuration as hyperparameters in the model."""
2
+
3
+ from pytorch_lightning import LightningModule, Trainer
4
+ from pytorch_lightning.callbacks import Callback
5
+
6
+ from careamics.config import Configuration
7
+
8
+
9
+ class HyperParametersCallback(Callback):
10
+ """
11
+ Callback allowing saving CAREamics configuration as hyperparameters in the model.
12
+
13
+ This allows saving the configuration as dictionary in the checkpoints, and
14
+ loading it subsequently in a CAREamist instance.
15
+
16
+ Parameters
17
+ ----------
18
+ config : Configuration
19
+ CAREamics configuration to be saved as hyperparameter in the model.
20
+
21
+ Attributes
22
+ ----------
23
+ config : Configuration
24
+ CAREamics configuration to be saved as hyperparameter in the model.
25
+ """
26
+
27
+ def __init__(self, config: Configuration) -> None:
28
+ """
29
+ Constructor.
30
+
31
+ Parameters
32
+ ----------
33
+ config : Configuration
34
+ CAREamics configuration to be saved as hyperparameter in the model.
35
+ """
36
+ self.config = config
37
+
38
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
39
+ """
40
+ Update the hyperparameters of the model with the configuration on train start.
41
+
42
+ Parameters
43
+ ----------
44
+ trainer : Trainer
45
+ PyTorch Lightning trainer, unused.
46
+ pl_module : LightningModule
47
+ PyTorch Lightning module.
48
+ """
49
+ pl_module.hparams.update(self.config.model_dump())
@@ -0,0 +1,20 @@
1
+ """A package for the `PredictionWriterCallback` class and utilities."""
2
+
3
+ __all__ = [
4
+ "CacheTiles",
5
+ "PredictionWriterCallback",
6
+ "WriteImage",
7
+ "WriteStrategy",
8
+ "WriteTilesZarr",
9
+ "create_write_strategy",
10
+ "select_write_extension",
11
+ "select_write_func",
12
+ ]
13
+
14
+ from .prediction_writer_callback import PredictionWriterCallback
15
+ from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr
16
+ from .write_strategy_factory import (
17
+ create_write_strategy,
18
+ select_write_extension,
19
+ select_write_func,
20
+ )
@@ -0,0 +1,56 @@
1
+ """Module containing file path utilities for `WriteStrategy` to use."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
7
+
8
+
9
+ # TODO: move to datasets package ?
10
+ def get_sample_file_path(
11
+ dataset: Union[IterableTiledPredDataset, IterablePredDataset], sample_id: int
12
+ ) -> Path:
13
+ """
14
+ Get the file path for a particular sample.
15
+
16
+ Parameters
17
+ ----------
18
+ dataset : IterableTiledPredDataset or IterablePredDataset
19
+ Dataset.
20
+ sample_id : int
21
+ Sample ID, the index of the file in the dataset `dataset`.
22
+
23
+ Returns
24
+ -------
25
+ Path
26
+ The file path corresponding to the sample with the ID `sample_id`.
27
+ """
28
+ return dataset.data_files[sample_id]
29
+
30
+
31
+ def create_write_file_path(
32
+ dirpath: Path, file_path: Path, write_extension: str
33
+ ) -> Path:
34
+ """
35
+ Create the file name for the output file.
36
+
37
+ Takes the original file path, changes the directory to `dirpath` and changes
38
+ the extension to `write_extension`.
39
+
40
+ Parameters
41
+ ----------
42
+ dirpath : pathlib.Path
43
+ The output directory to write file to.
44
+ file_path : pathlib.Path
45
+ The original file path.
46
+ write_extension : str
47
+ The extension that output files should have.
48
+
49
+ Returns
50
+ -------
51
+ Path
52
+ The output file path.
53
+ """
54
+ file_name = Path(file_path.stem).with_suffix(write_extension)
55
+ file_path = dirpath / file_name
56
+ return file_path