careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Module containing convenience function to create `WriteStrategy`."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from careamics.config.support import SupportedData
|
|
6
|
+
from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
|
|
7
|
+
|
|
8
|
+
from .write_strategy import CacheTiles, WriteImage, WriteStrategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_write_strategy(
|
|
12
|
+
write_type: SupportedWriteType,
|
|
13
|
+
tiled: bool,
|
|
14
|
+
write_func: WriteFunc | None = None,
|
|
15
|
+
write_extension: str | None = None,
|
|
16
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
17
|
+
) -> WriteStrategy:
|
|
18
|
+
"""
|
|
19
|
+
Create a write strategy from convenient parameters.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
write_type : {"tiff", "custom"}
|
|
24
|
+
The data type to save as, includes custom.
|
|
25
|
+
tiled : bool
|
|
26
|
+
Whether the prediction will be tiled or not.
|
|
27
|
+
write_func : WriteFunc, optional
|
|
28
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
29
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
30
|
+
write_extension : str, optional
|
|
31
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
32
|
+
`write_type` an extension to save the data with must be passed.
|
|
33
|
+
write_func_kwargs : dict of {str: any}, optional
|
|
34
|
+
Additional keyword arguments to be passed to the save function.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
WriteStrategy
|
|
39
|
+
A strategy for writing predicions.
|
|
40
|
+
|
|
41
|
+
Notes
|
|
42
|
+
-----
|
|
43
|
+
The `write_func` function signature must match that of the example below
|
|
44
|
+
```
|
|
45
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
The `write_func_kwargs` will be passed to the `write_func` doing the following:
|
|
49
|
+
```
|
|
50
|
+
write_func(file_path=file_path, img=img, **kwargs)
|
|
51
|
+
```
|
|
52
|
+
"""
|
|
53
|
+
if write_func_kwargs is None:
|
|
54
|
+
write_func_kwargs = {}
|
|
55
|
+
|
|
56
|
+
write_strategy: WriteStrategy
|
|
57
|
+
if not tiled:
|
|
58
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
59
|
+
write_extension = select_write_extension(
|
|
60
|
+
write_type=write_type, write_extension=write_extension
|
|
61
|
+
)
|
|
62
|
+
write_strategy = WriteImage(
|
|
63
|
+
write_func=write_func,
|
|
64
|
+
write_extension=write_extension,
|
|
65
|
+
write_func_kwargs=write_func_kwargs,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
# select CacheTiles or WriteTilesZarr (when implemented)
|
|
69
|
+
write_strategy = _create_tiled_write_strategy(
|
|
70
|
+
write_type=write_type,
|
|
71
|
+
write_func=write_func,
|
|
72
|
+
write_extension=write_extension,
|
|
73
|
+
write_func_kwargs=write_func_kwargs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return write_strategy
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_tiled_write_strategy(
|
|
80
|
+
write_type: SupportedWriteType,
|
|
81
|
+
write_func: WriteFunc | None,
|
|
82
|
+
write_extension: str | None,
|
|
83
|
+
write_func_kwargs: dict[str, Any],
|
|
84
|
+
) -> WriteStrategy:
|
|
85
|
+
"""
|
|
86
|
+
Create a tiled write strategy.
|
|
87
|
+
|
|
88
|
+
Either `CacheTiles` for caching tiles until a whole image is predicted or
|
|
89
|
+
`WriteTilesZarr` for writing tiles directly to disk.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
write_type : {"tiff", "custom"}
|
|
94
|
+
The data type to save as, includes custom.
|
|
95
|
+
write_func : WriteFunc, optional
|
|
96
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
97
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
98
|
+
write_extension : str, optional
|
|
99
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
100
|
+
`write_type` an extension to save the data with must be passed.
|
|
101
|
+
write_func_kwargs : dict of {str: any}
|
|
102
|
+
Additional keyword arguments to be passed to the save function.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
WriteStrategy
|
|
107
|
+
A strategy for writing tiled predictions.
|
|
108
|
+
|
|
109
|
+
Raises
|
|
110
|
+
------
|
|
111
|
+
NotImplementedError
|
|
112
|
+
if `write_type="zarr" is chosen.
|
|
113
|
+
"""
|
|
114
|
+
# if write_type == SupportedData.ZARR:
|
|
115
|
+
# create *args, **kwargs
|
|
116
|
+
# return WriteTilesZarr(*args, **kwargs)
|
|
117
|
+
# else:
|
|
118
|
+
if write_type == "zarr":
|
|
119
|
+
raise NotImplementedError("Saving to zarr is not implemented yet.")
|
|
120
|
+
else:
|
|
121
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
122
|
+
write_extension = select_write_extension(
|
|
123
|
+
write_type=write_type, write_extension=write_extension
|
|
124
|
+
)
|
|
125
|
+
return CacheTiles(
|
|
126
|
+
write_func=write_func,
|
|
127
|
+
write_extension=write_extension,
|
|
128
|
+
write_func_kwargs=write_func_kwargs,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def select_write_func(
|
|
133
|
+
write_type: SupportedWriteType, write_func: WriteFunc | None = None
|
|
134
|
+
) -> WriteFunc:
|
|
135
|
+
"""
|
|
136
|
+
Return a function to write images.
|
|
137
|
+
|
|
138
|
+
If `write_type` is "custom" then `write_func`, otherwise the known write function
|
|
139
|
+
is selected.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
write_type : {"tiff", "custom"}
|
|
144
|
+
The data type to save as, includes custom.
|
|
145
|
+
write_func : WriteFunc, optional
|
|
146
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
147
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
WriteFunc
|
|
152
|
+
A function for writing images.
|
|
153
|
+
|
|
154
|
+
Raises
|
|
155
|
+
------
|
|
156
|
+
ValueError
|
|
157
|
+
If `write_type="custom"` but `write_func` has not been given.
|
|
158
|
+
|
|
159
|
+
Notes
|
|
160
|
+
-----
|
|
161
|
+
The `write_func` function signature must match that of the example below
|
|
162
|
+
```
|
|
163
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
164
|
+
```
|
|
165
|
+
"""
|
|
166
|
+
if write_type == SupportedData.CUSTOM:
|
|
167
|
+
if write_func is None:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"A save function must be provided for custom data types."
|
|
170
|
+
# TODO: link to how save functions should be implemented
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
write_func = write_func
|
|
174
|
+
else:
|
|
175
|
+
write_func = get_write_func(write_type)
|
|
176
|
+
return write_func
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def select_write_extension(
|
|
180
|
+
write_type: SupportedWriteType, write_extension: str | None = None
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Return an extension to add to file paths.
|
|
184
|
+
|
|
185
|
+
If `write_type` is "custom" then `write_extension`, otherwise the known
|
|
186
|
+
write extension is selected.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
write_type : {"tiff", "custom"}
|
|
191
|
+
The data type to save as, includes custom.
|
|
192
|
+
write_extension : str, optional
|
|
193
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
194
|
+
`write_type` an extension to save the data with must be passed.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
str
|
|
199
|
+
The extension to be added to file paths.
|
|
200
|
+
|
|
201
|
+
Raises
|
|
202
|
+
------
|
|
203
|
+
ValueError
|
|
204
|
+
If `self.save_type="custom"` but `save_extension` has not been given.
|
|
205
|
+
"""
|
|
206
|
+
write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
|
|
207
|
+
if write_type_ == SupportedData.CUSTOM:
|
|
208
|
+
if write_extension is None:
|
|
209
|
+
raise ValueError("A save extension must be provided for custom data types.")
|
|
210
|
+
else:
|
|
211
|
+
write_extension = write_extension
|
|
212
|
+
else:
|
|
213
|
+
# kind of a weird pattern -> reason to move get_extension from SupportedData
|
|
214
|
+
write_extension = write_type_.get_extension(write_type_)
|
|
215
|
+
return write_extension
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Progressbar callback."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
7
|
+
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProgressBarCallback(TQDMProgressBar):
|
|
12
|
+
"""Progress bar for training and validation steps."""
|
|
13
|
+
|
|
14
|
+
def init_train_tqdm(self) -> tqdm:
|
|
15
|
+
"""Override this to customize the tqdm bar for training.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
tqdm
|
|
20
|
+
A tqdm bar.
|
|
21
|
+
"""
|
|
22
|
+
bar = tqdm(
|
|
23
|
+
desc="Training",
|
|
24
|
+
position=(2 * self.process_position),
|
|
25
|
+
disable=self.is_disabled,
|
|
26
|
+
leave=True,
|
|
27
|
+
dynamic_ncols=True,
|
|
28
|
+
file=sys.stdout,
|
|
29
|
+
smoothing=0,
|
|
30
|
+
)
|
|
31
|
+
return bar
|
|
32
|
+
|
|
33
|
+
def init_validation_tqdm(self) -> tqdm:
|
|
34
|
+
"""Override this to customize the tqdm bar for validation.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
tqdm
|
|
39
|
+
A tqdm bar.
|
|
40
|
+
"""
|
|
41
|
+
# The main progress bar doesn't exist in `trainer.validate()`
|
|
42
|
+
has_main_bar = self.train_progress_bar is not None
|
|
43
|
+
bar = tqdm(
|
|
44
|
+
desc="Validating",
|
|
45
|
+
position=(2 * self.process_position + has_main_bar),
|
|
46
|
+
disable=self.is_disabled,
|
|
47
|
+
leave=False,
|
|
48
|
+
dynamic_ncols=True,
|
|
49
|
+
file=sys.stdout,
|
|
50
|
+
)
|
|
51
|
+
return bar
|
|
52
|
+
|
|
53
|
+
def init_test_tqdm(self) -> tqdm:
|
|
54
|
+
"""Override this to customize the tqdm bar for testing.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
tqdm
|
|
59
|
+
A tqdm bar.
|
|
60
|
+
"""
|
|
61
|
+
bar = tqdm(
|
|
62
|
+
desc="Testing",
|
|
63
|
+
position=(2 * self.process_position),
|
|
64
|
+
disable=self.is_disabled,
|
|
65
|
+
leave=True,
|
|
66
|
+
dynamic_ncols=False,
|
|
67
|
+
ncols=100,
|
|
68
|
+
file=sys.stdout,
|
|
69
|
+
)
|
|
70
|
+
return bar
|
|
71
|
+
|
|
72
|
+
def get_metrics(
|
|
73
|
+
self, trainer: Trainer, pl_module: LightningModule
|
|
74
|
+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
|
|
75
|
+
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
trainer : Trainer
|
|
80
|
+
The trainer object.
|
|
81
|
+
pl_module : LightningModule
|
|
82
|
+
The LightningModule object, unused.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
dict
|
|
87
|
+
A dictionary with the metrics to display in the progress bar.
|
|
88
|
+
"""
|
|
89
|
+
pbar_metrics = trainer.progress_bar_metrics
|
|
90
|
+
return {**pbar_metrics}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Next-Generation DataModules for Careamics."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""NG Dataset compatible callbacks for PyTorch Lightning."""
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""A package for the `PredictionWriterCallback` class and utilities."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CachedTiles",
|
|
5
|
+
"PredictionWriterCallback",
|
|
6
|
+
"WriteImage",
|
|
7
|
+
"WriteStrategy",
|
|
8
|
+
"WriteTilesZarr",
|
|
9
|
+
"create_write_file_path",
|
|
10
|
+
"create_write_strategy",
|
|
11
|
+
"decollate_image_region_data",
|
|
12
|
+
"select_write_extension",
|
|
13
|
+
"select_write_func",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
from .cached_tiles_strategy import CachedTiles
|
|
17
|
+
from .file_path_utils import create_write_file_path
|
|
18
|
+
from .prediction_writer_callback import (
|
|
19
|
+
PredictionWriterCallback,
|
|
20
|
+
decollate_image_region_data,
|
|
21
|
+
)
|
|
22
|
+
from .write_image_strategy import WriteImage
|
|
23
|
+
from .write_strategy import WriteStrategy
|
|
24
|
+
from .write_strategy_factory import (
|
|
25
|
+
create_write_strategy,
|
|
26
|
+
select_write_extension,
|
|
27
|
+
select_write_func,
|
|
28
|
+
)
|
|
29
|
+
from .write_tiles_zarr_strategy import WriteTilesZarr
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""A writing strategy that caches tiles until a whole image is predicted."""
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
8
|
+
from careamics.file_io import WriteFunc
|
|
9
|
+
from careamics.lightning.dataset_ng.prediction import (
|
|
10
|
+
stitch_single_prediction,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .file_path_utils import create_write_file_path
|
|
14
|
+
from .write_strategy import WriteStrategy
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CachedTiles(WriteStrategy):
|
|
18
|
+
"""
|
|
19
|
+
A write strategy that will cache tiles.
|
|
20
|
+
|
|
21
|
+
Tiles are cached until a whole image is predicted on. Then the stitched
|
|
22
|
+
prediction is saved.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
write_func : WriteFunc
|
|
27
|
+
Function used to save predictions.
|
|
28
|
+
write_extension : str
|
|
29
|
+
Extension added to prediction file paths.
|
|
30
|
+
write_func_kwargs : dict of {str: Any}
|
|
31
|
+
Extra kwargs to pass to `write_func`.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
write_func : WriteFunc
|
|
36
|
+
Function used to save predictions.
|
|
37
|
+
write_extension : str
|
|
38
|
+
Extension added to prediction file paths.
|
|
39
|
+
write_func_kwargs : dict of {str: Any}
|
|
40
|
+
Extra kwargs to pass to `write_func`.
|
|
41
|
+
tile_cache : list of numpy.ndarray
|
|
42
|
+
Tiles cached for stitching prediction.
|
|
43
|
+
tile_info_cache : list of TileInformation
|
|
44
|
+
Cached tile information for stitching prediction.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
write_func: WriteFunc,
|
|
50
|
+
write_extension: str,
|
|
51
|
+
write_func_kwargs: dict[str, Any],
|
|
52
|
+
) -> None:
|
|
53
|
+
"""
|
|
54
|
+
A write strategy that will cache tiles.
|
|
55
|
+
|
|
56
|
+
Tiles are cached until a whole image is predicted on. Then the stitched
|
|
57
|
+
prediction is saved.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
write_func : WriteFunc
|
|
62
|
+
Function used to save predictions.
|
|
63
|
+
write_extension : str
|
|
64
|
+
Extension added to prediction file paths.
|
|
65
|
+
write_func_kwargs : dict of {str: Any}
|
|
66
|
+
Extra kwargs to pass to `write_func`.
|
|
67
|
+
"""
|
|
68
|
+
super().__init__()
|
|
69
|
+
|
|
70
|
+
self.write_func: WriteFunc = write_func
|
|
71
|
+
self.write_extension: str = write_extension
|
|
72
|
+
self.write_func_kwargs: dict[str, Any] = write_func_kwargs
|
|
73
|
+
|
|
74
|
+
# where tiles will be cached until a whole image has been predicted
|
|
75
|
+
self.tile_cache: dict[int, list[ImageRegionData]] = defaultdict(list)
|
|
76
|
+
|
|
77
|
+
def write_batch(
|
|
78
|
+
self,
|
|
79
|
+
dirpath: Path,
|
|
80
|
+
predictions: list[ImageRegionData],
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
Cache tiles until the last tile is predicted, then save the stitched image.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
dirpath : Path
|
|
88
|
+
Path to directory to save predictions to.
|
|
89
|
+
predictions : list[ImageRegionData]
|
|
90
|
+
Decollated predictions.
|
|
91
|
+
"""
|
|
92
|
+
assert predictions is not None
|
|
93
|
+
|
|
94
|
+
# cache tiles
|
|
95
|
+
for tile in predictions:
|
|
96
|
+
data_idx = tile.region_spec["data_idx"]
|
|
97
|
+
self.tile_cache[data_idx].append(tile)
|
|
98
|
+
|
|
99
|
+
self._write_images(dirpath)
|
|
100
|
+
|
|
101
|
+
def _get_full_images(self) -> list[int]:
|
|
102
|
+
"""
|
|
103
|
+
Get data indices of full images contained in the cache.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
list of int
|
|
108
|
+
Data indices of full images contained in the cache.
|
|
109
|
+
"""
|
|
110
|
+
full_images = []
|
|
111
|
+
for data_idx in self.tile_cache.keys():
|
|
112
|
+
exp_n_tiles = self.tile_cache[data_idx][0].region_spec["total_tiles"]
|
|
113
|
+
|
|
114
|
+
if len(self.tile_cache[data_idx]) == exp_n_tiles:
|
|
115
|
+
full_images.append(data_idx)
|
|
116
|
+
elif len(self.tile_cache[data_idx]) > exp_n_tiles:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"More tiles cached for data_idx {data_idx} than expected. "
|
|
119
|
+
f"Expected {exp_n_tiles}, found "
|
|
120
|
+
f"{len(self.tile_cache[data_idx])}."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return full_images
|
|
124
|
+
|
|
125
|
+
def _stitch_and_write_single(
|
|
126
|
+
self, dirpath: Path, tiles: list[ImageRegionData]
|
|
127
|
+
) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Stitch and write a single image from tiles.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
dirpath : Path
|
|
134
|
+
Path to directory to save predictions to.
|
|
135
|
+
tiles : list[ImageRegionData]
|
|
136
|
+
Tiles to stitch and write.
|
|
137
|
+
"""
|
|
138
|
+
# stitch prediction
|
|
139
|
+
prediction_image = stitch_single_prediction(tiles)
|
|
140
|
+
|
|
141
|
+
# write prediction
|
|
142
|
+
source: Path = Path(tiles[0].source)
|
|
143
|
+
file_path = create_write_file_path(
|
|
144
|
+
dirpath=dirpath,
|
|
145
|
+
file_path=source,
|
|
146
|
+
write_extension=self.write_extension,
|
|
147
|
+
)
|
|
148
|
+
self.write_func(
|
|
149
|
+
file_path=file_path, img=prediction_image, **self.write_func_kwargs
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def _write_images(self, dirpath: Path) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Write full images from cached tiles.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
dirpath : Path
|
|
159
|
+
Path to directory to save predictions to.
|
|
160
|
+
"""
|
|
161
|
+
full_images = self._get_full_images()
|
|
162
|
+
for data_idx in full_images:
|
|
163
|
+
tiles = self.tile_cache.pop(data_idx)
|
|
164
|
+
self._stitch_and_write_single(dirpath, tiles)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Module containing file path utilities for `WriteStrategy` to use."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_write_file_path(
|
|
7
|
+
dirpath: Path, file_path: Path, write_extension: str
|
|
8
|
+
) -> Path:
|
|
9
|
+
"""
|
|
10
|
+
Create the file name for the output file.
|
|
11
|
+
|
|
12
|
+
Takes the original file path, changes the directory to `dirpath` and changes
|
|
13
|
+
the extension to `write_extension`.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
dirpath : pathlib.Path
|
|
18
|
+
The output directory to write file to.
|
|
19
|
+
file_path : pathlib.Path
|
|
20
|
+
The original file path.
|
|
21
|
+
write_extension : str
|
|
22
|
+
The extension that output files should have.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Path
|
|
27
|
+
The output file path.
|
|
28
|
+
"""
|
|
29
|
+
file_path = Path(file_path) # as a guard against str input
|
|
30
|
+
|
|
31
|
+
file_name = Path(file_path.stem).with_suffix(write_extension)
|
|
32
|
+
file_path = dirpath / file_name
|
|
33
|
+
return file_path
|