careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,375 @@
1
+ """Tile Zarr writing strategy."""
2
+
3
+ import builtins
4
+ from collections.abc import Sequence
5
+ from pathlib import Path
6
+
7
+ import zarr
8
+ from numpy import float32
9
+
10
+ from careamics.dataset.dataset_utils.dataset_utils import get_axes_order
11
+ from careamics.dataset_ng.dataset import ImageRegionData
12
+ from careamics.dataset_ng.image_stack_loader.zarr_utils import (
13
+ decipher_zarr_uri,
14
+ is_valid_uri,
15
+ )
16
+ from careamics.dataset_ng.patching_strategies import TileSpecs, is_tile_specs
17
+
18
+ OUTPUT_KEY = "_output"
19
+
20
+
21
+ def _update_data_shape(axes: str, data_shape: Sequence[int]) -> tuple[int, ...]:
22
+ """Update data shape to remove non existing dimensions.
23
+
24
+ Parameters
25
+ ----------
26
+ axes : str
27
+ Axes string of the original data.
28
+ data_shape : Sequence[int]
29
+ Shape of the array in SC(Z)YX order with potential singleton dimensions.
30
+
31
+ Returns
32
+ -------
33
+ tuple[int, ...]
34
+ Updated shape with non-existing axes removed.
35
+ """
36
+ new_shape = []
37
+
38
+ if "S" in axes:
39
+ new_shape.append(data_shape[0])
40
+
41
+ if "C" in axes:
42
+ new_shape.append(data_shape[1])
43
+
44
+ for idx in range(2, len(data_shape)):
45
+ new_shape.append(data_shape[idx])
46
+
47
+ return tuple(new_shape)
48
+
49
+
50
+ def _update_T_axis(axes: str) -> str:
51
+ """Update axes string to account for multiplexed S and T dimensions.
52
+
53
+ If only `T` is present, then it is relabeled as `S`. If both `S` and `T` are
54
+ present, then `T` is removed.
55
+
56
+ Parameters
57
+ ----------
58
+ axes : str
59
+ Axes string of the original data.
60
+
61
+ Returns
62
+ -------
63
+ str
64
+ Updated axes string.
65
+ """
66
+ if "T" in axes:
67
+ if "S" in axes:
68
+ # remove T
69
+ axes = axes.replace("T", "")
70
+ else:
71
+ # relabel T as S
72
+ axes = axes.replace("T", "S")
73
+ return axes
74
+
75
+
76
+ def _auto_chunks(axes: str, data_shape: Sequence[int]) -> tuple[int, ...]:
77
+ """Generate automatic chunk sizes based on axes and shape.
78
+
79
+ Spatial dimensions will be chunked with a maximum size of 64, other dimensions
80
+ will have chunk size 1.
81
+
82
+ Parameters
83
+ ----------
84
+ axes : str
85
+ Axes string of the original data.
86
+ data_shape : Sequence[int]
87
+ Shape of the array in SC(Z)YX order with potential singleton dimensions.
88
+
89
+ Returns
90
+ -------
91
+ tuple[int, ...]
92
+ Chunk sizes for each dimension in SC(Z)YX order, but excluding dimensions that
93
+ are not in the axes string.
94
+ """
95
+ chunk_sizes = []
96
+
97
+ # axes may contain T, which is now multiplexed with S
98
+ updated_axes = _update_T_axis(axes)
99
+
100
+ # axes reshaping indices in the order SC(Z)YX
101
+ indices = get_axes_order(updated_axes, ref_axes="SCZYX")
102
+
103
+ sczyx_offset = 0
104
+
105
+ if "S" not in updated_axes:
106
+ sczyx_offset = 1 # singleton S dim added to data_shape
107
+
108
+ if "C" not in updated_axes:
109
+ sczyx_offset += 1 # singleton C dim added to data_shape
110
+
111
+ # loop through the original axes in order SC(Z)YX
112
+ # - original_index is the index of the axis in the original `axes` string
113
+ # - idx is the index in SC(Z)YX order of the axes present in `axes`
114
+ # - since all non spatial are treated the same, we can recover the spatial dims
115
+ # index in SC(Z)YX order by using sczyx_offset
116
+ for idx, original_index in enumerate(indices):
117
+ axis = updated_axes[original_index]
118
+
119
+ # TODO we should probably not chunk along Z (#658)
120
+ if axis in ("Z", "Y", "X"):
121
+ dim_size = data_shape[idx + sczyx_offset]
122
+ chunk_sizes.append(
123
+ min(128, dim_size)
124
+ ) # TODO arbitrary value, about 1MB for float64
125
+ else:
126
+ chunk_sizes.append(1)
127
+
128
+ return tuple(chunk_sizes)
129
+
130
+
131
+ def _add_output_key(dirpath: Path, path: str | Path) -> Path:
132
+ """Add `_output` to zarr name.
133
+
134
+ Parameters
135
+ ----------
136
+ dirpath : Path
137
+ Directory path to save the output zarr.
138
+ path : str | Path
139
+ Original zarr path.
140
+
141
+ Returns
142
+ -------
143
+ Path
144
+ Zarr path with `output` key added.
145
+ """
146
+ p = Path(path)
147
+ new_name = p.stem + OUTPUT_KEY + ".zarr"
148
+ return dirpath / new_name
149
+
150
+
151
+ class WriteTilesZarr:
152
+ """Zarr tile writer strategy.
153
+
154
+ This writer creates zarr files, groups and arrays as needed and writes tiles
155
+ into the appropriate locations.
156
+ """
157
+
158
+ def __init__(self) -> None:
159
+ """Constructor."""
160
+ self.current_store: zarr.Group | None = None
161
+ self.current_group: zarr.Group | None = None
162
+ self.current_array: zarr.Array | None = None
163
+
164
+ def _create_zarr(self, store: str | Path) -> None:
165
+ """Create a new zarr storage.
166
+
167
+ Parameters
168
+ ----------
169
+ store : str | Path
170
+ Path to the zarr store.
171
+ """
172
+ if not Path(store).exists():
173
+ self.current_store = zarr.create_group(store)
174
+ else:
175
+ open_store = zarr.open(store)
176
+
177
+ if not isinstance(open_store, zarr.Group):
178
+ raise RuntimeError(f"Zarr store at {store} is not a group.")
179
+
180
+ self.current_store = open_store
181
+
182
+ print(f"Store: {Path(store).absolute()}")
183
+
184
+ def _create_group(self, group_path: str) -> None:
185
+ """Create a new group in an existing zarr storage.
186
+
187
+ Parameters
188
+ ----------
189
+ group_path : str
190
+ Path to the group within the zarr store.
191
+
192
+ Raises
193
+ ------
194
+ RuntimeError
195
+ If the zarr store has not been initialized.
196
+ """
197
+ if self.current_store is None:
198
+ raise RuntimeError("Zarr store not initialized.")
199
+
200
+ if group_path not in self.current_store:
201
+ self.current_group = self.current_store.create_group(group_path)
202
+ else:
203
+ current_group = self.current_store[group_path]
204
+ if not isinstance(current_group, zarr.Group):
205
+ raise RuntimeError(f"Zarr group at {group_path} is not a group.")
206
+
207
+ self.current_group = current_group
208
+
209
+ def _create_array(
210
+ self,
211
+ array_name: str,
212
+ axes: str,
213
+ data_shape: Sequence[int],
214
+ shards: tuple[int, ...] | None,
215
+ chunks: tuple[int, ...] | None,
216
+ ) -> None:
217
+ """Create a new array in an existing zarr group.
218
+
219
+ Parameters
220
+ ----------
221
+ array_name : str
222
+ Name of the array within the zarr group.
223
+ axes : str
224
+ Axes string in SC(Z)YX format with original data order.
225
+ data_shape : Sequence[int]
226
+ Shape of the array.
227
+ shards : tuple[int, ...] or None
228
+ Shard size for the array.
229
+ chunks : tuple[int, ...] or None
230
+ Chunk size for the array.
231
+
232
+ Raises
233
+ ------
234
+ RuntimeError
235
+ If the zarr group has not been initialized.
236
+ """
237
+ if self.current_group is None:
238
+ raise RuntimeError("Zarr group not initialized.")
239
+
240
+ if array_name not in self.current_group:
241
+ # get shape without non-existing axes (S or C)
242
+ updated_shape = _update_data_shape(axes, data_shape)
243
+
244
+ if chunks is not None and len(updated_shape) != len(chunks):
245
+ raise ValueError(
246
+ f"Shape {updated_shape} and chunks {chunks} have different lengths."
247
+ )
248
+
249
+ if chunks is None:
250
+ chunks = _auto_chunks(axes, data_shape)
251
+
252
+ # TODO if we auto_chunks, we probably want to auto shards as well
253
+ # there is shards="auto" in zarr, where array.target_shard_size_bytes
254
+ # needs to be used (see zarr-python docs)
255
+ if shards is not None and len(chunks) != len(shards):
256
+ raise ValueError(
257
+ f"Chunks {chunks} and shards {shards} have different lengths."
258
+ )
259
+
260
+ self.current_array = self.current_group.create_array(
261
+ name=array_name,
262
+ shape=updated_shape,
263
+ shards=shards,
264
+ chunks=chunks,
265
+ dtype=float32,
266
+ )
267
+ else:
268
+ current_array = self.current_group[array_name]
269
+ if not isinstance(current_array, zarr.Array):
270
+ raise RuntimeError(f"Zarr array at {array_name} is not an array.")
271
+ self.current_array = current_array
272
+
273
+ def write_tile(self, dirpath: Path, region: ImageRegionData) -> None:
274
+ """Write cropped tile to zarr array.
275
+
276
+ Parameters
277
+ ----------
278
+ dirpath : Path
279
+ Path to directory to save predictions to.
280
+ region : ImageRegionData
281
+ Image region data containing tile information.
282
+ """
283
+ if is_valid_uri(region.source):
284
+ store_path, parent_path, array_name = decipher_zarr_uri(region.source)
285
+ output_store_path = _add_output_key(dirpath, store_path)
286
+ else:
287
+ raise NotImplementedError(
288
+ f"Invalid zarr URI: {region.source}. Currently, only predicting from "
289
+ f"Zarr files is supported when writing Zarr tiles."
290
+ )
291
+
292
+ if (
293
+ self.current_group is None
294
+ or str(self.current_group.store_path)[: len(OUTPUT_KEY)]
295
+ != output_store_path
296
+ ):
297
+ self._create_zarr(output_store_path)
298
+
299
+ if self.current_group is None or self.current_group.name != parent_path:
300
+ self._create_group(parent_path)
301
+
302
+ if self.current_array is None or self.current_array.basename != array_name:
303
+ # data_shape, chunks and shards are in SC(Z)YX order since they are reshaped
304
+ # in the zarr image stack loader
305
+ # If the source is not a Zarr file, then chunks and shards will be `None`.
306
+ shape = region.data_shape
307
+ chunks: tuple[int, ...] | None = region.additional_metadata.get(
308
+ "chunks", None
309
+ )
310
+ shards: tuple[int, ...] | None = region.additional_metadata.get(
311
+ "shards", None
312
+ )
313
+ self._create_array(array_name, region.axes, shape, shards, chunks)
314
+
315
+ assert is_tile_specs(region.region_spec) # for mypy
316
+ tile_spec: TileSpecs = region.region_spec
317
+ crop_coords = tile_spec["crop_coords"]
318
+ crop_size = tile_spec["crop_size"]
319
+ stitch_coords = tile_spec["stitch_coords"]
320
+
321
+ # compute sample slice
322
+ sample_idx = tile_spec["sample_idx"]
323
+
324
+ # TODO there is duplicated code in stitch_prediction
325
+ crop_slices: tuple[builtins.ellipsis | slice | int, ...] = (
326
+ ...,
327
+ *[
328
+ slice(start, start + length)
329
+ for start, length in zip(crop_coords, crop_size, strict=True)
330
+ ],
331
+ )
332
+ stitch_slices: tuple[builtins.ellipsis | slice | int, ...] = (
333
+ ...,
334
+ *[
335
+ slice(start, start + length)
336
+ for start, length in zip(stitch_coords, crop_size, strict=True)
337
+ ],
338
+ )
339
+
340
+ if self.current_array is not None:
341
+ # region.data has shape C(Z)YX, broadcast can fail with singleton dims
342
+ crop = region.data[crop_slices]
343
+
344
+ if region.data.shape[0] == 1 and "C" not in region.axes:
345
+ # singleton C dim, need to remove it before writing
346
+ # unless it was present in the original axes
347
+ crop = crop[0]
348
+
349
+ if "S" in region.axes:
350
+ if "C" in region.axes:
351
+ stitch_slices = (sample_idx, *stitch_slices[0:])
352
+ else:
353
+ stitch_slices = (sample_idx, *stitch_slices[1:])
354
+
355
+ self.current_array[stitch_slices] = crop
356
+ else:
357
+ raise RuntimeError("Zarr array not initialized.")
358
+
359
+ def write_batch(
360
+ self,
361
+ dirpath: Path,
362
+ predictions: list[ImageRegionData],
363
+ ) -> None:
364
+ """
365
+ Write all tiles to a Zarr file.
366
+
367
+ Parameters
368
+ ----------
369
+ dirpath : Path
370
+ Path to directory to save predictions to.
371
+ predictions : list[ImageRegionData]
372
+ Decollated predictions.
373
+ """
374
+ for region in predictions:
375
+ self.write_tile(dirpath, region)