careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,150 @@
1
+ """N2V manipulation transform."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
9
+ from careamics.transforms.transform import Transform
10
+
11
+ from .pixel_manipulation import median_manipulate, uniform_manipulate
12
+ from .struct_mask_parameters import StructMaskParameters
13
+
14
+
15
+ class N2VManipulate(Transform):
16
+ """
17
+ Default augmentation for the N2V model.
18
+
19
+ This transform expects C(Z)YX dimensions.
20
+
21
+ Parameters
22
+ ----------
23
+ roi_size : int, optional
24
+ Size of the replacement area, by default 11.
25
+ masked_pixel_percentage : float, optional
26
+ Percentage of pixels to mask, by default 0.2.
27
+ strategy : Literal[ "uniform", "median" ], optional
28
+ Replaccement strategy, uniform or median, by default uniform.
29
+ remove_center : bool, optional
30
+ Whether to remove central pixel from patch, by default True.
31
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
32
+ StructN2V mask axis, by default "none".
33
+ struct_mask_span : int, optional
34
+ StructN2V mask span, by default 5.
35
+ seed : Optional[int], optional
36
+ Random seed, by default None.
37
+
38
+ Attributes
39
+ ----------
40
+ masked_pixel_percentage : float
41
+ Percentage of pixels to mask.
42
+ roi_size : int
43
+ Size of the replacement area.
44
+ strategy : Literal[ "uniform", "median" ]
45
+ Replaccement strategy, uniform or median.
46
+ remove_center : bool
47
+ Whether to remove central pixel from patch.
48
+ struct_mask : Optional[StructMaskParameters]
49
+ StructN2V mask parameters.
50
+ rng : Generator
51
+ Random number generator.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ roi_size: int = 11,
57
+ masked_pixel_percentage: float = 0.2,
58
+ strategy: Literal[
59
+ "uniform", "median"
60
+ ] = SupportedPixelManipulation.UNIFORM.value,
61
+ remove_center: bool = True,
62
+ struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
63
+ struct_mask_span: int = 5,
64
+ seed: int | None = None,
65
+ ):
66
+ """Constructor.
67
+
68
+ Parameters
69
+ ----------
70
+ roi_size : int, optional
71
+ Size of the replacement area, by default 11.
72
+ masked_pixel_percentage : float, optional
73
+ Percentage of pixels to mask, by default 0.2.
74
+ strategy : Literal[ "uniform", "median" ], optional
75
+ Replaccement strategy, uniform or median, by default uniform.
76
+ remove_center : bool, optional
77
+ Whether to remove central pixel from patch, by default True.
78
+ struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
79
+ StructN2V mask axis, by default "none".
80
+ struct_mask_span : int, optional
81
+ StructN2V mask span, by default 5.
82
+ seed : Optional[int], optional
83
+ Random seed, by default None.
84
+ """
85
+ self.masked_pixel_percentage = masked_pixel_percentage
86
+ self.roi_size = roi_size
87
+ self.strategy = strategy
88
+ self.remove_center = remove_center # TODO is this ever used?
89
+
90
+ if struct_mask_axis == SupportedStructAxis.NONE:
91
+ self.struct_mask: StructMaskParameters | None = None
92
+ else:
93
+ self.struct_mask = StructMaskParameters(
94
+ axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1,
95
+ span=struct_mask_span,
96
+ )
97
+
98
+ # numpy random generator
99
+ self.rng = np.random.default_rng(seed=seed)
100
+
101
+ def __call__(
102
+ self, patch: NDArray, *args: Any, **kwargs: Any
103
+ ) -> tuple[NDArray, NDArray, NDArray]:
104
+ """Apply the transform to the image.
105
+
106
+ Parameters
107
+ ----------
108
+ patch : np.ndarray
109
+ Image patch, 2D or 3D, shape C(Z)YX.
110
+ *args : Any
111
+ Additional arguments, unused.
112
+ **kwargs : Any
113
+ Additional keyword arguments, unused.
114
+
115
+ Returns
116
+ -------
117
+ tuple[np.ndarray, np.ndarray, np.ndarray]
118
+ Masked patch, original patch, and mask.
119
+ """
120
+ masked = np.zeros_like(patch)
121
+ mask = np.zeros_like(patch)
122
+ if self.strategy == SupportedPixelManipulation.UNIFORM:
123
+ # Iterate over the channels to apply manipulation separately
124
+ for c in range(patch.shape[0]):
125
+ masked[c, ...], mask[c, ...] = uniform_manipulate(
126
+ patch=patch[c, ...],
127
+ mask_pixel_percentage=self.masked_pixel_percentage,
128
+ subpatch_size=self.roi_size,
129
+ remove_center=self.remove_center,
130
+ struct_params=self.struct_mask,
131
+ rng=self.rng,
132
+ )
133
+ elif self.strategy == SupportedPixelManipulation.MEDIAN:
134
+ # Iterate over the channels to apply manipulation separately
135
+ for c in range(patch.shape[0]):
136
+ masked[c, ...], mask[c, ...] = median_manipulate(
137
+ patch=patch[c, ...],
138
+ mask_pixel_percentage=self.masked_pixel_percentage,
139
+ subpatch_size=self.roi_size,
140
+ struct_params=self.struct_mask,
141
+ rng=self.rng,
142
+ )
143
+ else:
144
+ raise ValueError(f"Unknown masking strategy ({self.strategy}).")
145
+
146
+ # TODO: Output does not match other transforms, how to resolve?
147
+ # - Don't include in Compose and apply after if algorithm is N2V?
148
+ # - or just don't return patch? but then mask is in the target position
149
+ # TODO why return patch?
150
+ return masked, patch, mask
@@ -0,0 +1,149 @@
1
+ """N2V manipulation transform for PyTorch."""
2
+
3
+ import platform
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
9
+ from careamics.config.transformations import N2VManipulateConfig
10
+
11
+ from .pixel_manipulation_torch import (
12
+ median_manipulate_torch,
13
+ uniform_manipulate_torch,
14
+ )
15
+ from .struct_mask_parameters import StructMaskParameters
16
+
17
+
18
+ class N2VManipulateTorch:
19
+ """
20
+ Default augmentation for the N2V model.
21
+
22
+ This transform expects C(Z)YX dimensions.
23
+
24
+ Parameters
25
+ ----------
26
+ n2v_manipulate_config : N2VManipulateConfig
27
+ N2V manipulation configuration.
28
+ seed : Optional[int], optional
29
+ Random seed, by default None.
30
+ device : str
31
+ The device on which operations take place, e.g. "cuda", "cpu" or "mps".
32
+
33
+ Attributes
34
+ ----------
35
+ masked_pixel_percentage : float
36
+ Percentage of pixels to mask.
37
+ roi_size : int
38
+ Size of the replacement area.
39
+ strategy : Literal[ "uniform", "median" ]
40
+ Replacement strategy, uniform or median.
41
+ remove_center : bool
42
+ Whether to remove central pixel from patch.
43
+ struct_mask : Optional[StructMaskParameters]
44
+ StructN2V mask parameters.
45
+ rng : Generator
46
+ Random number generator.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ n2v_manipulate_config: N2VManipulateConfig,
52
+ seed: int | None = None,
53
+ device: str | None = None,
54
+ ):
55
+ """Constructor.
56
+
57
+ Parameters
58
+ ----------
59
+ n2v_manipulate_config : N2VManipulateConfig
60
+ N2V manipulation configuration.
61
+ seed : Optional[int], optional
62
+ Random seed, by default None.
63
+ device : str
64
+ The device on which operations take place, e.g. "cuda", "cpu" or "mps".
65
+ """
66
+ self.masked_pixel_percentage = n2v_manipulate_config.masked_pixel_percentage
67
+ self.roi_size = n2v_manipulate_config.roi_size
68
+ self.strategy = n2v_manipulate_config.strategy
69
+ self.remove_center = n2v_manipulate_config.remove_center
70
+
71
+ if n2v_manipulate_config.struct_mask_axis == SupportedStructAxis.NONE:
72
+ self.struct_mask: StructMaskParameters | None = None
73
+ else:
74
+ self.struct_mask = StructMaskParameters(
75
+ axis=(
76
+ 0
77
+ if n2v_manipulate_config.struct_mask_axis
78
+ == SupportedStructAxis.HORIZONTAL
79
+ else 1
80
+ ),
81
+ span=n2v_manipulate_config.struct_mask_span,
82
+ )
83
+
84
+ # PyTorch random generator
85
+ # TODO refactor into careamics.utils.torch_utils.get_device
86
+ if device is None:
87
+ if torch.cuda.is_available():
88
+ device = "cuda"
89
+ elif torch.backends.mps.is_available() and platform.processor() in (
90
+ "arm",
91
+ "arm64",
92
+ ):
93
+ device = "mps"
94
+ else:
95
+ device = "cpu"
96
+
97
+ self.rng = (
98
+ torch.Generator(device=device).manual_seed(seed)
99
+ if seed is not None
100
+ else torch.Generator(device=device)
101
+ )
102
+
103
+ def __call__(
104
+ self, batch: torch.Tensor, *args: Any, **kwargs: Any
105
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
106
+ """Apply the transform to the image.
107
+
108
+ Parameters
109
+ ----------
110
+ batch : torch.Tensor
111
+ Batch if image patches, 2D or 3D, shape BC(Z)YX.
112
+ *args : Any
113
+ Additional arguments, unused.
114
+ **kwargs : Any
115
+ Additional keyword arguments, unused.
116
+
117
+ Returns
118
+ -------
119
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
120
+ Masked patch, original patch, and mask.
121
+ """
122
+ masked = torch.zeros_like(batch)
123
+ mask = torch.zeros_like(batch, dtype=torch.uint8)
124
+
125
+ if self.strategy == SupportedPixelManipulation.UNIFORM:
126
+ # Iterate over the channels to apply manipulation separately
127
+ for c in range(batch.shape[1]):
128
+ masked[:, c, ...], mask[:, c, ...] = uniform_manipulate_torch(
129
+ patch=batch[:, c, ...],
130
+ mask_pixel_percentage=self.masked_pixel_percentage,
131
+ subpatch_size=self.roi_size,
132
+ remove_center=self.remove_center,
133
+ struct_params=self.struct_mask,
134
+ rng=self.rng,
135
+ )
136
+ elif self.strategy == SupportedPixelManipulation.MEDIAN:
137
+ # Iterate over the channels to apply manipulation separately
138
+ for c in range(batch.shape[1]):
139
+ masked[:, c, ...], mask[:, c, ...] = median_manipulate_torch(
140
+ batch=batch[:, c, ...],
141
+ mask_pixel_percentage=self.masked_pixel_percentage,
142
+ subpatch_size=self.roi_size,
143
+ struct_params=self.struct_mask,
144
+ rng=self.rng,
145
+ )
146
+ else:
147
+ raise ValueError(f"Unknown masking strategy ({self.strategy}).")
148
+
149
+ return masked, batch, mask
@@ -0,0 +1,374 @@
1
+ """Normalization and denormalization transforms for image patches."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ from numpy.typing import NDArray
6
+ from torch import Tensor
7
+
8
+ from careamics.transforms.transform import Transform
9
+
10
+
11
+ def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
12
+ """Reshape stats to match the number of dimensions of the input image.
13
+
14
+ This allows to broadcast the stats (mean or std) to the image dimensions, and
15
+ thus directly perform a vectorial calculation.
16
+
17
+ Parameters
18
+ ----------
19
+ stats : list of float
20
+ List of stats, mean or standard deviation.
21
+ ndim : int
22
+ Number of dimensions of the image, including the C channel.
23
+
24
+ Returns
25
+ -------
26
+ NDArray
27
+ Reshaped stats.
28
+ """
29
+ return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
30
+
31
+
32
+ def _reshape_stats_torch(stats: list[float], ndim: int) -> Tensor:
33
+ """Torch equivalent of `_reshape_stats` for broadcasting over image dims.
34
+
35
+ Parameters
36
+ ----------
37
+ stats : list of float
38
+ List of stats, mean or standard deviation.
39
+ ndim : int
40
+ Number of dimensions of the tensor, including the C channel.
41
+
42
+ Returns
43
+ -------
44
+ Tensor
45
+ Reshaped stats tensor.
46
+ """
47
+ t = torch.tensor(stats)
48
+ # Add singleton dimensions to match input tensor ndim for broadcasting
49
+ return t[(..., *[None] * (ndim - 1))]
50
+
51
+
52
+ class Normalize(Transform):
53
+ """
54
+ Normalize an image or image patch.
55
+
56
+ Normalization is a zero mean and unit variance. This transform expects C(Z)YX
57
+ dimensions.
58
+
59
+ Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
60
+ division by zero and that it returns a float32 image.
61
+
62
+ Parameters
63
+ ----------
64
+ image_means : list of float
65
+ Mean value per channel.
66
+ image_stds : list of float
67
+ Standard deviation value per channel.
68
+ target_means : list of float, optional
69
+ Target mean value per channel, by default None.
70
+ target_stds : list of float, optional
71
+ Target standard deviation value per channel, by default None.
72
+
73
+ Attributes
74
+ ----------
75
+ image_means : list of float
76
+ Mean value per channel.
77
+ image_stds : list of float
78
+ Standard deviation value per channel.
79
+ target_means :list of float, optional
80
+ Target mean value per channel, by default None.
81
+ target_stds : list of float, optional
82
+ Target standard deviation value per channel, by default None.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ image_means: list[float],
88
+ image_stds: list[float],
89
+ target_means: list[float] | None = None,
90
+ target_stds: list[float] | None = None,
91
+ ):
92
+ """Constructor.
93
+
94
+ Parameters
95
+ ----------
96
+ image_means : list of float
97
+ Mean value per channel.
98
+ image_stds : list of float
99
+ Standard deviation value per channel.
100
+ target_means : list of float, optional
101
+ Target mean value per channel, by default None.
102
+ target_stds : list of float, optional
103
+ Target standard deviation value per channel, by default None.
104
+ """
105
+ self.image_means = image_means
106
+ self.image_stds = image_stds
107
+ self.target_means = target_means
108
+ self.target_stds = target_stds
109
+
110
+ self.eps = 1e-6
111
+
112
+ def __call__(
113
+ self,
114
+ patch: np.ndarray,
115
+ target: NDArray | None = None,
116
+ **additional_arrays: NDArray,
117
+ ) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
118
+ """Apply the transform to the source patch and the target (optional).
119
+
120
+ Parameters
121
+ ----------
122
+ patch : NDArray
123
+ Patch, 2D or 3D, shape C(Z)YX.
124
+ target : NDArray, optional
125
+ Target for the patch, by default None.
126
+ **additional_arrays : NDArray
127
+ Additional arrays that will be transformed identically to `patch` and
128
+ `target`.
129
+
130
+ Returns
131
+ -------
132
+ tuple of NDArray
133
+ Transformed patch and target, the target can be returned as `None`.
134
+ """
135
+ if len(self.image_means) != patch.shape[0]:
136
+ raise ValueError(
137
+ f"Number of means (got a list of size {len(self.image_means)}) and "
138
+ f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
139
+ )
140
+ if len(additional_arrays) != 0:
141
+ raise NotImplementedError(
142
+ "Transforming additional arrays is currently not supported for "
143
+ "`Normalize`."
144
+ )
145
+
146
+ # reshape mean and std and apply the normalization to the patch
147
+ means = _reshape_stats(self.image_means, patch.ndim)
148
+ stds = _reshape_stats(self.image_stds, patch.ndim)
149
+ norm_patch = self._apply(patch, means, stds)
150
+
151
+ # same for the target patch
152
+ if target is None:
153
+ norm_target = None
154
+ else:
155
+ if not self.target_means or not self.target_stds:
156
+ raise ValueError(
157
+ "Target means and standard deviations must be provided "
158
+ "if target is not None."
159
+ )
160
+ if len(self.target_means) == 0 and len(self.target_stds) == 0:
161
+ raise ValueError(
162
+ "Target means and standard deviations must be provided "
163
+ "if target is not None."
164
+ )
165
+ if len(self.target_means) != target.shape[0]:
166
+ raise ValueError(
167
+ "Target means and standard deviations must have the same length "
168
+ "as the target."
169
+ )
170
+ target_means = _reshape_stats(self.target_means, target.ndim)
171
+ target_stds = _reshape_stats(self.target_stds, target.ndim)
172
+ norm_target = self._apply(target, target_means, target_stds)
173
+
174
+ return norm_patch, norm_target, additional_arrays
175
+
176
+ def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
177
+ """
178
+ Apply the transform to the image.
179
+
180
+ Parameters
181
+ ----------
182
+ patch : NDArray
183
+ Image patch, 2D or 3D, shape C(Z)YX.
184
+ mean : NDArray
185
+ Mean values.
186
+ std : NDArray
187
+ Standard deviations.
188
+
189
+ Returns
190
+ -------
191
+ NDArray
192
+ Normalized image patch.
193
+ """
194
+ return ((patch - mean) / (std + self.eps)).astype(np.float32)
195
+
196
+
197
+ class Denormalize:
198
+ """
199
+ Denormalize an image.
200
+
201
+ Denormalization is performed expecting a zero mean and unit variance input. This
202
+ transform expects C(Z)YX dimensions.
203
+
204
+ Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
205
+ division by zero during the normalization step, which is taken into account during
206
+ denormalization.
207
+
208
+ Parameters
209
+ ----------
210
+ image_means : list or tuple of float
211
+ Mean value per channel.
212
+ image_stds : list or tuple of float
213
+ Standard deviation value per channel.
214
+
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ image_means: list[float],
220
+ image_stds: list[float],
221
+ ):
222
+ """Constructor.
223
+
224
+ Parameters
225
+ ----------
226
+ image_means : list of float
227
+ Mean value per channel.
228
+ image_stds : list of float
229
+ Standard deviation value per channel.
230
+ """
231
+ self.image_means = image_means
232
+ self.image_stds = image_stds
233
+
234
+ self.eps = 1e-6
235
+
236
+ def __call__(self, patch: NDArray) -> NDArray:
237
+ """Reverse the normalization operation for a batch of patches.
238
+
239
+ Parameters
240
+ ----------
241
+ patch : NDArray
242
+ Patch, 2D or 3D, shape BC(Z)YX.
243
+
244
+ Returns
245
+ -------
246
+ NDArray
247
+ Transformed array.
248
+ """
249
+ # if len(self.image_means) != patch.shape[1]:
250
+ # raise ValueError(
251
+ # f"Number of means (got a list of size {len(self.image_means)}) and "
252
+ # f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
253
+ # f"match."
254
+ # )
255
+ # TODO for pn2v channel handling needs to be changed
256
+ means = _reshape_stats(self.image_means, patch.ndim)
257
+ stds = _reshape_stats(self.image_stds, patch.ndim)
258
+
259
+ denorm_array = self._apply(
260
+ patch,
261
+ np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
262
+ np.swapaxes(stds, 0, 1),
263
+ )
264
+
265
+ return denorm_array.astype(np.float32)
266
+
267
+ def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
268
+ """
269
+ Apply the transform to the image.
270
+
271
+ Parameters
272
+ ----------
273
+ array : NDArray
274
+ Image patch, 2D or 3D, shape C(Z)YX.
275
+ mean : NDArray
276
+ Mean values.
277
+ std : NDArray
278
+ Standard deviations.
279
+
280
+ Returns
281
+ -------
282
+ NDArray
283
+ Denormalized image array.
284
+ """
285
+ return array * (std + self.eps) + mean
286
+
287
+
288
+ class TrainDenormalize:
289
+ """
290
+ Denormalize an image tensor for training-time tensors.
291
+
292
+ This class mirrors `Denormalize` but operates on torch tensors. It expects
293
+ the input tensor to have shape BC(Z)YX with the channel dimension at index 1.
294
+
295
+ Parameters
296
+ ----------
297
+ image_means : list or tuple of float
298
+ Mean value per channel.
299
+ image_stds : list or tuple of float
300
+ Standard deviation value per channel.
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ image_means: list[float],
306
+ image_stds: list[float],
307
+ ) -> None:
308
+ """Initialize Denormalize transform.
309
+
310
+ Parameters
311
+ ----------
312
+ image_means : list of float
313
+ Mean values per channel.
314
+ image_stds : list of float
315
+ Standard deviation values per channel.
316
+ """
317
+ self.image_means = image_means
318
+ self.image_stds = image_stds
319
+ self.eps = 1e-6
320
+
321
+ def __call__(self, patch: Tensor) -> Tensor:
322
+ """Reverse the normalization operation for a batch of patches.
323
+
324
+ Parameters
325
+ ----------
326
+ patch : Tensor
327
+ Patch, 2D or 3D, shape BC(Z)YX.
328
+
329
+ Returns
330
+ -------
331
+ Tensor
332
+ Denormalized tensor with dtype float32.
333
+ """
334
+ # if len(self.image_means) != patch.shape[1]:
335
+ # raise ValueError(
336
+ # f"Number of means (got a list of size {len(self.image_means)}) and "
337
+ # f"number of channels (got shape {tuple(patch.shape)} for BC(Z)YX) "
338
+ # f"don't match."
339
+ # )
340
+ # TODO for pn2v channel handling needs to be changed
341
+
342
+ means = _reshape_stats_torch(self.image_means, patch.ndim).to(
343
+ device=patch.device, dtype=patch.dtype
344
+ )
345
+ stds = _reshape_stats_torch(self.image_stds, patch.ndim).to(
346
+ device=patch.device, dtype=patch.dtype
347
+ )
348
+
349
+ denorm_tensor = self._apply(
350
+ patch,
351
+ torch.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
352
+ torch.swapaxes(stds, 0, 1),
353
+ )
354
+
355
+ return denorm_tensor.float()
356
+
357
+ def _apply(self, array: Tensor, mean: Tensor, std: Tensor) -> Tensor:
358
+ """Apply the denormalization to the tensor.
359
+
360
+ Parameters
361
+ ----------
362
+ array : Tensor
363
+ Input tensor.
364
+ mean : Tensor
365
+ Mean values.
366
+ std : Tensor
367
+ Standard deviation values.
368
+
369
+ Returns
370
+ -------
371
+ Tensor
372
+ Denormalized tensor.
373
+ """
374
+ return array * (std + self.eps) + mean