careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Module containing `PredictionWriterCallback` class."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
10
|
+
from pytorch_lightning.callbacks import BasePredictionWriter
|
|
11
|
+
|
|
12
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
13
|
+
from careamics.file_io.write.get_func import SupportedWriteType, WriteFunc
|
|
14
|
+
from careamics.lightning.dataset_ng.prediction import decollate_image_region_data
|
|
15
|
+
from careamics.utils import get_logger
|
|
16
|
+
|
|
17
|
+
from .write_strategy import WriteStrategy
|
|
18
|
+
from .write_strategy_factory import create_write_strategy
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PredictionWriterCallback(BasePredictionWriter):
|
|
24
|
+
"""
|
|
25
|
+
PyTorch Lightning callback to save predictions.
|
|
26
|
+
|
|
27
|
+
A `WriteStrategy` must be provided at instantiation or later via
|
|
28
|
+
`set_writing_strategy`.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
dirpath : Path or str, default="predictions"
|
|
33
|
+
The path to the directory where prediction outputs will be saved. If
|
|
34
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
35
|
+
directory.
|
|
36
|
+
write_strategy : WriteStrategy or None, default=None
|
|
37
|
+
A strategy for writing predictions.
|
|
38
|
+
|
|
39
|
+
Attributes
|
|
40
|
+
----------
|
|
41
|
+
writing_predictions : bool
|
|
42
|
+
If writing predictions is turned on or off.
|
|
43
|
+
dirpath : pathlib.Path, default=""
|
|
44
|
+
The path to the directory where prediction outputs will be saved. If
|
|
45
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
46
|
+
directory.
|
|
47
|
+
write_strategy : WriteStrategy or None
|
|
48
|
+
A strategy for writing predictions.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
dirpath: Path | str = "",
|
|
54
|
+
write_strategy: WriteStrategy | None = None,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Constructor.
|
|
58
|
+
|
|
59
|
+
A `WriteStrategy` must be provided at instantiation or later via
|
|
60
|
+
`set_writing_strategy`.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
dirpath : pathlib.Path or str, default="predictions"
|
|
65
|
+
The path to the directory where prediction outputs will be saved. If
|
|
66
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
67
|
+
directory.
|
|
68
|
+
write_strategy : WriteStrategy or None, default=None
|
|
69
|
+
A strategy for writing predictions.
|
|
70
|
+
"""
|
|
71
|
+
super().__init__(write_interval="batch")
|
|
72
|
+
|
|
73
|
+
self.writing_predictions = True # flag to turn off predictions
|
|
74
|
+
|
|
75
|
+
# forward declaration
|
|
76
|
+
self.write_strategy: WriteStrategy
|
|
77
|
+
if write_strategy is not None: # avoid `WriteStrategy | None` type
|
|
78
|
+
self.write_strategy = write_strategy
|
|
79
|
+
|
|
80
|
+
self.dirpath: Path
|
|
81
|
+
|
|
82
|
+
# if a dirpath is provided, initialize it
|
|
83
|
+
# in some cases (e.g. zarr), destination is provided by the zarr store path
|
|
84
|
+
if dirpath != "":
|
|
85
|
+
self._init_dirpath(dirpath)
|
|
86
|
+
|
|
87
|
+
def disable_writing(self, disable_writing: bool) -> None:
|
|
88
|
+
"""Disable writing.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
disable_writing : bool
|
|
93
|
+
If writing predictions should be disabled.
|
|
94
|
+
"""
|
|
95
|
+
self.writing_predictions = disable_writing
|
|
96
|
+
|
|
97
|
+
def _init_dirpath(self, dirpath):
|
|
98
|
+
"""
|
|
99
|
+
Initialize directory path. Should only be called from `__init__`.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
dirpath : pathlib.Path
|
|
104
|
+
See `__init__` description.
|
|
105
|
+
"""
|
|
106
|
+
dirpath = Path(dirpath)
|
|
107
|
+
if not dirpath.is_absolute():
|
|
108
|
+
dirpath = Path.cwd() / dirpath
|
|
109
|
+
logger.warning(
|
|
110
|
+
"Prediction output directory is not absolute, absolute path assumed to"
|
|
111
|
+
f"be '{dirpath}'"
|
|
112
|
+
)
|
|
113
|
+
self.dirpath = dirpath
|
|
114
|
+
|
|
115
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
116
|
+
"""
|
|
117
|
+
Create the prediction output directory when predict begins.
|
|
118
|
+
|
|
119
|
+
Called when fit, validate, test, predict, or tune begins.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
trainer : Trainer
|
|
124
|
+
PyTorch Lightning trainer.
|
|
125
|
+
pl_module : LightningModule
|
|
126
|
+
PyTorch Lightning module.
|
|
127
|
+
stage : str
|
|
128
|
+
Stage of training e.g. 'predict', 'fit', 'validate'.
|
|
129
|
+
"""
|
|
130
|
+
super().setup(trainer, pl_module, stage)
|
|
131
|
+
if stage == "predict":
|
|
132
|
+
if self.dirpath is not None:
|
|
133
|
+
# make prediction output directory
|
|
134
|
+
logger.info("Making prediction output directory.")
|
|
135
|
+
self.dirpath.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
|
|
137
|
+
def set_writing_strategy(
|
|
138
|
+
self,
|
|
139
|
+
write_type: SupportedWriteType,
|
|
140
|
+
tiled: bool,
|
|
141
|
+
write_func: WriteFunc | None = None,
|
|
142
|
+
write_extension: str | None = None,
|
|
143
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
144
|
+
) -> None:
|
|
145
|
+
"""
|
|
146
|
+
Set the writing strategy.
|
|
147
|
+
|
|
148
|
+
Must be called before writing predictions.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
write_type : SupportedWriteType
|
|
153
|
+
The type of writing to perform.
|
|
154
|
+
tiled : bool
|
|
155
|
+
Whether to write in tiled format.
|
|
156
|
+
write_func : WriteFunc or None, default=None
|
|
157
|
+
A custom writing function.
|
|
158
|
+
write_extension : str or None, default=None
|
|
159
|
+
The file extension to use when writing files.
|
|
160
|
+
write_func_kwargs : dict of str to Any, default=None
|
|
161
|
+
Additional keyword arguments to pass to `write_func`.
|
|
162
|
+
"""
|
|
163
|
+
self.write_strategy = create_write_strategy(
|
|
164
|
+
write_type=write_type,
|
|
165
|
+
tiled=tiled,
|
|
166
|
+
write_func=write_func,
|
|
167
|
+
write_extension=write_extension,
|
|
168
|
+
write_func_kwargs=write_func_kwargs,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def write_on_batch_end(
|
|
172
|
+
self,
|
|
173
|
+
trainer: Trainer,
|
|
174
|
+
pl_module: LightningModule,
|
|
175
|
+
prediction: ImageRegionData,
|
|
176
|
+
batch_indices: Sequence[int] | None,
|
|
177
|
+
batch: ImageRegionData,
|
|
178
|
+
batch_idx: int,
|
|
179
|
+
dataloader_idx: int,
|
|
180
|
+
) -> None:
|
|
181
|
+
"""
|
|
182
|
+
Write predictions at the end of a batch.
|
|
183
|
+
|
|
184
|
+
Writing method is determined by the attribute `write_strategy`.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
trainer : Trainer
|
|
189
|
+
PyTorch Lightning trainer.
|
|
190
|
+
pl_module : LightningModule
|
|
191
|
+
PyTorch Lightning module.
|
|
192
|
+
prediction : ImageRegionData
|
|
193
|
+
Prediction outputs of `batch`.
|
|
194
|
+
batch_indices : sequence of Any, optional
|
|
195
|
+
Batch indices.
|
|
196
|
+
batch : ImageRegionData
|
|
197
|
+
Input batch.
|
|
198
|
+
batch_idx : int
|
|
199
|
+
Batch index.
|
|
200
|
+
dataloader_idx : int
|
|
201
|
+
Dataloader index.
|
|
202
|
+
"""
|
|
203
|
+
# if writing prediction is turned off
|
|
204
|
+
if not self.writing_predictions:
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
if self.write_strategy is not None:
|
|
208
|
+
assert prediction is not None
|
|
209
|
+
predictions = decollate_image_region_data(prediction)
|
|
210
|
+
|
|
211
|
+
self.write_strategy.write_batch(
|
|
212
|
+
dirpath=self.dirpath,
|
|
213
|
+
predictions=predictions,
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
raise RuntimeError(
|
|
217
|
+
"No write strategy defined for `PredictionWriterCallback`, cannot write"
|
|
218
|
+
" predictions. Call `set_writing_strategy` to pass a write strategy."
|
|
219
|
+
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""A strategy writing whole images directly."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
7
|
+
from careamics.file_io import WriteFunc
|
|
8
|
+
from careamics.lightning.dataset_ng.prediction import (
|
|
9
|
+
combine_samples,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from .file_path_utils import create_write_file_path
|
|
13
|
+
from .write_strategy import WriteStrategy
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# TODO bug: batch is over samples for whole images, if one batch does not cover
|
|
17
|
+
# all samples, it will write an incomplete image, then overwrite it whith the next
|
|
18
|
+
# batch
|
|
19
|
+
class WriteImage(WriteStrategy):
|
|
20
|
+
"""
|
|
21
|
+
A strategy for writing image predictions (i.e. un-tiled predictions).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
write_func : WriteFunc
|
|
26
|
+
Function used to save predictions.
|
|
27
|
+
write_extension : str
|
|
28
|
+
Extension added to prediction file paths.
|
|
29
|
+
write_func_kwargs : dict of {str: Any}
|
|
30
|
+
Extra kwargs to pass to `write_func`.
|
|
31
|
+
|
|
32
|
+
Attributes
|
|
33
|
+
----------
|
|
34
|
+
write_func : WriteFunc
|
|
35
|
+
Function used to save predictions.
|
|
36
|
+
write_extension : str
|
|
37
|
+
Extension added to prediction file paths.
|
|
38
|
+
write_func_kwargs : dict of {str: Any}
|
|
39
|
+
Extra kwargs to pass to `write_func`.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
write_func: WriteFunc,
|
|
45
|
+
write_extension: str,
|
|
46
|
+
write_func_kwargs: dict[str, Any],
|
|
47
|
+
) -> None:
|
|
48
|
+
"""
|
|
49
|
+
A strategy for writing image predictions (i.e. un-tiled predictions).
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
write_func : WriteFunc
|
|
54
|
+
Function used to save predictions.
|
|
55
|
+
write_extension : str
|
|
56
|
+
Extension added to prediction file paths.
|
|
57
|
+
write_func_kwargs : dict of {str: Any}
|
|
58
|
+
Extra kwargs to pass to `write_func`.
|
|
59
|
+
"""
|
|
60
|
+
super().__init__()
|
|
61
|
+
|
|
62
|
+
self.write_func: WriteFunc = write_func
|
|
63
|
+
self.write_extension: str = write_extension
|
|
64
|
+
self.write_func_kwargs: dict[str, Any] = write_func_kwargs
|
|
65
|
+
|
|
66
|
+
def write_batch(
|
|
67
|
+
self,
|
|
68
|
+
dirpath: Path,
|
|
69
|
+
predictions: list[ImageRegionData],
|
|
70
|
+
) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Save full images.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
dirpath : Path
|
|
77
|
+
Path to directory to save predictions to.
|
|
78
|
+
predictions : list[ImageRegionData]
|
|
79
|
+
Decollated predictions.
|
|
80
|
+
"""
|
|
81
|
+
assert predictions is not None
|
|
82
|
+
|
|
83
|
+
image_lst, sources = combine_samples(predictions)
|
|
84
|
+
|
|
85
|
+
for i, image in enumerate(image_lst):
|
|
86
|
+
file_path = create_write_file_path(
|
|
87
|
+
dirpath=dirpath,
|
|
88
|
+
file_path=Path(sources[i]),
|
|
89
|
+
write_extension=self.write_extension,
|
|
90
|
+
)
|
|
91
|
+
self.write_func(file_path=file_path, img=image, **self.write_func_kwargs)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Module containing different strategies for writing predictions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class WriteStrategy(Protocol):
|
|
10
|
+
"""Protocol for write strategy classes."""
|
|
11
|
+
|
|
12
|
+
def write_batch(
|
|
13
|
+
self,
|
|
14
|
+
dirpath: Path,
|
|
15
|
+
predictions: list[ImageRegionData],
|
|
16
|
+
) -> None:
|
|
17
|
+
"""
|
|
18
|
+
WriteStrategy subclasses must contain this function to write a batch.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
dirpath : Path
|
|
23
|
+
Path to directory to save predictions to.
|
|
24
|
+
predictions : list[ImageRegionData]
|
|
25
|
+
Decollated predictions.
|
|
26
|
+
"""
|
|
27
|
+
...
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""Module containing convenience function to create `WriteStrategy`."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from careamics.config.support import SupportedData
|
|
6
|
+
from careamics.file_io import SupportedWriteType, WriteFunc, get_write_func
|
|
7
|
+
|
|
8
|
+
from .cached_tiles_strategy import CachedTiles
|
|
9
|
+
from .write_image_strategy import WriteImage
|
|
10
|
+
from .write_strategy import WriteStrategy
|
|
11
|
+
from .write_tiles_zarr_strategy import WriteTilesZarr
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_write_strategy(
|
|
15
|
+
write_type: SupportedWriteType,
|
|
16
|
+
tiled: bool,
|
|
17
|
+
write_func: WriteFunc | None = None,
|
|
18
|
+
write_extension: str | None = None,
|
|
19
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
20
|
+
) -> WriteStrategy:
|
|
21
|
+
"""
|
|
22
|
+
Create a write strategy from convenient parameters.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
write_type : {"tiff", "zarr", "custom"}
|
|
27
|
+
The data type to save as, includes custom.
|
|
28
|
+
tiled : bool
|
|
29
|
+
Whether the prediction will be tiled or not.
|
|
30
|
+
write_func : WriteFunc, optional
|
|
31
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
32
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
33
|
+
write_extension : str, optional
|
|
34
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
35
|
+
`write_type` an extension to save the data with must be passed.
|
|
36
|
+
write_func_kwargs : dict of {str: any}, optional
|
|
37
|
+
Additional keyword arguments to be passed to the save function.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
WriteStrategy
|
|
42
|
+
A strategy for writing predicions.
|
|
43
|
+
|
|
44
|
+
Notes
|
|
45
|
+
-----
|
|
46
|
+
The `write_func` function signature must match that of the example below
|
|
47
|
+
```
|
|
48
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
The `write_func_kwargs` will be passed to the `write_func` doing the following:
|
|
52
|
+
```
|
|
53
|
+
write_func(file_path=file_path, img=img, **kwargs)
|
|
54
|
+
```
|
|
55
|
+
"""
|
|
56
|
+
if write_func_kwargs is None:
|
|
57
|
+
write_func_kwargs = {}
|
|
58
|
+
|
|
59
|
+
write_strategy: WriteStrategy
|
|
60
|
+
if not tiled:
|
|
61
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
62
|
+
write_extension = select_write_extension(
|
|
63
|
+
write_type=write_type, write_extension=write_extension
|
|
64
|
+
)
|
|
65
|
+
write_strategy = WriteImage(
|
|
66
|
+
write_func=write_func,
|
|
67
|
+
write_extension=write_extension,
|
|
68
|
+
write_func_kwargs=write_func_kwargs,
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
# select CacheTiles or ZarrTiles
|
|
72
|
+
write_strategy = _create_tiled_write_strategy(
|
|
73
|
+
write_type=write_type,
|
|
74
|
+
write_func=write_func,
|
|
75
|
+
write_extension=write_extension,
|
|
76
|
+
write_func_kwargs=write_func_kwargs,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return write_strategy
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _create_tiled_write_strategy(
|
|
83
|
+
write_type: SupportedWriteType,
|
|
84
|
+
write_func: WriteFunc | None,
|
|
85
|
+
write_extension: str | None,
|
|
86
|
+
write_func_kwargs: dict[str, Any],
|
|
87
|
+
) -> WriteStrategy:
|
|
88
|
+
"""
|
|
89
|
+
Create a tiled write strategy.
|
|
90
|
+
|
|
91
|
+
Either `CacheTiles` for caching tiles until a whole image is predicted or
|
|
92
|
+
`ZarrTiles` for writing tiles directly to disk.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
write_type : {"tiff", "zarr", "custom"}
|
|
97
|
+
The data type to save as, includes custom.
|
|
98
|
+
write_func : WriteFunc, optional
|
|
99
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
100
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
101
|
+
write_extension : str, optional
|
|
102
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
103
|
+
`write_type` an extension to save the data with must be passed.
|
|
104
|
+
write_func_kwargs : dict of {str: any}
|
|
105
|
+
Additional keyword arguments to be passed to the save function.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
WriteStrategy
|
|
110
|
+
A strategy for writing tiled predictions.
|
|
111
|
+
|
|
112
|
+
Raises
|
|
113
|
+
------
|
|
114
|
+
NotImplementedError
|
|
115
|
+
if `write_type="zarr" is chosen.
|
|
116
|
+
"""
|
|
117
|
+
if write_type == "zarr":
|
|
118
|
+
return WriteTilesZarr()
|
|
119
|
+
else:
|
|
120
|
+
write_func = select_write_func(write_type=write_type, write_func=write_func)
|
|
121
|
+
write_extension = select_write_extension(
|
|
122
|
+
write_type=write_type, write_extension=write_extension
|
|
123
|
+
)
|
|
124
|
+
return CachedTiles(
|
|
125
|
+
write_func=write_func,
|
|
126
|
+
write_extension=write_extension,
|
|
127
|
+
write_func_kwargs=write_func_kwargs,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def select_write_func(
|
|
132
|
+
write_type: SupportedWriteType, write_func: WriteFunc | None = None
|
|
133
|
+
) -> WriteFunc:
|
|
134
|
+
"""
|
|
135
|
+
Return a function to write images.
|
|
136
|
+
|
|
137
|
+
If `write_type` is "custom" then `write_func`, otherwise the known write function
|
|
138
|
+
is selected.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
write_type : {"tiff", "custom"}
|
|
143
|
+
The data type to save as, includes custom.
|
|
144
|
+
write_func : WriteFunc, optional
|
|
145
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
146
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
WriteFunc
|
|
151
|
+
A function for writing images.
|
|
152
|
+
|
|
153
|
+
Raises
|
|
154
|
+
------
|
|
155
|
+
ValueError
|
|
156
|
+
If `write_type="custom"` but `write_func` has not been given.
|
|
157
|
+
|
|
158
|
+
Notes
|
|
159
|
+
-----
|
|
160
|
+
The `write_func` function signature must match that of the example below
|
|
161
|
+
```
|
|
162
|
+
write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...
|
|
163
|
+
```
|
|
164
|
+
"""
|
|
165
|
+
if write_type == SupportedData.CUSTOM:
|
|
166
|
+
if write_func is None:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
"A save function must be provided for custom data types."
|
|
169
|
+
# TODO: link to how save functions should be implemented
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
write_func = write_func
|
|
173
|
+
else:
|
|
174
|
+
write_func = get_write_func(write_type)
|
|
175
|
+
return write_func
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def select_write_extension(
|
|
179
|
+
write_type: SupportedWriteType, write_extension: str | None = None
|
|
180
|
+
) -> str:
|
|
181
|
+
"""
|
|
182
|
+
Return an extension to add to file paths.
|
|
183
|
+
|
|
184
|
+
If `write_type` is "custom" then `write_extension`, otherwise the known
|
|
185
|
+
write extension is selected.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
write_type : {"tiff", "custom"}
|
|
190
|
+
The data type to save as, includes custom.
|
|
191
|
+
write_extension : str, optional
|
|
192
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
193
|
+
`write_type` an extension to save the data with must be passed.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
str
|
|
198
|
+
The extension to be added to file paths.
|
|
199
|
+
|
|
200
|
+
Raises
|
|
201
|
+
------
|
|
202
|
+
ValueError
|
|
203
|
+
If `self.save_type="custom"` but `save_extension` has not been given.
|
|
204
|
+
"""
|
|
205
|
+
write_type_: SupportedData = SupportedData(write_type) # new variable for mypy
|
|
206
|
+
if write_type_ == SupportedData.CUSTOM:
|
|
207
|
+
if write_extension is None:
|
|
208
|
+
raise ValueError("A save extension must be provided for custom data types.")
|
|
209
|
+
else:
|
|
210
|
+
write_extension = write_extension
|
|
211
|
+
else:
|
|
212
|
+
# kind of a weird pattern -> reason to move get_extension from SupportedData
|
|
213
|
+
write_extension = write_type_.get_extension(write_type_)
|
|
214
|
+
return write_extension
|