careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""Tile Zarr writing strategy."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import zarr
|
|
8
|
+
from numpy import float32
|
|
9
|
+
|
|
10
|
+
from careamics.dataset.dataset_utils.dataset_utils import get_axes_order
|
|
11
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
12
|
+
from careamics.dataset_ng.image_stack_loader.zarr_utils import (
|
|
13
|
+
decipher_zarr_uri,
|
|
14
|
+
is_valid_uri,
|
|
15
|
+
)
|
|
16
|
+
from careamics.dataset_ng.patching_strategies import TileSpecs, is_tile_specs
|
|
17
|
+
|
|
18
|
+
OUTPUT_KEY = "_output"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _update_data_shape(axes: str, data_shape: Sequence[int]) -> tuple[int, ...]:
|
|
22
|
+
"""Update data shape to remove non existing dimensions.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
axes : str
|
|
27
|
+
Axes string of the original data.
|
|
28
|
+
data_shape : Sequence[int]
|
|
29
|
+
Shape of the array in SC(Z)YX order with potential singleton dimensions.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
tuple[int, ...]
|
|
34
|
+
Updated shape with non-existing axes removed.
|
|
35
|
+
"""
|
|
36
|
+
new_shape = []
|
|
37
|
+
|
|
38
|
+
if "S" in axes:
|
|
39
|
+
new_shape.append(data_shape[0])
|
|
40
|
+
|
|
41
|
+
if "C" in axes:
|
|
42
|
+
new_shape.append(data_shape[1])
|
|
43
|
+
|
|
44
|
+
for idx in range(2, len(data_shape)):
|
|
45
|
+
new_shape.append(data_shape[idx])
|
|
46
|
+
|
|
47
|
+
return tuple(new_shape)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _update_T_axis(axes: str) -> str:
|
|
51
|
+
"""Update axes string to account for multiplexed S and T dimensions.
|
|
52
|
+
|
|
53
|
+
If only `T` is present, then it is relabeled as `S`. If both `S` and `T` are
|
|
54
|
+
present, then `T` is removed.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
axes : str
|
|
59
|
+
Axes string of the original data.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
str
|
|
64
|
+
Updated axes string.
|
|
65
|
+
"""
|
|
66
|
+
if "T" in axes:
|
|
67
|
+
if "S" in axes:
|
|
68
|
+
# remove T
|
|
69
|
+
axes = axes.replace("T", "")
|
|
70
|
+
else:
|
|
71
|
+
# relabel T as S
|
|
72
|
+
axes = axes.replace("T", "S")
|
|
73
|
+
return axes
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _auto_chunks(axes: str, data_shape: Sequence[int]) -> tuple[int, ...]:
|
|
77
|
+
"""Generate automatic chunk sizes based on axes and shape.
|
|
78
|
+
|
|
79
|
+
Spatial dimensions will be chunked with a maximum size of 64, other dimensions
|
|
80
|
+
will have chunk size 1.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
axes : str
|
|
85
|
+
Axes string of the original data.
|
|
86
|
+
data_shape : Sequence[int]
|
|
87
|
+
Shape of the array in SC(Z)YX order with potential singleton dimensions.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
tuple[int, ...]
|
|
92
|
+
Chunk sizes for each dimension in SC(Z)YX order, but excluding dimensions that
|
|
93
|
+
are not in the axes string.
|
|
94
|
+
"""
|
|
95
|
+
chunk_sizes = []
|
|
96
|
+
|
|
97
|
+
# axes may contain T, which is now multiplexed with S
|
|
98
|
+
updated_axes = _update_T_axis(axes)
|
|
99
|
+
|
|
100
|
+
# axes reshaping indices in the order SC(Z)YX
|
|
101
|
+
indices = get_axes_order(updated_axes, ref_axes="SCZYX")
|
|
102
|
+
|
|
103
|
+
sczyx_offset = 0
|
|
104
|
+
|
|
105
|
+
if "S" not in updated_axes:
|
|
106
|
+
sczyx_offset = 1 # singleton S dim added to data_shape
|
|
107
|
+
|
|
108
|
+
if "C" not in updated_axes:
|
|
109
|
+
sczyx_offset += 1 # singleton C dim added to data_shape
|
|
110
|
+
|
|
111
|
+
# loop through the original axes in order SC(Z)YX
|
|
112
|
+
# - original_index is the index of the axis in the original `axes` string
|
|
113
|
+
# - idx is the index in SC(Z)YX order of the axes present in `axes`
|
|
114
|
+
# - since all non spatial are treated the same, we can recover the spatial dims
|
|
115
|
+
# index in SC(Z)YX order by using sczyx_offset
|
|
116
|
+
for idx, original_index in enumerate(indices):
|
|
117
|
+
axis = updated_axes[original_index]
|
|
118
|
+
|
|
119
|
+
# TODO we should probably not chunk along Z (#658)
|
|
120
|
+
if axis in ("Z", "Y", "X"):
|
|
121
|
+
dim_size = data_shape[idx + sczyx_offset]
|
|
122
|
+
chunk_sizes.append(
|
|
123
|
+
min(128, dim_size)
|
|
124
|
+
) # TODO arbitrary value, about 1MB for float64
|
|
125
|
+
else:
|
|
126
|
+
chunk_sizes.append(1)
|
|
127
|
+
|
|
128
|
+
return tuple(chunk_sizes)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _add_output_key(dirpath: Path, path: str | Path) -> Path:
|
|
132
|
+
"""Add `_output` to zarr name.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
dirpath : Path
|
|
137
|
+
Directory path to save the output zarr.
|
|
138
|
+
path : str | Path
|
|
139
|
+
Original zarr path.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
Path
|
|
144
|
+
Zarr path with `output` key added.
|
|
145
|
+
"""
|
|
146
|
+
p = Path(path)
|
|
147
|
+
new_name = p.stem + OUTPUT_KEY + ".zarr"
|
|
148
|
+
return dirpath / new_name
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class WriteTilesZarr:
|
|
152
|
+
"""Zarr tile writer strategy.
|
|
153
|
+
|
|
154
|
+
This writer creates zarr files, groups and arrays as needed and writes tiles
|
|
155
|
+
into the appropriate locations.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(self) -> None:
|
|
159
|
+
"""Constructor."""
|
|
160
|
+
self.current_store: zarr.Group | None = None
|
|
161
|
+
self.current_group: zarr.Group | None = None
|
|
162
|
+
self.current_array: zarr.Array | None = None
|
|
163
|
+
|
|
164
|
+
def _create_zarr(self, store: str | Path) -> None:
|
|
165
|
+
"""Create a new zarr storage.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
store : str | Path
|
|
170
|
+
Path to the zarr store.
|
|
171
|
+
"""
|
|
172
|
+
if not Path(store).exists():
|
|
173
|
+
self.current_store = zarr.create_group(store)
|
|
174
|
+
else:
|
|
175
|
+
open_store = zarr.open(store)
|
|
176
|
+
|
|
177
|
+
if not isinstance(open_store, zarr.Group):
|
|
178
|
+
raise RuntimeError(f"Zarr store at {store} is not a group.")
|
|
179
|
+
|
|
180
|
+
self.current_store = open_store
|
|
181
|
+
|
|
182
|
+
print(f"Store: {Path(store).absolute()}")
|
|
183
|
+
|
|
184
|
+
def _create_group(self, group_path: str) -> None:
|
|
185
|
+
"""Create a new group in an existing zarr storage.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
group_path : str
|
|
190
|
+
Path to the group within the zarr store.
|
|
191
|
+
|
|
192
|
+
Raises
|
|
193
|
+
------
|
|
194
|
+
RuntimeError
|
|
195
|
+
If the zarr store has not been initialized.
|
|
196
|
+
"""
|
|
197
|
+
if self.current_store is None:
|
|
198
|
+
raise RuntimeError("Zarr store not initialized.")
|
|
199
|
+
|
|
200
|
+
if group_path not in self.current_store:
|
|
201
|
+
self.current_group = self.current_store.create_group(group_path)
|
|
202
|
+
else:
|
|
203
|
+
current_group = self.current_store[group_path]
|
|
204
|
+
if not isinstance(current_group, zarr.Group):
|
|
205
|
+
raise RuntimeError(f"Zarr group at {group_path} is not a group.")
|
|
206
|
+
|
|
207
|
+
self.current_group = current_group
|
|
208
|
+
|
|
209
|
+
def _create_array(
|
|
210
|
+
self,
|
|
211
|
+
array_name: str,
|
|
212
|
+
axes: str,
|
|
213
|
+
data_shape: Sequence[int],
|
|
214
|
+
shards: tuple[int, ...] | None,
|
|
215
|
+
chunks: tuple[int, ...] | None,
|
|
216
|
+
) -> None:
|
|
217
|
+
"""Create a new array in an existing zarr group.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
array_name : str
|
|
222
|
+
Name of the array within the zarr group.
|
|
223
|
+
axes : str
|
|
224
|
+
Axes string in SC(Z)YX format with original data order.
|
|
225
|
+
data_shape : Sequence[int]
|
|
226
|
+
Shape of the array.
|
|
227
|
+
shards : tuple[int, ...] or None
|
|
228
|
+
Shard size for the array.
|
|
229
|
+
chunks : tuple[int, ...] or None
|
|
230
|
+
Chunk size for the array.
|
|
231
|
+
|
|
232
|
+
Raises
|
|
233
|
+
------
|
|
234
|
+
RuntimeError
|
|
235
|
+
If the zarr group has not been initialized.
|
|
236
|
+
"""
|
|
237
|
+
if self.current_group is None:
|
|
238
|
+
raise RuntimeError("Zarr group not initialized.")
|
|
239
|
+
|
|
240
|
+
if array_name not in self.current_group:
|
|
241
|
+
# get shape without non-existing axes (S or C)
|
|
242
|
+
updated_shape = _update_data_shape(axes, data_shape)
|
|
243
|
+
|
|
244
|
+
if chunks is not None and len(updated_shape) != len(chunks):
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"Shape {updated_shape} and chunks {chunks} have different lengths."
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if chunks is None:
|
|
250
|
+
chunks = _auto_chunks(axes, data_shape)
|
|
251
|
+
|
|
252
|
+
# TODO if we auto_chunks, we probably want to auto shards as well
|
|
253
|
+
# there is shards="auto" in zarr, where array.target_shard_size_bytes
|
|
254
|
+
# needs to be used (see zarr-python docs)
|
|
255
|
+
if shards is not None and len(chunks) != len(shards):
|
|
256
|
+
raise ValueError(
|
|
257
|
+
f"Chunks {chunks} and shards {shards} have different lengths."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
self.current_array = self.current_group.create_array(
|
|
261
|
+
name=array_name,
|
|
262
|
+
shape=updated_shape,
|
|
263
|
+
shards=shards,
|
|
264
|
+
chunks=chunks,
|
|
265
|
+
dtype=float32,
|
|
266
|
+
)
|
|
267
|
+
else:
|
|
268
|
+
current_array = self.current_group[array_name]
|
|
269
|
+
if not isinstance(current_array, zarr.Array):
|
|
270
|
+
raise RuntimeError(f"Zarr array at {array_name} is not an array.")
|
|
271
|
+
self.current_array = current_array
|
|
272
|
+
|
|
273
|
+
def write_tile(self, dirpath: Path, region: ImageRegionData) -> None:
|
|
274
|
+
"""Write cropped tile to zarr array.
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
dirpath : Path
|
|
279
|
+
Path to directory to save predictions to.
|
|
280
|
+
region : ImageRegionData
|
|
281
|
+
Image region data containing tile information.
|
|
282
|
+
"""
|
|
283
|
+
if is_valid_uri(region.source):
|
|
284
|
+
store_path, parent_path, array_name = decipher_zarr_uri(region.source)
|
|
285
|
+
output_store_path = _add_output_key(dirpath, store_path)
|
|
286
|
+
else:
|
|
287
|
+
raise NotImplementedError(
|
|
288
|
+
f"Invalid zarr URI: {region.source}. Currently, only predicting from "
|
|
289
|
+
f"Zarr files is supported when writing Zarr tiles."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
if (
|
|
293
|
+
self.current_group is None
|
|
294
|
+
or str(self.current_group.store_path)[: len(OUTPUT_KEY)]
|
|
295
|
+
!= output_store_path
|
|
296
|
+
):
|
|
297
|
+
self._create_zarr(output_store_path)
|
|
298
|
+
|
|
299
|
+
if self.current_group is None or self.current_group.name != parent_path:
|
|
300
|
+
self._create_group(parent_path)
|
|
301
|
+
|
|
302
|
+
if self.current_array is None or self.current_array.basename != array_name:
|
|
303
|
+
# data_shape, chunks and shards are in SC(Z)YX order since they are reshaped
|
|
304
|
+
# in the zarr image stack loader
|
|
305
|
+
# If the source is not a Zarr file, then chunks and shards will be `None`.
|
|
306
|
+
shape = region.data_shape
|
|
307
|
+
chunks: tuple[int, ...] | None = region.additional_metadata.get(
|
|
308
|
+
"chunks", None
|
|
309
|
+
)
|
|
310
|
+
shards: tuple[int, ...] | None = region.additional_metadata.get(
|
|
311
|
+
"shards", None
|
|
312
|
+
)
|
|
313
|
+
self._create_array(array_name, region.axes, shape, shards, chunks)
|
|
314
|
+
|
|
315
|
+
assert is_tile_specs(region.region_spec) # for mypy
|
|
316
|
+
tile_spec: TileSpecs = region.region_spec
|
|
317
|
+
crop_coords = tile_spec["crop_coords"]
|
|
318
|
+
crop_size = tile_spec["crop_size"]
|
|
319
|
+
stitch_coords = tile_spec["stitch_coords"]
|
|
320
|
+
|
|
321
|
+
# compute sample slice
|
|
322
|
+
sample_idx = tile_spec["sample_idx"]
|
|
323
|
+
|
|
324
|
+
# TODO there is duplicated code in stitch_prediction
|
|
325
|
+
crop_slices: tuple[builtins.ellipsis | slice | int, ...] = (
|
|
326
|
+
...,
|
|
327
|
+
*[
|
|
328
|
+
slice(start, start + length)
|
|
329
|
+
for start, length in zip(crop_coords, crop_size, strict=True)
|
|
330
|
+
],
|
|
331
|
+
)
|
|
332
|
+
stitch_slices: tuple[builtins.ellipsis | slice | int, ...] = (
|
|
333
|
+
...,
|
|
334
|
+
*[
|
|
335
|
+
slice(start, start + length)
|
|
336
|
+
for start, length in zip(stitch_coords, crop_size, strict=True)
|
|
337
|
+
],
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
if self.current_array is not None:
|
|
341
|
+
# region.data has shape C(Z)YX, broadcast can fail with singleton dims
|
|
342
|
+
crop = region.data[crop_slices]
|
|
343
|
+
|
|
344
|
+
if region.data.shape[0] == 1 and "C" not in region.axes:
|
|
345
|
+
# singleton C dim, need to remove it before writing
|
|
346
|
+
# unless it was present in the original axes
|
|
347
|
+
crop = crop[0]
|
|
348
|
+
|
|
349
|
+
if "S" in region.axes:
|
|
350
|
+
if "C" in region.axes:
|
|
351
|
+
stitch_slices = (sample_idx, *stitch_slices[0:])
|
|
352
|
+
else:
|
|
353
|
+
stitch_slices = (sample_idx, *stitch_slices[1:])
|
|
354
|
+
|
|
355
|
+
self.current_array[stitch_slices] = crop
|
|
356
|
+
else:
|
|
357
|
+
raise RuntimeError("Zarr array not initialized.")
|
|
358
|
+
|
|
359
|
+
def write_batch(
|
|
360
|
+
self,
|
|
361
|
+
dirpath: Path,
|
|
362
|
+
predictions: list[ImageRegionData],
|
|
363
|
+
) -> None:
|
|
364
|
+
"""
|
|
365
|
+
Write all tiles to a Zarr file.
|
|
366
|
+
|
|
367
|
+
Parameters
|
|
368
|
+
----------
|
|
369
|
+
dirpath : Path
|
|
370
|
+
Path to directory to save predictions to.
|
|
371
|
+
predictions : list[ImageRegionData]
|
|
372
|
+
Decollated predictions.
|
|
373
|
+
"""
|
|
374
|
+
for region in predictions:
|
|
375
|
+
self.write_tile(dirpath, region)
|