careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,377 @@
1
+ """MicroSplit patch synthesis."""
2
+
3
+ # --- PROOF OF PRINCIPLE ---
4
+
5
+
6
+ from collections.abc import Callable, Sequence
7
+ from typing import Any, Literal, NamedTuple
8
+
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+
12
+ from .dataset import ImageRegionData
13
+ from .image_stack import ImageStack
14
+ from .patch_extractor import PatchExtractor
15
+ from .patch_filter import PatchFilterProtocol
16
+ from .patching_strategies import PatchingStrategy, PatchSpecs
17
+
18
+
19
+ # TODO: better name
20
+ # mirrors format of ImageRegionData
21
+ class UncorrelatedRegionData(NamedTuple):
22
+ data: NDArray
23
+ source: Sequence[str | Literal["array"]]
24
+ data_shape: Sequence[Sequence[int]]
25
+ dtype: Sequence[str] # dtype should be str for collate
26
+ axes: Sequence[str]
27
+ region_spec: Sequence[PatchSpecs]
28
+
29
+
30
+ # --- for finding empty / signal channel patches in loop
31
+ def is_empty(filter: PatchFilterProtocol) -> Callable[[NDArray[Any]], bool]:
32
+ def is_empty_check(patch: NDArray[Any]) -> bool:
33
+ return filter.filter_out(patch)
34
+
35
+ return is_empty_check
36
+
37
+
38
+ def is_not_empty(filter: PatchFilterProtocol) -> Callable[[NDArray[Any]], bool]:
39
+ def is_not_empty_check(patch: NDArray[Any]) -> bool:
40
+ return not filter.filter_out(patch)
41
+
42
+ return is_not_empty_check
43
+
44
+
45
+ # ---
46
+
47
+
48
+ def create_default_input_target(
49
+ idx: int,
50
+ patch_extractor: PatchExtractor[ImageStack],
51
+ patching_strategy: PatchingStrategy,
52
+ alphas: list[float],
53
+ axes: str, # annoyingly have to supply this to image region
54
+ ) -> tuple[ImageRegionData, ImageRegionData]:
55
+ """
56
+ Create a default MicroSplit patch with synthetically summed input.
57
+
58
+ Parameters
59
+ ----------
60
+ idx: int
61
+ The dataset index.
62
+ patch_extractor: PatchExtractor
63
+ Used to extract patches from the data.
64
+ patching_strategy: PatchingStrategy
65
+ Patch locations will be sampled using the patching strategy.
66
+ alphas: list[float]
67
+ Weights for each channel for creating the synthetic input with summation.
68
+ axes: str
69
+ The axes of the data. This is only used to populate metadata.
70
+
71
+ Returns
72
+ -------
73
+ input_region: ImageRegionData
74
+ The input patch and its metadata, the data has the dimension L(Z)YX.
75
+ target_region: ImageRegionData
76
+ The target patch and its metadata, the data has the dimensions C(Z)YX.
77
+ """
78
+ patch_spec = patching_strategy.get_patch_spec(idx)
79
+ patches = extract_microsplit_patch(patch_extractor, patch_spec)
80
+
81
+ ndims = len(patches.shape) - 1
82
+ alpha_broadcast = np.array(alphas)[:, *(np.newaxis for _ in range(ndims))]
83
+ # weight channels by alphas then sum on the channel axis
84
+ # input dims will be L(Z)YX
85
+ input_patch = (alpha_broadcast * patches).sum(axis=0)
86
+ target_patch = patches[:, 0, ...] # first L patch
87
+
88
+ data_idx = patch_spec["data_idx"]
89
+ input_region = ImageRegionData(
90
+ input_patch,
91
+ source=str(patch_extractor.image_stacks[data_idx].source),
92
+ data_shape=patch_extractor.image_stacks[data_idx].data_shape,
93
+ dtype=str(patch_extractor.image_stacks[data_idx].data_dtype),
94
+ axes=axes,
95
+ region_spec=patch_spec,
96
+ additional_metadata={},
97
+ )
98
+ target_region = ImageRegionData(
99
+ target_patch,
100
+ source=str(patch_extractor.image_stacks[data_idx].source),
101
+ data_shape=patch_extractor.image_stacks[data_idx].data_shape,
102
+ dtype=str(patch_extractor.image_stacks[data_idx].data_dtype),
103
+ axes=axes,
104
+ region_spec=patch_spec,
105
+ additional_metadata={},
106
+ )
107
+ return input_region, target_region
108
+
109
+
110
+ def create_uncorrelated_input_target(
111
+ patches: NDArray[Any],
112
+ patch_specs: list[PatchSpecs],
113
+ alphas: list[float],
114
+ patch_extractor: PatchExtractor[ImageStack], # for metadata
115
+ axes: str, # mirroring imageregion
116
+ ) -> tuple[UncorrelatedRegionData, UncorrelatedRegionData]:
117
+ """
118
+ Create MicroSplit target and synthetically summed input with metadata.
119
+
120
+ Parameters
121
+ ----------
122
+ patches: NDArray
123
+ Patches with dimensions LC(Z)YX, where L contains the lateral context at
124
+ multiple scales.
125
+ patch_specs: list[PatchSpecs]
126
+ The patch specs for each channel.
127
+ alphas: list[float]
128
+ Weights for each channel for creating the synthetic input with summation.
129
+ patch_extractor: PatchExtractor
130
+ The patch extractor the patches were extracted from. Used for additional
131
+ metadata.
132
+
133
+ Returns
134
+ -------
135
+ input_region: UncorrelatedRegionData
136
+ The input patch and its metadata, the data has the dimension L(Z)YX.
137
+ target_region: UncorrelatedRegionData
138
+ The target patch and its metadata, the data has the dimensions C(Z)YX.
139
+ """
140
+ ndims = len(patches.shape) - 1
141
+ alpha_broadcast = np.array(alphas)[:, *(np.newaxis for _ in range(ndims))]
142
+ # weight channels by alphas then sum on the channel axis
143
+ # input dims will be L(Z)YX
144
+ input_patch = (alpha_broadcast * patches).sum(axis=0)
145
+ target_patch = patches[:, 0, ...] # first L patch
146
+
147
+ input_stacks = [
148
+ patch_extractor.image_stacks[patch_spec["data_idx"]]
149
+ for patch_spec in patch_specs
150
+ ]
151
+ source = [str(stack.source) for stack in input_stacks]
152
+ data_shape = [stack.data_shape for stack in input_stacks]
153
+ dtype = [str(stack.data_dtype) for stack in input_stacks]
154
+
155
+ input_region = UncorrelatedRegionData(
156
+ data=input_patch,
157
+ source=source,
158
+ data_shape=data_shape,
159
+ dtype=dtype,
160
+ region_spec=patch_specs,
161
+ axes=axes,
162
+ )
163
+ target_region = UncorrelatedRegionData(
164
+ data=target_patch,
165
+ source=source,
166
+ data_shape=data_shape,
167
+ dtype=dtype,
168
+ region_spec=patch_specs,
169
+ axes=axes,
170
+ )
171
+ return input_region, target_region
172
+
173
+
174
+ def get_random_channel_patches(
175
+ idx: int, # TODO: is this needed it makes it work the same as original dataset
176
+ patch_extractor: PatchExtractor[ImageStack],
177
+ patching_strategy: PatchingStrategy,
178
+ rng: np.random.Generator | None,
179
+ ) -> tuple[NDArray[Any], list[PatchSpecs]]:
180
+ """
181
+ Select patches form random patch locations for each channel.
182
+
183
+ Parameters
184
+ ----------
185
+ idx: int
186
+ The dataset index.
187
+ patch_extractor: PatchExtractor
188
+ Used to extract patches from the data.
189
+ patching_strategy: PatchingStrategy
190
+ Patch locations will be sampled using the patching strategy.
191
+ rng: numpy.random.Generator | None
192
+ Useful for seeding the process. If `None` the default random number generator
193
+ will be used.
194
+
195
+ Returns
196
+ -------
197
+ NDArray[Any]
198
+ The resulting patches with dimensions LC(Z)YX, where L contains the lateral
199
+ context at multiple scales.
200
+ list[PatchSpecs]
201
+ A list of patch specification, one for each channel.
202
+ """
203
+ if rng is None:
204
+ rng = np.random.default_rng()
205
+
206
+ n_channels = patch_extractor.n_channels
207
+
208
+ # in the original dataset, new random indices are chosen for each channel
209
+ # the other channels can come from anywhere in the entire dataset
210
+ indices = (idx, *rng.integers(patching_strategy.n_patches, size=(n_channels - 1)))
211
+
212
+ # get n different patch specs for n different channels
213
+ patch_specs = [patching_strategy.get_patch_spec(i) for i in indices]
214
+ patches = extract_microsplit_patch(patch_extractor, patch_specs)
215
+
216
+ return patches, patch_specs
217
+
218
+
219
+ # TODO: better name
220
+ def get_empty_channel_patches(
221
+ idx: int,
222
+ patch_extractor: PatchExtractor,
223
+ patching_strategy: PatchingStrategy,
224
+ signal_channels: dict[int, PatchFilterProtocol],
225
+ empty_channels: dict[int, PatchFilterProtocol],
226
+ patience: int,
227
+ rng: np.random.Generator | None,
228
+ ) -> tuple[NDArray[Any], list[PatchSpecs]]:
229
+ """
230
+ Select patches, specifying which channels should have signal and which should not.
231
+
232
+ Parameters
233
+ ----------
234
+ idx: int
235
+ The dataset index.
236
+ patch_extractor: PatchExtractor
237
+ Used to extract patches from the data.
238
+ patching_strategy: PatchingStrategy
239
+ Patch locations will be sampled using the patching strategy.
240
+ signal_channels: dict[int, PatchFilterProtocol]
241
+ A dictionary to specify the channels that should have signal and how they should
242
+ be filtered. The keys are the channel index and the values are the patch filters
243
+ used to determine if the channel patch is empty or not.
244
+ empty_channels: dict[int, PatchFilterProtocol]
245
+ A dictionary to specify the channels that should not have signal. Similar to
246
+ the `signal_channels`.
247
+ patience: int
248
+ New patches are selected at random until a patch with signal or without is
249
+ found, the `patience` determines how many times to look before giving up.
250
+ rng: numpy.random.Generator | None
251
+ Useful for seeding the process. If `None` the default random number generator
252
+ will be used.
253
+
254
+ Returns
255
+ -------
256
+ NDArray[Any]
257
+ The resulting patches with dimensions LC(Z)YX, where L contains the lateral
258
+ context at multiple scales.
259
+ list[PatchSpecs]
260
+ A list of patch specification, one for each channel.
261
+ """
262
+ if rng is None:
263
+ rng = np.random.default_rng()
264
+
265
+ # if a channel is not selected to be empty or filled it will from idx
266
+ filled = set(signal_channels.keys())
267
+ empty = set(empty_channels.keys())
268
+ if len(intersect := filled.intersection(empty)) != 0:
269
+ raise ValueError(
270
+ "Channels cannot be selected as both empty and filled, the following "
271
+ f"channels were selected as both {intersect}."
272
+ )
273
+
274
+ n_channels = patch_extractor.n_channels
275
+
276
+ # start with random initial patches
277
+ patches, patch_specs = get_random_channel_patches(
278
+ idx, patch_extractor, patching_strategy, rng
279
+ )
280
+
281
+ # for each channel sample patches until they are empty or not empty
282
+ for c in range(n_channels):
283
+
284
+ # criterion for the while loop
285
+ criterion: Callable[[NDArray[Any]], bool]
286
+ filter_: PatchFilterProtocol
287
+ if c in empty_channels:
288
+ filter_ = empty_channels[c]
289
+ criterion = is_not_empty(filter_)
290
+ elif c in signal_channels:
291
+ filter_ = signal_channels[c]
292
+ criterion = is_empty(filter_)
293
+ else:
294
+ break
295
+
296
+ patch = patches[c]
297
+ patch_spec = patch_specs[c]
298
+ patience_ = patience
299
+ # only check if primary input is empty
300
+ while criterion(patch[0]) and patience_ > 0:
301
+ # sample random indices from anywhere in the dataset
302
+ new_idx = rng.integers(patching_strategy.n_patches)
303
+ patch_spec = patching_strategy.get_patch_spec(new_idx.item())
304
+ patch = patch_extractor.extract_channel_patch(
305
+ data_idx=patch_spec["data_idx"],
306
+ sample_idx=patch_spec["sample_idx"],
307
+ channels=[c],
308
+ coords=patch_spec["coords"],
309
+ patch_size=patch_spec["patch_size"],
310
+ )[0]
311
+ # ^ removing channel dim
312
+ patience_ -= 1
313
+ if patience <= 0:
314
+ # TODO: log properly
315
+ print(f"Out of patience finding patch for channel {c}")
316
+
317
+ patches[c] = patch
318
+ patch_specs[c] = patch_spec
319
+
320
+ return patches, patch_specs
321
+
322
+
323
+ def extract_microsplit_patch(
324
+ patch_extractor: PatchExtractor[ImageStack],
325
+ patch_specs: PatchSpecs | list[PatchSpecs],
326
+ ) -> NDArray[Any]:
327
+ """
328
+ Extract a MicroSplit patch with the dimensions LC(Z)YX.
329
+
330
+ This patch can be used to synthesis an input patch by summing the C dimension, and
331
+ it can be used to create a target patch by selecting the primary input from the
332
+ L dimension, where L is to store lateral context patches.
333
+
334
+ Parameters
335
+ ----------
336
+ patch_extractor: PatchExtractor
337
+ Used to extract patches from the data.
338
+ patch_specs: PatchSpec | list[PatchSpecs]
339
+ A patch specification or a list of patch specifications — one for each channel.
340
+ Different patch specs can be used or each channel to create uncorrelated channel
341
+ patches.
342
+
343
+ Returns
344
+ -------
345
+ NDArray[Any]
346
+ The resulting patches with dimensions LC(Z)YX, where L contains the lateral
347
+ context at multiple scales.
348
+ """
349
+ if isinstance(patch_specs, list):
350
+ patches = np.concat(
351
+ [
352
+ patch_extractor.extract_channel_patch(
353
+ data_idx=patch_spec["data_idx"],
354
+ sample_idx=patch_spec["sample_idx"],
355
+ channels=[c],
356
+ coords=patch_spec["coords"],
357
+ patch_size=patch_spec["patch_size"],
358
+ )
359
+ for c, patch_spec in enumerate(patch_specs)
360
+ ],
361
+ axis=0,
362
+ )
363
+ else:
364
+ patches = patch_extractor.extract_patch(
365
+ data_idx=patch_specs["data_idx"],
366
+ sample_idx=patch_specs["sample_idx"],
367
+ coords=patch_specs["coords"],
368
+ patch_size=patch_specs["patch_size"],
369
+ )
370
+ # Add L dimension if not present
371
+ n_spatial_dims = patch_extractor.n_spatial_dims
372
+ lateral_context_present = len(patches.shape) - n_spatial_dims == 2
373
+ if not lateral_context_present:
374
+ # insert a L dim
375
+ patches = patches[:, np.newaxis]
376
+
377
+ return patches
@@ -0,0 +1,7 @@
1
+ __all__ = [
2
+ "LimitFilesPatchExtractor",
3
+ "PatchExtractor",
4
+ ]
5
+
6
+ from .limit_file_extractor import LimitFilesPatchExtractor
7
+ from .patch_extractor import PatchExtractor
@@ -0,0 +1,50 @@
1
+ from collections.abc import Sequence
2
+
3
+ from numpy.typing import NDArray
4
+
5
+ from ..image_stack import FileImageStack
6
+ from .patch_construction import PatchConstructor, default_patch_constr
7
+ from .patch_extractor import PatchExtractor
8
+
9
+
10
+ class LimitFilesPatchExtractor(PatchExtractor[FileImageStack]):
11
+ """
12
+ A patch extractor that limits the number of files that have their data loaded.
13
+
14
+ This is useful for when not all of the data will fit into memory.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ image_stacks: Sequence[FileImageStack],
20
+ patch_constructor: PatchConstructor = default_patch_constr,
21
+ ):
22
+ """
23
+ Parameters
24
+ ----------
25
+ image_stacks: Sequence of `FileImageStack`
26
+ """
27
+ super().__init__(image_stacks, patch_constructor)
28
+ self.loaded_stacks: list[int] = []
29
+
30
+ def extract_channel_patch(
31
+ self,
32
+ data_idx: int,
33
+ sample_idx: int,
34
+ channels: Sequence[int] | None,
35
+ coords: Sequence[int],
36
+ patch_size: Sequence[int],
37
+ ) -> NDArray:
38
+ if data_idx not in self.loaded_stacks:
39
+ # TODO: make maximum images loaded configurable?
40
+ if len(self.loaded_stacks) >= 1:
41
+ # get the idx that was added longest ago
42
+ idx_to_close = self.loaded_stacks.pop(0)
43
+ self.image_stacks[idx_to_close].close()
44
+
45
+ self.image_stacks[data_idx].load()
46
+ self.loaded_stacks.append(data_idx)
47
+
48
+ return super().extract_channel_patch(
49
+ data_idx, sample_idx, channels, coords, patch_size
50
+ )
@@ -0,0 +1,151 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any, Literal, Protocol
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+ from skimage.transform import resize
7
+
8
+ from ..image_stack import ImageStack
9
+
10
+
11
+ class PatchConstructor(Protocol):
12
+ """
13
+ A callable that modifies how patches are constructed in the PatchExtractor.
14
+
15
+ This protocol defines the signature of a callable that is passed as an argument to
16
+ the `PatchExtractor`. It can be used to modify how patches are constructed, for
17
+ example creating patches with multiple lateral context levels for MicroSplit.
18
+ """
19
+
20
+ def __call__(
21
+ self,
22
+ image_stack: ImageStack,
23
+ sample_idx: int,
24
+ channels: Sequence[int] | None, # `channels = None` to select all channels
25
+ coords: Sequence[int],
26
+ patch_size: Sequence[int],
27
+ ) -> NDArray[Any]:
28
+ """
29
+ Parameters
30
+ ----------
31
+ image_stack: ImageStack
32
+ The image stack to construct a patch from.
33
+ sample_idx: int
34
+ Sample index. The first dimension of the image data will be indexed at this
35
+ value.
36
+ coords: Sequence of int
37
+ The coordinates that define the start of a patch.
38
+ patch_size: Sequence of int
39
+ The size of the patch in each spatial dimension.
40
+
41
+ Returns
42
+ -------
43
+ numpy.ndarray
44
+ The patch.
45
+ """
46
+ ...
47
+
48
+
49
+ def default_patch_constr(
50
+ image_stack: ImageStack,
51
+ sample_idx: int,
52
+ channels: Sequence[int] | None, # `channels = None` to select all channels
53
+ coords: Sequence[int],
54
+ patch_size: Sequence[int],
55
+ ) -> NDArray[Any]:
56
+ return image_stack.extract_channel_patch(
57
+ sample_idx=sample_idx,
58
+ channels=channels,
59
+ coords=coords,
60
+ patch_size=patch_size,
61
+ )
62
+
63
+
64
+ # closure to create constructor funcs with particular multiscale_count and padding mode
65
+ def lateral_context_patch_constr(
66
+ # TODO: will we stick with this as the parameter name
67
+ multiscale_count: int,
68
+ # TODO: add other modes?
69
+ padding_mode: Literal["reflect", "wrap"],
70
+ ) -> PatchConstructor:
71
+ """
72
+ Create a lateral context `PatchConstructor` for MicroSplit.
73
+
74
+ Parameters
75
+ ----------
76
+ multiscale_count : int
77
+ The number of multiscale inputs that will be created including the original
78
+ image size.
79
+ padding_mode : {"reflect", "wrap"}
80
+ How lateral context inputs will be padded at the edge of the image. See
81
+ [`numpy.pad`](https://numpy.org/devdocs/reference/generated/numpy.pad.html) for
82
+ more information.
83
+
84
+ Returns
85
+ -------
86
+ PatchConstructor
87
+ The patch constructor function. It will return patches with the dimensions
88
+ (C, L, (Z), Y, X) where L will be equal to `multiscale_count`, C is the number
89
+ of channels in the image, and (Z), Y, X are the patch size.
90
+ """
91
+
92
+ def constructor_func(
93
+ image_stack: ImageStack,
94
+ sample_idx: int,
95
+ channels: Sequence[int] | None, # `channels = None` to select all channels
96
+ coords: Sequence[int],
97
+ patch_size: Sequence[int],
98
+ ) -> NDArray[Any]:
99
+ if channels is not None and len(channels) > 1:
100
+ raise NotImplementedError(
101
+ "Selecting multiple channels is currently not implemented for lateral "
102
+ "context patches. Select a single channel or pass `channels=None` to "
103
+ "select all channels."
104
+ )
105
+
106
+ shape = image_stack.data_shape
107
+ spatial_shape = shape[2:]
108
+ n_channels = shape[1] if channels is None else 1
109
+
110
+ # There will now be an additional lc dimension,
111
+ # this has to be handled correctly by the dataset
112
+ # TODO: maybe we want to limit this constructor to only images with 1 channel
113
+ # then we can put LCs in the channel dimension
114
+ # but not sure if this artificially limits potential use-cases
115
+ patch = np.zeros((n_channels, multiscale_count, *patch_size))
116
+ for scale in range(multiscale_count):
117
+ lc_patch_size = np.array(patch_size) * (2**scale)
118
+ lc_start = np.array(coords) + np.array(patch_size) // 2 - lc_patch_size // 2
119
+ lc_end = lc_start + np.array(lc_patch_size)
120
+
121
+ start_clipped = np.clip(
122
+ lc_start, np.zeros_like(spatial_shape), np.array(spatial_shape)
123
+ )
124
+ end_clipped = np.clip(
125
+ lc_end, np.zeros_like(spatial_shape), np.array(spatial_shape)
126
+ )
127
+ size_clipped = end_clipped - start_clipped
128
+
129
+ lc_patch = image_stack.extract_channel_patch(
130
+ sample_idx, channels, start_clipped, size_clipped
131
+ )
132
+ pad_before = start_clipped - lc_start
133
+ pad_after = lc_end - end_clipped
134
+ pad_width = np.concat(
135
+ [
136
+ # zeros to not pad the channel axis
137
+ np.zeros((1, 2), dtype=int),
138
+ np.stack([pad_before, pad_after], axis=-1),
139
+ ]
140
+ )
141
+ lc_patch = np.pad(
142
+ lc_patch,
143
+ pad_width,
144
+ mode=padding_mode,
145
+ )
146
+ # TODO: test different downscaling? skimage suggests downscale_local_mean
147
+ lc_patch = resize(lc_patch, (n_channels, *patch_size))
148
+ patch[:, scale, ...] = lc_patch
149
+ return patch
150
+
151
+ return constructor_func
@@ -0,0 +1,117 @@
1
+ from collections.abc import Sequence
2
+ from typing import Generic
3
+
4
+ from numpy.typing import NDArray
5
+
6
+ from ..image_stack import GenericImageStack
7
+ from .patch_construction import PatchConstructor, default_patch_constr
8
+
9
+
10
+ class PatchExtractor(Generic[GenericImageStack]):
11
+ """
12
+ A class for extracting patches from multiple image stacks.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ image_stacks: Sequence[GenericImageStack],
18
+ patch_constructor: PatchConstructor = default_patch_constr,
19
+ ):
20
+ self.patch_constructor = patch_constructor
21
+ self.image_stacks: list[GenericImageStack] = list(image_stacks)
22
+
23
+ # check all image stacks have the same number of dimensions
24
+ # check all image stacks have the same number of channels
25
+ self.n_spatial_dims = len(self.image_stacks[0].data_shape) - 2 # SC(Z)YX
26
+ self.n_channels = self.image_stacks[0].data_shape[1]
27
+ for i, image_stack in enumerate(image_stacks):
28
+ if (ndims := len(image_stack.data_shape) - 2) != self.n_spatial_dims:
29
+ raise ValueError(
30
+ "All `ImageStack` objects in a `PatchExtractor` must have the same "
31
+ "number of spatial dimensions. The first image stack is "
32
+ f"{self.n_spatial_dims}D but found a {ndims}D image stack at index "
33
+ f"{i}."
34
+ )
35
+ if (n_channels := image_stack.data_shape[1]) != self.n_channels:
36
+ raise ValueError(
37
+ "All `ImageStack` objects in a `PatchExtractor` must have the same "
38
+ f"number of channels. The first image stack has {self.n_channels} "
39
+ f"but found an image stack with {n_channels} at index {i}."
40
+ )
41
+
42
+ def extract_patch(
43
+ self,
44
+ data_idx: int,
45
+ sample_idx: int,
46
+ coords: Sequence[int],
47
+ patch_size: Sequence[int],
48
+ ) -> NDArray:
49
+ """Extract a patch from the specified image stack across all channels.
50
+
51
+ Eqauivalent to calling `extract_channel_patch` with `channels=None`.
52
+
53
+ Parameters
54
+ ----------
55
+ data_idx : int
56
+ Index of the image stack to extract the patch from.
57
+ sample_idx : int
58
+ Sample index. The first dimension of the image data will be indexed at this
59
+ value.
60
+ coords : Sequence of int
61
+ The coordinates that define the start of a patch.
62
+ patch_size : Sequence of int
63
+ The size of the patch in each spatial dimension.
64
+
65
+ Returns
66
+ -------
67
+ numpy.ndarray
68
+ The extracted patch.
69
+ """
70
+ return self.extract_channel_patch(
71
+ data_idx=data_idx,
72
+ sample_idx=sample_idx,
73
+ channels=None,
74
+ coords=coords,
75
+ patch_size=patch_size,
76
+ )
77
+
78
+ def extract_channel_patch(
79
+ self,
80
+ data_idx: int,
81
+ sample_idx: int,
82
+ channels: Sequence[int] | None,
83
+ coords: Sequence[int],
84
+ patch_size: Sequence[int],
85
+ ) -> NDArray:
86
+ """Extract a patch from the specified image stack.
87
+
88
+ Parameters
89
+ ----------
90
+ data_idx : int
91
+ Index of the image stack to extract the patch from.
92
+ sample_idx : int
93
+ Sample index. The first dimension of the image data will be indexed at this
94
+ value.
95
+ channels : Sequence of int | None
96
+ Channels to extract. If `None`, all channels are extracted.
97
+ coords : Sequence of int
98
+ The coordinates that define the start of a patch.
99
+ patch_size : Sequence of int
100
+ The size of the patch in each spatial dimension.
101
+
102
+ Returns
103
+ -------
104
+ numpy.ndarray
105
+ The extracted patch.
106
+ """
107
+ return self.patch_constructor(
108
+ self.image_stacks[data_idx],
109
+ sample_idx=sample_idx,
110
+ channels=channels,
111
+ coords=coords,
112
+ patch_size=patch_size,
113
+ )
114
+
115
+ @property
116
+ def shapes(self) -> list[Sequence[int]]:
117
+ return [stack.data_shape for stack in self.image_stacks]