careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,406 @@
1
+ """
2
+ Pixel manipulation methods.
3
+
4
+ Pixel manipulation is used in N2V and similar algorithm to replace the value of
5
+ masked pixels.
6
+ """
7
+
8
+ import numpy as np
9
+
10
+ from .struct_mask_parameters import StructMaskParameters
11
+
12
+
13
+ def _apply_struct_mask(
14
+ patch: np.ndarray,
15
+ coords: np.ndarray,
16
+ struct_params: StructMaskParameters,
17
+ rng: np.random.Generator | None = None,
18
+ ) -> np.ndarray:
19
+ """Apply structN2V masks to patch.
20
+
21
+ Each point in `coords` corresponds to the center of a mask, masks are paremeterized
22
+ by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
23
+ a random value.
24
+
25
+ Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
26
+
27
+ Parameters
28
+ ----------
29
+ patch : np.ndarray
30
+ Patch to be manipulated, 2D or 3D.
31
+ coords : np.ndarray
32
+ Coordinates of the ROI(subpatch) centers.
33
+ struct_params : StructMaskParameters
34
+ Parameters for the structN2V mask (axis and span).
35
+ rng : np.random.Generator or None
36
+ Random number generator.
37
+
38
+ Returns
39
+ -------
40
+ np.ndarray
41
+ Patch with the structN2V mask applied.
42
+ """
43
+ if rng is None:
44
+ rng = np.random.default_rng()
45
+
46
+ # relative axis
47
+ moving_axis = -1 - struct_params.axis
48
+
49
+ # Create a mask array
50
+ mask = np.expand_dims(
51
+ np.ones(struct_params.span), axis=list(range(len(patch.shape) - 1))
52
+ ) # (1, 1, span) or (1, span)
53
+
54
+ # Move the moving axis to the correct position
55
+ # i.e. the axis along which the coordinates should change
56
+ mask = np.moveaxis(mask, -1, moving_axis)
57
+ center = np.array(mask.shape) // 2
58
+
59
+ # Mark the center
60
+ mask[tuple(center.T)] = 0
61
+
62
+ # displacements from center
63
+ dx = np.indices(mask.shape)[:, mask == 1] - center[:, None]
64
+
65
+ # combine all coords (ndim, npts,) with all displacements (ncoords,ndim,)
66
+ mix = dx.T[..., None] + coords.T[None]
67
+ mix = mix.transpose([1, 0, 2]).reshape([mask.ndim, -1]).T
68
+
69
+ # delete entries that are out of bounds
70
+ mix = np.delete(mix, mix[:, moving_axis] < 0, axis=0)
71
+
72
+ max_bound = patch.shape[moving_axis] - 1
73
+ mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
74
+
75
+ # replace neighbouring pixels with random values from flat dist
76
+ patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
77
+
78
+ return patch
79
+
80
+
81
+ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
82
+ """
83
+ Randomly sample a jitter to be applied to the masking grid.
84
+
85
+ This is done to account for cases where the step size is not an integer.
86
+
87
+ Parameters
88
+ ----------
89
+ step : float
90
+ Step size of the grid, output of np.linspace.
91
+ rng : np.random.Generator
92
+ Random number generator.
93
+
94
+ Returns
95
+ -------
96
+ np.ndarray
97
+ Array of random jitter to be added to the grid.
98
+ """
99
+ # Define the random jitter to be added to the grid
100
+ odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
101
+
102
+ # Round the step size to the nearest integer depending on the jitter
103
+ return np.floor(step) if odd_jitter == 0 else np.ceil(step)
104
+
105
+
106
+ def _get_stratified_coords(
107
+ mask_pixel_perc: float,
108
+ shape: tuple[int, ...],
109
+ rng: np.random.Generator | None = None,
110
+ ) -> np.ndarray:
111
+ """
112
+ Generate coordinates of the pixels to mask.
113
+
114
+ Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
115
+ the distance between masked pixels is approximately the same.
116
+
117
+ Parameters
118
+ ----------
119
+ mask_pixel_perc : float
120
+ Actual (quasi) percentage of masked pixels across the whole image. Used in
121
+ calculating the distance between masked pixels across each axis.
122
+ shape : tuple[int, ...]
123
+ Shape of the input patch.
124
+ rng : np.random.Generator or None
125
+ Random number generator.
126
+
127
+ Returns
128
+ -------
129
+ np.ndarray
130
+ Array of coordinates of the masked pixels.
131
+ """
132
+ if len(shape) < 2 or len(shape) > 3:
133
+ raise ValueError(
134
+ "Calculating coordinates is only possible for 2D and 3D patches"
135
+ )
136
+
137
+ if rng is None:
138
+ rng = np.random.default_rng()
139
+
140
+ mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
141
+ np.int32
142
+ )
143
+
144
+ # Define a grid of coordinates for each axis in the input patch and the step size
145
+ pixel_coords = []
146
+ steps = []
147
+ for axis_size in shape:
148
+ # make sure axis size is evenly divisible by box size
149
+ num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
150
+ axis_pixel_coords, step = np.linspace(
151
+ 0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
152
+ )
153
+ # explain
154
+ pixel_coords.append(axis_pixel_coords.T)
155
+ steps.append(step)
156
+
157
+ # Create a meshgrid of coordinates for each axis in the input patch
158
+ coordinate_grid_list = np.meshgrid(*pixel_coords)
159
+ coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
160
+
161
+ grid_random_increment = rng.integers(
162
+ _odd_jitter_func(float(max(steps)), rng) # type: ignore
163
+ * np.ones_like(coordinate_grid).astype(np.int32)
164
+ - 1,
165
+ size=coordinate_grid.shape,
166
+ endpoint=True,
167
+ )
168
+ coordinate_grid += grid_random_increment
169
+ coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
170
+ return coordinate_grid
171
+
172
+
173
+ def _create_subpatch_center_mask(
174
+ subpatch: np.ndarray, center_coords: np.ndarray
175
+ ) -> np.ndarray:
176
+ """Create a mask with the center of the subpatch masked.
177
+
178
+ Parameters
179
+ ----------
180
+ subpatch : np.ndarray
181
+ Subpatch to be manipulated.
182
+ center_coords : np.ndarray
183
+ Coordinates of the original center before possible crop.
184
+
185
+ Returns
186
+ -------
187
+ np.ndarray
188
+ Mask with the center of the subpatch masked.
189
+ """
190
+ mask = np.ones(subpatch.shape)
191
+ mask[tuple(center_coords)] = 0
192
+ return np.ma.make_mask(mask) # type: ignore
193
+
194
+
195
+ def _create_subpatch_struct_mask(
196
+ subpatch: np.ndarray, center_coords: np.ndarray, struct_params: StructMaskParameters
197
+ ) -> np.ndarray:
198
+ """Create a structN2V mask for the subpatch.
199
+
200
+ Parameters
201
+ ----------
202
+ subpatch : np.ndarray
203
+ Subpatch to be manipulated.
204
+ center_coords : np.ndarray
205
+ Coordinates of the original center before possible crop.
206
+ struct_params : StructMaskParameters
207
+ Parameters for the structN2V mask (axis and span).
208
+
209
+ Returns
210
+ -------
211
+ np.ndarray
212
+ StructN2V mask for the subpatch.
213
+ """
214
+ # TODO no test for this function!
215
+ # Create a mask with the center of the subpatch masked
216
+ mask_placeholder = np.ones(subpatch.shape)
217
+
218
+ # reshape to move the struct axis to the first position
219
+ mask_reshaped = np.moveaxis(mask_placeholder, struct_params.axis, 0)
220
+
221
+ # create the mask index for the struct axis
222
+ mask_index = slice(
223
+ max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2),
224
+ min(
225
+ 1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2,
226
+ subpatch.shape[struct_params.axis],
227
+ ),
228
+ )
229
+ mask_reshaped[struct_params.axis][mask_index] = 0
230
+
231
+ # reshape back to the original shape
232
+ mask = np.moveaxis(mask_reshaped, 0, struct_params.axis)
233
+
234
+ return np.ma.make_mask(mask) # type: ignore
235
+
236
+
237
+ def uniform_manipulate(
238
+ patch: np.ndarray,
239
+ mask_pixel_percentage: float,
240
+ subpatch_size: int = 11,
241
+ remove_center: bool = True,
242
+ struct_params: StructMaskParameters | None = None,
243
+ rng: np.random.Generator | None = None,
244
+ ) -> tuple[np.ndarray, np.ndarray]:
245
+ """
246
+ Manipulate pixels by replacing them with a neighbor values.
247
+
248
+ Manipulated pixels are selected unformly selected in a subpatch, away from a grid
249
+ with an approximate uniform probability to be selected across the whole patch.
250
+ If `struct_params` is not None, an additional structN2V mask is applied to the
251
+ data, replacing the pixels in the mask with random values (excluding the pixel
252
+ already manipulated).
253
+
254
+ Parameters
255
+ ----------
256
+ patch : np.ndarray
257
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
258
+ mask_pixel_percentage : float
259
+ Approximate percentage of pixels to be masked.
260
+ subpatch_size : int
261
+ Size of the subpatch the new pixel value is sampled from, by default 11.
262
+ remove_center : bool
263
+ Whether to remove the center pixel from the subpatch, by default False.
264
+ struct_params : StructMaskParameters or None
265
+ Parameters for the structN2V mask (axis and span).
266
+ rng : np.random.Generator or None
267
+ Random number generator.
268
+
269
+ Returns
270
+ -------
271
+ tuple[np.ndarray]
272
+ tuple containing the manipulated patch and the corresponding mask.
273
+ """
274
+ if rng is None:
275
+ rng = np.random.default_rng()
276
+
277
+ # Get the coordinates of the pixels to be replaced
278
+ transformed_patch = patch.copy()
279
+
280
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
281
+
282
+ # Generate coordinate grid for subpatch
283
+ roi_span_full = np.arange(
284
+ -np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)
285
+ ).astype(np.int32)
286
+
287
+ # Remove the center pixel from the grid if needed
288
+ roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
289
+
290
+ # Randomly select coordinates from the grid
291
+ random_increment = rng.choice(roi_span, size=subpatch_centers.shape)
292
+
293
+ # Clip the coordinates to the patch size
294
+ replacement_coords = np.clip(
295
+ subpatch_centers + random_increment,
296
+ 0,
297
+ [patch.shape[i] - 1 for i in range(len(patch.shape))],
298
+ )
299
+
300
+ # Get the replacement pixels from all subpatchs
301
+ replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
302
+
303
+ # Replace the original pixels with the replacement pixels
304
+ transformed_patch[tuple(subpatch_centers.T.tolist())] = replacement_pixels
305
+ mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
306
+
307
+ if struct_params is not None:
308
+ transformed_patch = _apply_struct_mask(
309
+ transformed_patch, subpatch_centers, struct_params
310
+ )
311
+
312
+ return (
313
+ transformed_patch,
314
+ mask,
315
+ )
316
+
317
+
318
+ def median_manipulate(
319
+ patch: np.ndarray,
320
+ mask_pixel_percentage: float,
321
+ subpatch_size: int = 11,
322
+ struct_params: StructMaskParameters | None = None,
323
+ rng: np.random.Generator | None = None,
324
+ ) -> tuple[np.ndarray, np.ndarray]:
325
+ """
326
+ Manipulate pixels by replacing them with the median of their surrounding subpatch.
327
+
328
+ N2V2 version, manipulated pixels are selected randomly away from a grid with an
329
+ approximate uniform probability to be selected across the whole patch.
330
+
331
+ If `struct_params` is not None, an additional structN2V mask is applied to the data,
332
+ replacing the pixels in the mask with random values (excluding the pixel already
333
+ manipulated).
334
+
335
+ Parameters
336
+ ----------
337
+ patch : np.ndarray
338
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
339
+ mask_pixel_percentage : floar
340
+ Approximate percentage of pixels to be masked.
341
+ subpatch_size : int
342
+ Size of the subpatch the new pixel value is sampled from, by default 11.
343
+ struct_params : StructMaskParameters or None, optional
344
+ Parameters for the structN2V mask (axis and span).
345
+ rng : np.random.Generator or None, optional
346
+ Random number generato, by default None.
347
+
348
+ Returns
349
+ -------
350
+ tuple[np.ndarray]
351
+ tuple containing the manipulated patch, the original patch and the mask.
352
+ """
353
+ if rng is None:
354
+ rng = np.random.default_rng()
355
+
356
+ transformed_patch = patch.copy()
357
+
358
+ # Get the coordinates of the pixels to be replaced
359
+ subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
360
+
361
+ # Generate coordinate grid for subpatch
362
+ roi_span = np.array(
363
+ [-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)]
364
+ ).astype(np.int32)
365
+
366
+ subpatch_crops_span_full = subpatch_centers[np.newaxis, ...].T + roi_span
367
+
368
+ # Dimensions n dims, n centers, (min, max)
369
+ subpatch_crops_span_clipped = np.clip(
370
+ subpatch_crops_span_full,
371
+ a_min=np.zeros_like(patch.shape)[:, np.newaxis, np.newaxis],
372
+ a_max=np.array(patch.shape)[:, np.newaxis, np.newaxis],
373
+ )
374
+
375
+ for idx in range(subpatch_crops_span_clipped.shape[1]):
376
+ subpatch_coords = subpatch_crops_span_clipped[:, idx, ...]
377
+ idxs = [
378
+ slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1)
379
+ for x in subpatch_coords
380
+ ]
381
+ subpatch = patch[tuple(idxs)]
382
+ subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0]
383
+
384
+ if struct_params is None:
385
+ subpatch_mask = _create_subpatch_center_mask(
386
+ subpatch, subpatch_center_adjusted
387
+ )
388
+ else:
389
+ subpatch_mask = _create_subpatch_struct_mask(
390
+ subpatch, subpatch_center_adjusted, struct_params
391
+ )
392
+ transformed_patch[tuple(subpatch_centers[idx])] = np.median(
393
+ subpatch[subpatch_mask]
394
+ )
395
+
396
+ mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
397
+
398
+ if struct_params is not None:
399
+ transformed_patch = _apply_struct_mask(
400
+ transformed_patch, subpatch_centers, struct_params
401
+ )
402
+
403
+ return (
404
+ transformed_patch,
405
+ mask,
406
+ )