careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,388 @@
1
+ """N2V manipulation functions for PyTorch."""
2
+
3
+ import torch
4
+
5
+ from .struct_mask_parameters import StructMaskParameters
6
+
7
+
8
+ def _apply_struct_mask_torch(
9
+ patch: torch.Tensor,
10
+ coords: torch.Tensor,
11
+ struct_params: StructMaskParameters,
12
+ rng: torch.Generator | None = None,
13
+ ) -> torch.Tensor:
14
+ """Apply structN2V masks to patch.
15
+
16
+ Each point in `coords` corresponds to the center of a mask. Masks are parameterized
17
+ by `struct_params`, and pixels in the mask (with respect to `coords`) are replaced
18
+ by a random value.
19
+
20
+ Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
21
+
22
+ Parameters
23
+ ----------
24
+ patch : torch.Tensor
25
+ Patch to be manipulated, (batch, y, x) or (batch, z, y, x).
26
+ coords : torch.Tensor
27
+ Coordinates of the ROI (subpatch) centers.
28
+ struct_params : StructMaskParameters
29
+ Parameters for the structN2V mask (axis and span).
30
+ rng : torch.Generator, optional
31
+ Random number generator.
32
+
33
+ Returns
34
+ -------
35
+ torch.Tensor
36
+ Patch with the structN2V mask applied.
37
+ """
38
+ if rng is None:
39
+ rng = torch.Generator(device=patch.device)
40
+
41
+ # Relative axis
42
+ moving_axis = -1 - struct_params.axis
43
+
44
+ # Create a mask array
45
+ mask_shape = [1] * len(patch.shape)
46
+ mask_shape[moving_axis] = struct_params.span
47
+ mask = torch.ones(mask_shape, device=patch.device)
48
+
49
+ center = torch.tensor(mask.shape, device=patch.device) // 2
50
+
51
+ # Mark the center
52
+ mask[tuple(center)] = 0
53
+
54
+ # Displacements from center
55
+ displacements = torch.stack(torch.where(mask == 1)) - center.unsqueeze(1)
56
+
57
+ # Combine all coords (ndim, npts) with all displacements (ncoords, ndim)
58
+ mix = displacements.T.unsqueeze(-1) + coords.T.unsqueeze(0)
59
+ mix = mix.permute([1, 0, 2]).reshape([mask.ndim, -1]).T
60
+
61
+ # Filter out invalid indices
62
+ valid_indices = (mix[:, moving_axis] >= 0) & (
63
+ mix[:, moving_axis] < patch.shape[moving_axis]
64
+ )
65
+ mix = mix[valid_indices]
66
+
67
+ mins = patch.min(-1)[0].min(-1)[0]
68
+ maxs = patch.max(-1)[0].max(-1)[0]
69
+ for i in range(patch.shape[0]):
70
+ batch_coords = mix[mix[:, 0] == i]
71
+ min_ = mins[i].item()
72
+ max_ = maxs[i].item()
73
+ random_values = torch.empty(len(batch_coords), device=patch.device).uniform_(
74
+ min_, max_, generator=rng
75
+ )
76
+ patch[tuple(batch_coords[:, i] for i in range(patch.ndim))] = random_values
77
+
78
+ return patch
79
+
80
+
81
+ def _get_stratified_coords_torch(
82
+ mask_pixel_perc: float,
83
+ shape: tuple[int, ...],
84
+ rng: torch.Generator,
85
+ ) -> torch.Tensor:
86
+ """
87
+ Generate coordinates of the pixels to mask.
88
+
89
+ Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
90
+ the distance between masked pixels is approximately the same. This is achieved by
91
+ defining a grid and sampling a pixel in each grid square. The grid is defined such
92
+ that the resulting density of masked pixels is the desired masked pixel percentage.
93
+
94
+ Parameters
95
+ ----------
96
+ mask_pixel_perc : float
97
+ Expected value for percentage of masked pixels across the whole image.
98
+ shape : tuple[int, ...]
99
+ Shape of the input patch.
100
+ rng : torch.Generator or None
101
+ Random number generator.
102
+
103
+ Returns
104
+ -------
105
+ np.ndarray
106
+ Array of coordinates of the masked pixels.
107
+ """
108
+ # Implementation logic:
109
+ # find a box size s.t sampling 1 pixel within the box will result in the desired
110
+ # pixel percentage. Make a grid of these boxes that cover the patch (the area of
111
+ # the grid will be greater than or equal to the area of the patch) and sample 1
112
+ # pixel in each box. The density of masked pixels is an intensive property therefore
113
+ # any subset of this area will have the desired expected masked pixel percentage.
114
+ # We can get our desired patch with our desired expected masked pixel percentage by
115
+ # simply filtering out masked pixels that lie outside of our patch bounds.
116
+
117
+ batch_size = shape[0]
118
+ spatial_shape = shape[1:]
119
+
120
+ n_dims = len(spatial_shape)
121
+ expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
122
+
123
+ # keep the grid size in floats for a more accurate expected masked pixel percentage
124
+ grid_size = expected_area_per_pixel ** (1 / n_dims)
125
+ grid_dims = torch.ceil(torch.tensor(spatial_shape) / grid_size).int()
126
+
127
+ # coords on a fixed grid (top left corner)
128
+ coords = torch.stack(
129
+ torch.meshgrid(
130
+ torch.arange(batch_size, dtype=torch.float),
131
+ *[torch.arange(0, grid_dims[i].item()) * grid_size for i in range(n_dims)],
132
+ indexing="ij",
133
+ ),
134
+ -1,
135
+ ).reshape(-1, n_dims + 1)
136
+
137
+ # add random offset to get a random coord in each grid box
138
+ # also keep the offset in floats
139
+ offset = (
140
+ torch.rand((len(coords), n_dims), device=rng.device, generator=rng) * grid_size
141
+ )
142
+ coords = coords.to(rng.device)
143
+ coords[:, 1:] += offset
144
+ coords = torch.floor(coords).int()
145
+
146
+ # filter pixels out of bounds
147
+ out_of_bounds = (
148
+ coords[:, 1:]
149
+ >= torch.tensor(spatial_shape, device=rng.device).reshape(1, n_dims)
150
+ ).any(1)
151
+ coords = coords[~out_of_bounds]
152
+ return coords
153
+
154
+
155
+ def uniform_manipulate_torch(
156
+ patch: torch.Tensor,
157
+ mask_pixel_percentage: float,
158
+ subpatch_size: int = 11,
159
+ remove_center: bool = True,
160
+ struct_params: StructMaskParameters | None = None,
161
+ rng: torch.Generator | None = None,
162
+ ) -> tuple[torch.Tensor, torch.Tensor]:
163
+ """
164
+ Manipulate pixels by replacing them with a neighbor values.
165
+
166
+ # TODO add more details, especially about batch
167
+
168
+ Manipulated pixels are selected uniformly selected in a subpatch, away from a grid
169
+ with an approximate uniform probability to be selected across the whole patch.
170
+ If `struct_params` is not None, an additional structN2V mask is applied to the
171
+ data, replacing the pixels in the mask with random values (excluding the pixel
172
+ already manipulated).
173
+
174
+ Parameters
175
+ ----------
176
+ patch : torch.Tensor
177
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x). # TODO batch and channel.
178
+ mask_pixel_percentage : float
179
+ Approximate percentage of pixels to be masked.
180
+ subpatch_size : int
181
+ Size of the subpatch the new pixel value is sampled from, by default 11.
182
+ remove_center : bool
183
+ Whether to remove the center pixel from the subpatch, by default False.
184
+ struct_params : StructMaskParameters or None
185
+ Parameters for the structN2V mask (axis and span).
186
+ rng : torch.default_generator or None
187
+ Random number generator.
188
+
189
+ Returns
190
+ -------
191
+ tuple[torch.Tensor, torch.Tensor]
192
+ tuple containing the manipulated patch and the corresponding mask.
193
+ """
194
+ if rng is None:
195
+ rng = torch.Generator(device=patch.device)
196
+ # TODO do we need seed ?
197
+
198
+ # create a copy of the patch
199
+ transformed_patch = patch.clone()
200
+
201
+ # get the coordinates of the pixels to be masked
202
+ subpatch_centers = _get_stratified_coords_torch(
203
+ mask_pixel_percentage, patch.shape, rng
204
+ )
205
+ subpatch_centers = subpatch_centers.to(device=patch.device)
206
+
207
+ # TODO refactor with non negative indices?
208
+ # arrange the list of indices to represent the ROI around the pixel to be masked
209
+ roi_span_full = torch.arange(
210
+ -(subpatch_size // 2),
211
+ subpatch_size // 2 + 1,
212
+ dtype=torch.int32,
213
+ device=patch.device,
214
+ )
215
+
216
+ # remove the center pixel from the ROI
217
+ roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
218
+
219
+ # create a random increment to select the replacement value
220
+ # this increment is added to the center coordinates
221
+ random_increment = roi_span[
222
+ torch.randint(
223
+ low=min(roi_span),
224
+ high=max(roi_span) + 1,
225
+ # one less coord dim: we shouldn't add a random increment to the batch coord
226
+ size=(subpatch_centers.shape[0], subpatch_centers.shape[1] - 1),
227
+ generator=rng,
228
+ device=patch.device,
229
+ )
230
+ ]
231
+
232
+ # compute the replacement pixel coordinates
233
+ replacement_coords = subpatch_centers.clone()
234
+ # only add random increment to the spatial dimensions, not the batch dimension
235
+ replacement_coords[:, 1:] = torch.clamp(
236
+ replacement_coords[:, 1:] + random_increment,
237
+ torch.zeros_like(torch.tensor(patch.shape[1:])).to(device=patch.device),
238
+ torch.tensor([v - 1 for v in patch.shape[1:]]).to(device=patch.device),
239
+ )
240
+
241
+ # replace the pixels in the patch
242
+ # tuples and transpose are needed for proper indexing
243
+ replacement_pixels = patch[tuple(replacement_coords.T)]
244
+ transformed_patch[tuple(subpatch_centers.T)] = replacement_pixels
245
+
246
+ # create a mask representing the masked pixels
247
+ mask = (transformed_patch != patch).to(dtype=torch.uint8)
248
+
249
+ # apply structN2V mask if needed
250
+ if struct_params is not None:
251
+ transformed_patch = _apply_struct_mask_torch(
252
+ transformed_patch, subpatch_centers, struct_params, rng
253
+ )
254
+
255
+ return transformed_patch, mask
256
+
257
+
258
+ def median_manipulate_torch(
259
+ batch: torch.Tensor,
260
+ mask_pixel_percentage: float,
261
+ subpatch_size: int = 11,
262
+ struct_params: StructMaskParameters | None = None,
263
+ rng: torch.Generator | None = None,
264
+ ) -> tuple[torch.Tensor, torch.Tensor]:
265
+ """
266
+ Manipulate pixels by replacing them with the median of their surrounding subpatch.
267
+
268
+ N2V2 version, manipulated pixels are selected randomly away from a grid with an
269
+ approximate uniform probability to be selected across the whole patch.
270
+
271
+ If `struct_params` is not None, an additional structN2V mask is applied to the data,
272
+ replacing the pixels in the mask with random values (excluding the pixel already
273
+ manipulated).
274
+
275
+ Parameters
276
+ ----------
277
+ batch : torch.Tensor
278
+ Image patch, 2D or 3D, shape (y, x) or (z, y, x).
279
+ mask_pixel_percentage : float
280
+ Approximate percentage of pixels to be masked.
281
+ subpatch_size : int
282
+ Size of the subpatch the new pixel value is sampled from, by default 11.
283
+ struct_params : StructMaskParameters or None, optional
284
+ Parameters for the structN2V mask (axis and span).
285
+ rng : torch.default_generator or None, optional
286
+ Random number generator, by default None.
287
+
288
+ Returns
289
+ -------
290
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
291
+ tuple containing the manipulated patch, the original patch and the mask.
292
+ """
293
+ # get the coordinates of the future ROI centers
294
+ subpatch_center_coordinates = _get_stratified_coords_torch(
295
+ mask_pixel_percentage, batch.shape, rng
296
+ ).to(
297
+ device=batch.device
298
+ ) # (num_coordinates, batch + num_spatial_dims)
299
+
300
+ # Calculate the padding value for the input tensor
301
+ pad_value = subpatch_size // 2
302
+
303
+ # Generate all offsets for the ROIs. Iteration starting from 1 to skip the batch
304
+ offsets = torch.meshgrid(
305
+ [
306
+ torch.arange(-pad_value, pad_value + 1, device=batch.device)
307
+ for _ in range(1, subpatch_center_coordinates.shape[1])
308
+ ],
309
+ indexing="ij",
310
+ )
311
+ offsets = torch.stack(
312
+ [axis_offset.flatten() for axis_offset in offsets], dim=1
313
+ ) # (subpatch_size**2, num_spatial_dims)
314
+
315
+ # Create the list to assemble coordinates of the ROIs centers for each axis
316
+ coords_axes = []
317
+ # Create the list to assemble the span of coordinates defining the ROIs for each
318
+ # axis
319
+ coords_expands = []
320
+ for d in range(subpatch_center_coordinates.shape[1]):
321
+ coords_axes.append(subpatch_center_coordinates[:, d])
322
+ if d == 0:
323
+ # For batch dimension coordinates are not expanded (no offsets)
324
+ coords_expands.append(
325
+ subpatch_center_coordinates[:, d]
326
+ .unsqueeze(1)
327
+ .expand(-1, subpatch_size ** offsets.shape[1])
328
+ ) # (num_coordinates, subpatch_size**num_spacial_dims)
329
+ else:
330
+ # For spatial dimensions, coordinates are expanded with offsets, creating
331
+ # spans
332
+ coords_expands.append(
333
+ (
334
+ subpatch_center_coordinates[:, d].unsqueeze(1) + offsets[:, d - 1]
335
+ ).clamp(0, batch.shape[d] - 1)
336
+ ) # (num_coordinates, subpatch_size**num_spacial_dims)
337
+
338
+ # create array of rois by indexing the batch with gathered coordinates
339
+ rois = batch[
340
+ tuple(coords_expands)
341
+ ] # (num_coordinates, subpatch_size**num_spacial_dims)
342
+
343
+ if struct_params is not None:
344
+ # Create the structN2V mask
345
+ h, w = torch.meshgrid(
346
+ torch.arange(subpatch_size), torch.arange(subpatch_size), indexing="ij"
347
+ )
348
+ center_idx = subpatch_size // 2
349
+ halfspan = (struct_params.span - 1) // 2
350
+
351
+ # Determine the axis along which to apply the mask
352
+ if struct_params.axis == 0:
353
+ center_axis = h
354
+ span_axis = w
355
+ else:
356
+ center_axis = w
357
+ span_axis = h
358
+
359
+ # Create the mask
360
+ struct_mask = (
361
+ ~(
362
+ (center_axis == center_idx)
363
+ & (span_axis >= center_idx - halfspan)
364
+ & (span_axis <= center_idx + halfspan)
365
+ )
366
+ ).flatten()
367
+ rois_filtered = rois[:, struct_mask]
368
+ else:
369
+ # Remove the center pixel value from the rois
370
+ center_idx = (subpatch_size ** offsets.shape[1]) // 2
371
+ rois_filtered = torch.cat(
372
+ [rois[:, :center_idx], rois[:, center_idx + 1 :]], dim=1
373
+ )
374
+
375
+ # compute the medians.
376
+ medians = rois_filtered.median(dim=1).values # (num_coordinates,)
377
+
378
+ # Update the output tensor with medians
379
+ output_batch = batch.clone()
380
+ output_batch[tuple(coords_axes)] = medians
381
+ mask = torch.where(output_batch != batch, 1, 0).to(torch.uint8)
382
+
383
+ if struct_params is not None:
384
+ output_batch = _apply_struct_mask_torch(
385
+ output_batch, subpatch_center_coordinates, struct_params
386
+ )
387
+
388
+ return output_batch, mask
@@ -0,0 +1,20 @@
1
+ """Class representing the parameters of structN2V masks."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+
7
+ @dataclass
8
+ class StructMaskParameters:
9
+ """Parameters of structN2V masks.
10
+
11
+ Attributes
12
+ ----------
13
+ axis : Literal[0, 1]
14
+ Axis along which to apply the mask, horizontal (0) or vertical (1).
15
+ span : int
16
+ Span of the mask.
17
+ """
18
+
19
+ axis: Literal[0, 1]
20
+ span: int
@@ -0,0 +1,24 @@
1
+ """A general parent class for transforms."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ class Transform:
7
+ """A general parent class for transforms."""
8
+
9
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
10
+ """Apply the transform.
11
+
12
+ Parameters
13
+ ----------
14
+ *args : Any
15
+ Arguments.
16
+ **kwargs : Any
17
+ Keyword arguments.
18
+
19
+ Returns
20
+ -------
21
+ Any
22
+ Transformed data.
23
+ """
24
+ pass
@@ -0,0 +1,88 @@
1
+ """Test-time augmentations."""
2
+
3
+ from torch import Tensor, flip, mean, rot90, stack
4
+
5
+
6
+ class ImageRestorationTTA:
7
+ """
8
+ Test-time augmentation for image restoration tasks.
9
+
10
+ The augmentation is performed using all 90 deg rotations and their flipped version,
11
+ as well as the original image flipped.
12
+
13
+ Tensors should be of shape SC(Z)YX.
14
+
15
+ This transformation is used in the LightningModule in order to perform test-time
16
+ augmentation.
17
+ """
18
+
19
+ def forward(self, input_tensor: Tensor) -> list[Tensor]:
20
+ """
21
+ Apply test-time augmentation to the input tensor.
22
+
23
+ Parameters
24
+ ----------
25
+ input_tensor : Tensor
26
+ Input tensor, shape SC(Z)YX.
27
+
28
+ Returns
29
+ -------
30
+ list of torch.Tensor
31
+ List of augmented tensors.
32
+ """
33
+ # axes: only applies to YX axes
34
+ axes = (-2, -1)
35
+
36
+ augmented = [
37
+ # original
38
+ input_tensor,
39
+ # rotations
40
+ rot90(input_tensor, 1, dims=axes),
41
+ rot90(input_tensor, 2, dims=axes),
42
+ rot90(input_tensor, 3, dims=axes),
43
+ # original flipped
44
+ flip(input_tensor, dims=(axes[0],)),
45
+ flip(input_tensor, dims=(axes[1],)),
46
+ ]
47
+
48
+ # rotated once, flipped
49
+ augmented.extend(
50
+ [
51
+ flip(augmented[1], dims=(axes[0],)),
52
+ flip(augmented[1], dims=(axes[1],)),
53
+ ]
54
+ )
55
+
56
+ return augmented
57
+
58
+ def backward(self, x: list[Tensor]) -> Tensor:
59
+ """Undo the test-time augmentation.
60
+
61
+ Parameters
62
+ ----------
63
+ x : Any
64
+ List of augmented tensors of shape SC(Z)YX.
65
+
66
+ Returns
67
+ -------
68
+ Any
69
+ Original tensor.
70
+ """
71
+ axes = (-2, -1)
72
+
73
+ reverse = [
74
+ # original
75
+ x[0],
76
+ # rotated
77
+ rot90(x[1], -1, dims=axes),
78
+ rot90(x[2], -2, dims=axes),
79
+ rot90(x[3], -3, dims=axes),
80
+ # original flipped
81
+ flip(x[4], dims=(axes[0],)),
82
+ flip(x[5], dims=(axes[1],)),
83
+ # rotated once, flipped
84
+ rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
85
+ rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
86
+ ]
87
+
88
+ return mean(stack(reverse), dim=0)
@@ -0,0 +1,131 @@
1
+ """XY flip transform."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+ from careamics.transforms.transform import Transform
7
+
8
+
9
+ class XYFlip(Transform):
10
+ """Flip image along X and Y axis, one at a time.
11
+
12
+ This transform randomly flips one of the last two axes.
13
+
14
+ This transform expects C(Z)YX dimensions.
15
+
16
+ Attributes
17
+ ----------
18
+ axis_indices : List[int]
19
+ Indices of the axes that can be flipped.
20
+ rng : np.random.Generator
21
+ Random number generator.
22
+ p : float
23
+ Probability of applying the transform.
24
+ seed : Optional[int]
25
+ Random seed.
26
+
27
+ Parameters
28
+ ----------
29
+ flip_x : bool, optional
30
+ Whether to flip along the X axis, by default True.
31
+ flip_y : bool, optional
32
+ Whether to flip along the Y axis, by default True.
33
+ p : float, optional
34
+ Probability of applying the transform, by default 0.5.
35
+ seed : Optional[int], optional
36
+ Random seed, by default None.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ flip_x: bool = True,
42
+ flip_y: bool = True,
43
+ p: float = 0.5,
44
+ seed: int | None = None,
45
+ ) -> None:
46
+ """Constructor.
47
+
48
+ Parameters
49
+ ----------
50
+ flip_x : bool, optional
51
+ Whether to flip along the X axis, by default True.
52
+ flip_y : bool, optional
53
+ Whether to flip along the Y axis, by default True.
54
+ p : float
55
+ Probability of applying the transform, by default 0.5.
56
+ seed : Optional[int], optional
57
+ Random seed, by default None.
58
+ """
59
+ if p < 0 or p > 1:
60
+ raise ValueError("Probability must be in [0, 1].")
61
+
62
+ if not flip_x and not flip_y:
63
+ raise ValueError("At least one axis must be flippable.")
64
+
65
+ # probability to apply the transform
66
+ self.p = p
67
+
68
+ # "flippable" axes
69
+ self.axis_indices = []
70
+
71
+ if flip_y:
72
+ self.axis_indices.append(-2)
73
+ if flip_x:
74
+ self.axis_indices.append(-1)
75
+
76
+ # numpy random generator
77
+ self.rng = np.random.default_rng(seed=seed)
78
+
79
+ def __call__(
80
+ self,
81
+ patch: NDArray,
82
+ target: NDArray | None = None,
83
+ **additional_arrays: NDArray,
84
+ ) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
85
+ """Apply the transform to the source patch and the target (optional).
86
+
87
+ Parameters
88
+ ----------
89
+ patch : np.ndarray
90
+ Patch, 2D or 3D, shape C(Z)YX.
91
+ target : Optional[np.ndarray], optional
92
+ Target for the patch, by default None.
93
+ **additional_arrays : NDArray
94
+ Additional arrays that will be transformed identically to `patch` and
95
+ `target`.
96
+
97
+ Returns
98
+ -------
99
+ Tuple[np.ndarray, Optional[np.ndarray]]
100
+ Transformed patch and target.
101
+ """
102
+ if self.rng.random() > self.p:
103
+ return patch, target, additional_arrays
104
+
105
+ # choose an axis to flip
106
+ axis = self.rng.choice(self.axis_indices)
107
+
108
+ patch_transformed = self._apply(patch, axis)
109
+ target_transformed = self._apply(target, axis) if target is not None else None
110
+ additional_transformed = {
111
+ key: self._apply(array, axis) for key, array in additional_arrays.items()
112
+ }
113
+
114
+ return patch_transformed, target_transformed, additional_transformed
115
+
116
+ def _apply(self, patch: NDArray, axis: int) -> NDArray:
117
+ """Apply the transform to the image.
118
+
119
+ Parameters
120
+ ----------
121
+ patch : np.ndarray
122
+ Image patch, 2D or 3D, shape C(Z)YX.
123
+ axis : int
124
+ Axis to flip.
125
+
126
+ Returns
127
+ -------
128
+ np.ndarray
129
+ Flipped image patch.
130
+ """
131
+ return np.ascontiguousarray(np.flip(patch, axis=axis))