careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Module containing `PredictionWriterCallback` class."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Union
|
|
8
|
+
|
|
9
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
10
|
+
from pytorch_lightning.callbacks import BasePredictionWriter
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
|
|
13
|
+
from careamics.dataset import (
|
|
14
|
+
IterablePredDataset,
|
|
15
|
+
IterableTiledPredDataset,
|
|
16
|
+
)
|
|
17
|
+
from careamics.file_io import SupportedWriteType, WriteFunc
|
|
18
|
+
from careamics.utils import get_logger
|
|
19
|
+
|
|
20
|
+
from .write_strategy import WriteStrategy
|
|
21
|
+
from .write_strategy_factory import create_write_strategy
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
ValidPredDatasets = Union[IterablePredDataset, IterableTiledPredDataset]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PredictionWriterCallback(BasePredictionWriter):
|
|
29
|
+
"""
|
|
30
|
+
A PyTorch Lightning callback to save predictions.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
write_strategy : WriteStrategy
|
|
35
|
+
A strategy for writing predictions.
|
|
36
|
+
dirpath : Path or str, default="predictions"
|
|
37
|
+
The path to the directory where prediction outputs will be saved. If
|
|
38
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
39
|
+
directory.
|
|
40
|
+
|
|
41
|
+
Attributes
|
|
42
|
+
----------
|
|
43
|
+
write_strategy : WriteStrategy
|
|
44
|
+
A strategy for writing predictions.
|
|
45
|
+
dirpath : pathlib.Path, default="predictions"
|
|
46
|
+
The path to the directory where prediction outputs will be saved. If
|
|
47
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
48
|
+
directory.
|
|
49
|
+
writing_predictions : bool
|
|
50
|
+
If writing predictions is turned on or off.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
write_strategy: WriteStrategy,
|
|
56
|
+
dirpath: Union[Path, str] = "predictions",
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
A PyTorch Lightning callback to save predictions.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
write_strategy : WriteStrategy
|
|
64
|
+
A strategy for writing predictions.
|
|
65
|
+
dirpath : pathlib.Path or str, default="predictions"
|
|
66
|
+
The path to the directory where prediction outputs will be saved. If
|
|
67
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
68
|
+
directory.
|
|
69
|
+
"""
|
|
70
|
+
super().__init__(write_interval="batch")
|
|
71
|
+
|
|
72
|
+
# Toggle for CAREamist to switch off saving if desired
|
|
73
|
+
self.writing_predictions: bool = True
|
|
74
|
+
|
|
75
|
+
self.write_strategy: WriteStrategy = write_strategy
|
|
76
|
+
|
|
77
|
+
# forward declaration
|
|
78
|
+
self.dirpath: Path
|
|
79
|
+
# attribute initialisation
|
|
80
|
+
self._init_dirpath(dirpath)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_write_func_params(
|
|
84
|
+
cls,
|
|
85
|
+
write_type: SupportedWriteType,
|
|
86
|
+
tiled: bool,
|
|
87
|
+
write_func: WriteFunc | None = None,
|
|
88
|
+
write_extension: str | None = None,
|
|
89
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
90
|
+
dirpath: Union[Path, str] = "predictions",
|
|
91
|
+
) -> PredictionWriterCallback: # TODO: change type hint to self (find out how)
|
|
92
|
+
"""
|
|
93
|
+
Initialize a `PredictionWriterCallback` from write function parameters.
|
|
94
|
+
|
|
95
|
+
This will automatically create a `WriteStrategy` to be passed to the
|
|
96
|
+
initialization of `PredictionWriterCallback`.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
write_type : {"tiff", "custom"}
|
|
101
|
+
The data type to save as, includes custom.
|
|
102
|
+
tiled : bool
|
|
103
|
+
Whether the prediction will be tiled or not.
|
|
104
|
+
write_func : WriteFunc, optional
|
|
105
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
106
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
107
|
+
write_extension : str, optional
|
|
108
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
109
|
+
`write_type` an extension to save the data with must be passed.
|
|
110
|
+
write_func_kwargs : dict of {{str: any}}, optional
|
|
111
|
+
Additional keyword arguments to be passed to the save function.
|
|
112
|
+
dirpath : pathlib.Path or str, default="predictions"
|
|
113
|
+
The path to the directory where prediction outputs will be saved. If
|
|
114
|
+
`dirpath` is not absolute it is assumed to be relative to current working
|
|
115
|
+
directory.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
PredictionWriterCallback
|
|
120
|
+
Callback for writing predictions.
|
|
121
|
+
"""
|
|
122
|
+
write_strategy = create_write_strategy(
|
|
123
|
+
write_type=write_type,
|
|
124
|
+
tiled=tiled,
|
|
125
|
+
write_func=write_func,
|
|
126
|
+
write_extension=write_extension,
|
|
127
|
+
write_func_kwargs=write_func_kwargs,
|
|
128
|
+
)
|
|
129
|
+
return cls(write_strategy=write_strategy, dirpath=dirpath)
|
|
130
|
+
|
|
131
|
+
def _init_dirpath(self, dirpath):
|
|
132
|
+
"""
|
|
133
|
+
Initialize directory path. Should only be called from `__init__`.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
dirpath : pathlib.Path
|
|
138
|
+
See `__init__` description.
|
|
139
|
+
"""
|
|
140
|
+
dirpath = Path(dirpath)
|
|
141
|
+
if not dirpath.is_absolute():
|
|
142
|
+
dirpath = Path.cwd() / dirpath
|
|
143
|
+
logger.warning(
|
|
144
|
+
"Prediction output directory is not absolute, absolute path assumed to"
|
|
145
|
+
f"be '{dirpath}'"
|
|
146
|
+
)
|
|
147
|
+
self.dirpath = dirpath
|
|
148
|
+
|
|
149
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Create the prediction output directory when predict begins.
|
|
152
|
+
|
|
153
|
+
Called when fit, validate, test, predict, or tune begins.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
trainer : Trainer
|
|
158
|
+
PyTorch Lightning trainer.
|
|
159
|
+
pl_module : LightningModule
|
|
160
|
+
PyTorch Lightning module.
|
|
161
|
+
stage : str
|
|
162
|
+
Stage of training e.g. 'predict', 'fit', 'validate'.
|
|
163
|
+
"""
|
|
164
|
+
super().setup(trainer, pl_module, stage)
|
|
165
|
+
if stage == "predict":
|
|
166
|
+
# make prediction output directory
|
|
167
|
+
logger.info("Making prediction output directory.")
|
|
168
|
+
self.dirpath.mkdir(parents=True, exist_ok=True)
|
|
169
|
+
|
|
170
|
+
def write_on_batch_end(
|
|
171
|
+
self,
|
|
172
|
+
trainer: Trainer,
|
|
173
|
+
pl_module: LightningModule,
|
|
174
|
+
prediction: Any, # TODO: change to expected type
|
|
175
|
+
batch_indices: Sequence[int] | None,
|
|
176
|
+
batch: Any, # TODO: change to expected type
|
|
177
|
+
batch_idx: int,
|
|
178
|
+
dataloader_idx: int,
|
|
179
|
+
) -> None:
|
|
180
|
+
"""
|
|
181
|
+
Write predictions at the end of a batch.
|
|
182
|
+
|
|
183
|
+
The method of prediction is determined by the attribute `write_strategy`.
|
|
184
|
+
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
trainer : Trainer
|
|
188
|
+
PyTorch Lightning trainer.
|
|
189
|
+
pl_module : LightningModule
|
|
190
|
+
PyTorch Lightning module.
|
|
191
|
+
prediction : Any
|
|
192
|
+
Prediction outputs of `batch`.
|
|
193
|
+
batch_indices : sequence of Any, optional
|
|
194
|
+
Batch indices.
|
|
195
|
+
batch : Any
|
|
196
|
+
Input batch.
|
|
197
|
+
batch_idx : int
|
|
198
|
+
Batch index.
|
|
199
|
+
dataloader_idx : int
|
|
200
|
+
Dataloader index.
|
|
201
|
+
"""
|
|
202
|
+
# if writing prediction is turned off
|
|
203
|
+
if not self.writing_predictions:
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
|
|
207
|
+
dataloader: DataLoader = (
|
|
208
|
+
dataloaders[dataloader_idx]
|
|
209
|
+
if isinstance(dataloaders, list)
|
|
210
|
+
else dataloaders
|
|
211
|
+
)
|
|
212
|
+
dataset: ValidPredDatasets = dataloader.dataset
|
|
213
|
+
if not (
|
|
214
|
+
isinstance(dataset, IterablePredDataset)
|
|
215
|
+
or isinstance(dataset, IterableTiledPredDataset)
|
|
216
|
+
):
|
|
217
|
+
# Note: Error will be raised before here from the source type
|
|
218
|
+
# This is for extra redundancy of errors.
|
|
219
|
+
raise TypeError(
|
|
220
|
+
"Prediction dataset has to be `IterableTiledPredDataset` or "
|
|
221
|
+
"`IterablePredDataset`. Cannot be `InMemoryPredDataset` because "
|
|
222
|
+
"filenames are taken from the original file."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
self.write_strategy.write_batch(
|
|
226
|
+
trainer=trainer,
|
|
227
|
+
pl_module=pl_module,
|
|
228
|
+
prediction=prediction,
|
|
229
|
+
batch_indices=batch_indices,
|
|
230
|
+
batch=batch,
|
|
231
|
+
batch_idx=batch_idx,
|
|
232
|
+
dataloader_idx=dataloader_idx,
|
|
233
|
+
dirpath=self.dirpath,
|
|
234
|
+
)
|
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""Module containing different strategies for writing predictions."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Protocol, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
|
|
12
|
+
from careamics.config.data.tile_information import TileInformation
|
|
13
|
+
from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
|
|
14
|
+
from careamics.file_io import WriteFunc
|
|
15
|
+
from careamics.prediction_utils import stitch_prediction_single
|
|
16
|
+
|
|
17
|
+
from .file_path_utils import create_write_file_path, get_sample_file_path
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class WriteStrategy(Protocol):
|
|
21
|
+
"""Protocol for write strategy classes."""
|
|
22
|
+
|
|
23
|
+
def write_batch(
|
|
24
|
+
self,
|
|
25
|
+
trainer: Trainer,
|
|
26
|
+
pl_module: LightningModule,
|
|
27
|
+
prediction: Any, # TODO: change to expected type
|
|
28
|
+
batch_indices: Sequence[int] | None,
|
|
29
|
+
batch: Any, # TODO: change to expected type
|
|
30
|
+
batch_idx: int,
|
|
31
|
+
dataloader_idx: int,
|
|
32
|
+
dirpath: Path,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""
|
|
35
|
+
WriteStrategy subclasses must contain this function to write a batch.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
trainer : Trainer
|
|
40
|
+
PyTorch Lightning Trainer.
|
|
41
|
+
pl_module : LightningModule
|
|
42
|
+
PyTorch Lightning LightningModule.
|
|
43
|
+
prediction : Any
|
|
44
|
+
Predictions on `batch`.
|
|
45
|
+
batch_indices : sequence of int
|
|
46
|
+
Indices identifying the samples in the batch.
|
|
47
|
+
batch : Any
|
|
48
|
+
Input batch.
|
|
49
|
+
batch_idx : int
|
|
50
|
+
Batch index.
|
|
51
|
+
dataloader_idx : int
|
|
52
|
+
Dataloader index.
|
|
53
|
+
dirpath : Path
|
|
54
|
+
Path to directory to save predictions to.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class CacheTiles(WriteStrategy):
|
|
59
|
+
"""
|
|
60
|
+
A write strategy that will cache tiles.
|
|
61
|
+
|
|
62
|
+
Tiles are cached until a whole image is predicted on. Then the stitched
|
|
63
|
+
prediction is saved.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
write_func : WriteFunc
|
|
68
|
+
Function used to save predictions.
|
|
69
|
+
write_extension : str
|
|
70
|
+
Extension added to prediction file paths.
|
|
71
|
+
write_func_kwargs : dict of {str: Any}
|
|
72
|
+
Extra kwargs to pass to `write_func`.
|
|
73
|
+
|
|
74
|
+
Attributes
|
|
75
|
+
----------
|
|
76
|
+
write_func : WriteFunc
|
|
77
|
+
Function used to save predictions.
|
|
78
|
+
write_extension : str
|
|
79
|
+
Extension added to prediction file paths.
|
|
80
|
+
write_func_kwargs : dict of {str: Any}
|
|
81
|
+
Extra kwargs to pass to `write_func`.
|
|
82
|
+
tile_cache : list of numpy.ndarray
|
|
83
|
+
Tiles cached for stitching prediction.
|
|
84
|
+
tile_info_cache : list of TileInformation
|
|
85
|
+
Cached tile information for stitching prediction.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
write_func: WriteFunc,
|
|
91
|
+
write_extension: str,
|
|
92
|
+
write_func_kwargs: dict[str, Any],
|
|
93
|
+
) -> None:
|
|
94
|
+
"""
|
|
95
|
+
A write strategy that will cache tiles.
|
|
96
|
+
|
|
97
|
+
Tiles are cached until a whole image is predicted on. Then the stitched
|
|
98
|
+
prediction is saved.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
write_func : WriteFunc
|
|
103
|
+
Function used to save predictions.
|
|
104
|
+
write_extension : str
|
|
105
|
+
Extension added to prediction file paths.
|
|
106
|
+
write_func_kwargs : dict of {str: Any}
|
|
107
|
+
Extra kwargs to pass to `write_func`.
|
|
108
|
+
"""
|
|
109
|
+
super().__init__()
|
|
110
|
+
|
|
111
|
+
self.write_func: WriteFunc = write_func
|
|
112
|
+
self.write_extension: str = write_extension
|
|
113
|
+
self.write_func_kwargs: dict[str, Any] = write_func_kwargs
|
|
114
|
+
|
|
115
|
+
# where tiles will be cached until a whole image has been predicted
|
|
116
|
+
self.tile_cache: list[NDArray] = []
|
|
117
|
+
self.tile_info_cache: list[TileInformation] = []
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def last_tiles(self) -> list[bool]:
|
|
121
|
+
"""
|
|
122
|
+
List of bool to determine whether each tile in the cache is the last tile.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
list of bool
|
|
127
|
+
Whether each tile in the tile cache is the last tile.
|
|
128
|
+
"""
|
|
129
|
+
return [tile_info.last_tile for tile_info in self.tile_info_cache]
|
|
130
|
+
|
|
131
|
+
def write_batch(
|
|
132
|
+
self,
|
|
133
|
+
trainer: Trainer,
|
|
134
|
+
pl_module: LightningModule,
|
|
135
|
+
prediction: tuple[NDArray, list[TileInformation]],
|
|
136
|
+
batch_indices: Sequence[int] | None,
|
|
137
|
+
batch: tuple[NDArray, list[TileInformation]],
|
|
138
|
+
batch_idx: int,
|
|
139
|
+
dataloader_idx: int,
|
|
140
|
+
dirpath: Path,
|
|
141
|
+
) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Cache tiles until the last tile is predicted; save the stitched prediction.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
trainer : Trainer
|
|
148
|
+
PyTorch Lightning Trainer.
|
|
149
|
+
pl_module : LightningModule
|
|
150
|
+
PyTorch Lightning LightningModule.
|
|
151
|
+
prediction : Any
|
|
152
|
+
Predictions on `batch`.
|
|
153
|
+
batch_indices : sequence of int
|
|
154
|
+
Indices identifying the samples in the batch.
|
|
155
|
+
batch : Any
|
|
156
|
+
Input batch.
|
|
157
|
+
batch_idx : int
|
|
158
|
+
Batch index.
|
|
159
|
+
dataloader_idx : int
|
|
160
|
+
Dataloader index.
|
|
161
|
+
dirpath : Path
|
|
162
|
+
Path to directory to save predictions to.
|
|
163
|
+
"""
|
|
164
|
+
dataloaders: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
|
|
165
|
+
dataloader: DataLoader = (
|
|
166
|
+
dataloaders[dataloader_idx]
|
|
167
|
+
if isinstance(dataloaders, list)
|
|
168
|
+
else dataloaders
|
|
169
|
+
)
|
|
170
|
+
dataset: IterableTiledPredDataset = dataloader.dataset
|
|
171
|
+
if not isinstance(dataset, IterableTiledPredDataset):
|
|
172
|
+
raise TypeError("Prediction dataset is not `IterableTiledPredDataset`.")
|
|
173
|
+
|
|
174
|
+
# cache tiles (batches are split into single samples)
|
|
175
|
+
self.tile_cache.extend(np.split(prediction[0], prediction[0].shape[0]))
|
|
176
|
+
self.tile_info_cache.extend(prediction[1])
|
|
177
|
+
|
|
178
|
+
# save stitched prediction
|
|
179
|
+
if self._has_last_tile():
|
|
180
|
+
|
|
181
|
+
# get image tiles and remove them from the cache
|
|
182
|
+
tiles, tile_infos = self._get_image_tiles()
|
|
183
|
+
self._clear_cache()
|
|
184
|
+
|
|
185
|
+
# stitch prediction
|
|
186
|
+
prediction_image = stitch_prediction_single(
|
|
187
|
+
tiles=tiles, tile_infos=tile_infos
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# write prediction
|
|
191
|
+
sample_id = tile_infos[0].sample_id # need this to select correct file name
|
|
192
|
+
input_file_path = get_sample_file_path(dataset=dataset, sample_id=sample_id)
|
|
193
|
+
file_path = create_write_file_path(
|
|
194
|
+
dirpath=dirpath,
|
|
195
|
+
file_path=input_file_path,
|
|
196
|
+
write_extension=self.write_extension,
|
|
197
|
+
)
|
|
198
|
+
self.write_func(
|
|
199
|
+
file_path=file_path, img=prediction_image[0], **self.write_func_kwargs
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def _has_last_tile(self) -> bool:
|
|
203
|
+
"""
|
|
204
|
+
Whether a last tile is contained in the cached tiles.
|
|
205
|
+
|
|
206
|
+
Returns
|
|
207
|
+
-------
|
|
208
|
+
bool
|
|
209
|
+
Whether a last tile is contained in the cached tiles.
|
|
210
|
+
"""
|
|
211
|
+
return any(self.last_tiles)
|
|
212
|
+
|
|
213
|
+
def _clear_cache(self) -> None:
|
|
214
|
+
"""Remove the tiles in the cache up to the first last tile."""
|
|
215
|
+
index = self._last_tile_index()
|
|
216
|
+
self.tile_cache = self.tile_cache[index + 1 :]
|
|
217
|
+
self.tile_info_cache = self.tile_info_cache[index + 1 :]
|
|
218
|
+
|
|
219
|
+
def _last_tile_index(self) -> int:
|
|
220
|
+
"""
|
|
221
|
+
Find the index of the last tile in the tile cache.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
int
|
|
226
|
+
Index of last tile.
|
|
227
|
+
|
|
228
|
+
Raises
|
|
229
|
+
------
|
|
230
|
+
ValueError
|
|
231
|
+
If there is no last tile in the tile cache.
|
|
232
|
+
"""
|
|
233
|
+
last_tiles = self.last_tiles
|
|
234
|
+
if not any(last_tiles):
|
|
235
|
+
raise ValueError("No last tile in the tile cache.")
|
|
236
|
+
index = np.where(last_tiles)[0][0]
|
|
237
|
+
return index
|
|
238
|
+
|
|
239
|
+
def _get_image_tiles(self) -> tuple[list[NDArray], list[TileInformation]]:
|
|
240
|
+
"""
|
|
241
|
+
Get the tiles corresponding to a single image.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
246
|
+
Tiles and tile information to stitch together a full image.
|
|
247
|
+
"""
|
|
248
|
+
index = self._last_tile_index()
|
|
249
|
+
tiles = self.tile_cache[: index + 1]
|
|
250
|
+
tile_infos = self.tile_info_cache[: index + 1]
|
|
251
|
+
return tiles, tile_infos
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class WriteTilesZarr(WriteStrategy):
|
|
255
|
+
"""Strategy to write tiles to Zarr file."""
|
|
256
|
+
|
|
257
|
+
def write_batch(
|
|
258
|
+
self,
|
|
259
|
+
trainer: Trainer,
|
|
260
|
+
pl_module: LightningModule,
|
|
261
|
+
prediction: Any,
|
|
262
|
+
batch_indices: Sequence[int] | None,
|
|
263
|
+
batch: Any,
|
|
264
|
+
batch_idx: int,
|
|
265
|
+
dataloader_idx: int,
|
|
266
|
+
dirpath: Path,
|
|
267
|
+
) -> None:
|
|
268
|
+
"""
|
|
269
|
+
Write tiles to zarr file.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
trainer : Trainer
|
|
274
|
+
PyTorch Lightning Trainer.
|
|
275
|
+
pl_module : LightningModule
|
|
276
|
+
PyTorch Lightning LightningModule.
|
|
277
|
+
prediction : Any
|
|
278
|
+
Predictions on `batch`.
|
|
279
|
+
batch_indices : sequence of int
|
|
280
|
+
Indices identifying the samples in the batch.
|
|
281
|
+
batch : Any
|
|
282
|
+
Input batch.
|
|
283
|
+
batch_idx : int
|
|
284
|
+
Batch index.
|
|
285
|
+
dataloader_idx : int
|
|
286
|
+
Dataloader index.
|
|
287
|
+
dirpath : Path
|
|
288
|
+
Path to directory to save predictions to.
|
|
289
|
+
|
|
290
|
+
Raises
|
|
291
|
+
------
|
|
292
|
+
NotImplementedError
|
|
293
|
+
"""
|
|
294
|
+
raise NotImplementedError
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class WriteImage(WriteStrategy):
|
|
298
|
+
"""
|
|
299
|
+
A strategy for writing image predictions (i.e. un-tiled predictions).
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
write_func : WriteFunc
|
|
304
|
+
Function used to save predictions.
|
|
305
|
+
write_extension : str
|
|
306
|
+
Extension added to prediction file paths.
|
|
307
|
+
write_func_kwargs : dict of {str: Any}
|
|
308
|
+
Extra kwargs to pass to `write_func`.
|
|
309
|
+
|
|
310
|
+
Attributes
|
|
311
|
+
----------
|
|
312
|
+
write_func : WriteFunc
|
|
313
|
+
Function used to save predictions.
|
|
314
|
+
write_extension : str
|
|
315
|
+
Extension added to prediction file paths.
|
|
316
|
+
write_func_kwargs : dict of {str: Any}
|
|
317
|
+
Extra kwargs to pass to `write_func`.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
write_func: WriteFunc,
|
|
323
|
+
write_extension: str,
|
|
324
|
+
write_func_kwargs: dict[str, Any],
|
|
325
|
+
) -> None:
|
|
326
|
+
"""
|
|
327
|
+
A strategy for writing image predictions (i.e. un-tiled predictions).
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
write_func : WriteFunc
|
|
332
|
+
Function used to save predictions.
|
|
333
|
+
write_extension : str
|
|
334
|
+
Extension added to prediction file paths.
|
|
335
|
+
write_func_kwargs : dict of {str: Any}
|
|
336
|
+
Extra kwargs to pass to `write_func`.
|
|
337
|
+
"""
|
|
338
|
+
super().__init__()
|
|
339
|
+
|
|
340
|
+
self.write_func: WriteFunc = write_func
|
|
341
|
+
self.write_extension: str = write_extension
|
|
342
|
+
self.write_func_kwargs: dict[str, Any] = write_func_kwargs
|
|
343
|
+
|
|
344
|
+
def write_batch(
|
|
345
|
+
self,
|
|
346
|
+
trainer: Trainer,
|
|
347
|
+
pl_module: LightningModule,
|
|
348
|
+
prediction: NDArray,
|
|
349
|
+
batch_indices: Sequence[int] | None,
|
|
350
|
+
batch: NDArray,
|
|
351
|
+
batch_idx: int,
|
|
352
|
+
dataloader_idx: int,
|
|
353
|
+
dirpath: Path,
|
|
354
|
+
) -> None:
|
|
355
|
+
"""
|
|
356
|
+
Save full images.
|
|
357
|
+
|
|
358
|
+
Parameters
|
|
359
|
+
----------
|
|
360
|
+
trainer : Trainer
|
|
361
|
+
PyTorch Lightning Trainer.
|
|
362
|
+
pl_module : LightningModule
|
|
363
|
+
PyTorch Lightning LightningModule.
|
|
364
|
+
prediction : Any
|
|
365
|
+
Predictions on `batch`.
|
|
366
|
+
batch_indices : sequence of int
|
|
367
|
+
Indices identifying the samples in the batch.
|
|
368
|
+
batch : Any
|
|
369
|
+
Input batch.
|
|
370
|
+
batch_idx : int
|
|
371
|
+
Batch index.
|
|
372
|
+
dataloader_idx : int
|
|
373
|
+
Dataloader index.
|
|
374
|
+
dirpath : Path
|
|
375
|
+
Path to directory to save predictions to.
|
|
376
|
+
|
|
377
|
+
Raises
|
|
378
|
+
------
|
|
379
|
+
TypeError
|
|
380
|
+
If trainer prediction dataset is not `IterablePredDataset`.
|
|
381
|
+
"""
|
|
382
|
+
dls: Union[DataLoader, list[DataLoader]] = trainer.predict_dataloaders
|
|
383
|
+
dl: DataLoader = dls[dataloader_idx] if isinstance(dls, list) else dls
|
|
384
|
+
ds: IterablePredDataset = dl.dataset
|
|
385
|
+
if not isinstance(ds, IterablePredDataset):
|
|
386
|
+
raise TypeError("Prediction dataset is not `IterablePredDataset`.")
|
|
387
|
+
|
|
388
|
+
for i in range(prediction.shape[0]):
|
|
389
|
+
prediction_image = prediction[0]
|
|
390
|
+
sample_id = batch_idx * dl.batch_size + i
|
|
391
|
+
input_file_path = get_sample_file_path(dataset=ds, sample_id=sample_id)
|
|
392
|
+
file_path = create_write_file_path(
|
|
393
|
+
dirpath=dirpath,
|
|
394
|
+
file_path=input_file_path,
|
|
395
|
+
write_extension=self.write_extension,
|
|
396
|
+
)
|
|
397
|
+
self.write_func(
|
|
398
|
+
file_path=file_path, img=prediction_image, **self.write_func_kwargs
|
|
399
|
+
)
|