careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
careamics/careamist.py
ADDED
|
@@ -0,0 +1,961 @@
|
|
|
1
|
+
"""A class to train, predict and export models in CAREamics."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Literal, Union, overload
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from pytorch_lightning import Trainer
|
|
10
|
+
from pytorch_lightning.callbacks import (
|
|
11
|
+
Callback,
|
|
12
|
+
EarlyStopping,
|
|
13
|
+
ModelCheckpoint,
|
|
14
|
+
)
|
|
15
|
+
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
|
|
16
|
+
|
|
17
|
+
from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
|
|
18
|
+
from careamics.config.support import (
|
|
19
|
+
SupportedAlgorithm,
|
|
20
|
+
SupportedArchitecture,
|
|
21
|
+
SupportedData,
|
|
22
|
+
SupportedLogger,
|
|
23
|
+
)
|
|
24
|
+
from careamics.dataset.dataset_utils import list_files, reshape_array
|
|
25
|
+
from careamics.file_io import WriteFunc, get_write_func
|
|
26
|
+
from careamics.lightning import (
|
|
27
|
+
FCNModule,
|
|
28
|
+
HyperParametersCallback,
|
|
29
|
+
PredictDataModule,
|
|
30
|
+
ProgressBarCallback,
|
|
31
|
+
TrainDataModule,
|
|
32
|
+
create_predict_datamodule,
|
|
33
|
+
)
|
|
34
|
+
from careamics.model_io import export_to_bmz, load_pretrained
|
|
35
|
+
from careamics.prediction_utils import convert_outputs
|
|
36
|
+
from careamics.utils import check_path_exists, get_logger
|
|
37
|
+
from careamics.utils.lightning_utils import read_csv_logger
|
|
38
|
+
|
|
39
|
+
logger = get_logger(__name__)
|
|
40
|
+
|
|
41
|
+
LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# TODO type ignore have been added because of the czi data type in data configuration
|
|
45
|
+
class CAREamist:
|
|
46
|
+
"""Main CAREamics class, allowing training and prediction using various algorithms.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
51
|
+
Path to a configuration file or a trained model.
|
|
52
|
+
work_dir : str, optional
|
|
53
|
+
Path to working directory in which to save checkpoints and logs,
|
|
54
|
+
by default None.
|
|
55
|
+
callbacks : list of Callback, optional
|
|
56
|
+
List of callbacks to use during training and prediction, by default None.
|
|
57
|
+
enable_progress_bar : bool
|
|
58
|
+
Whether a progress bar will be displayed during training, validation and
|
|
59
|
+
prediction.
|
|
60
|
+
|
|
61
|
+
Attributes
|
|
62
|
+
----------
|
|
63
|
+
model : CAREamicsModule
|
|
64
|
+
CAREamics model.
|
|
65
|
+
cfg : Configuration
|
|
66
|
+
CAREamics configuration.
|
|
67
|
+
trainer : Trainer
|
|
68
|
+
PyTorch Lightning trainer.
|
|
69
|
+
experiment_logger : TensorBoardLogger or WandbLogger
|
|
70
|
+
Experiment logger, "wandb" or "tensorboard".
|
|
71
|
+
work_dir : pathlib.Path
|
|
72
|
+
Working directory.
|
|
73
|
+
train_datamodule : TrainDataModule
|
|
74
|
+
Training datamodule.
|
|
75
|
+
pred_datamodule : PredictDataModule
|
|
76
|
+
Prediction datamodule.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
@overload
|
|
80
|
+
def __init__( # numpydoc ignore=GL08
|
|
81
|
+
self,
|
|
82
|
+
source: Union[Path, str],
|
|
83
|
+
work_dir: Union[Path, str] | None = None,
|
|
84
|
+
callbacks: list[Callback] | None = None,
|
|
85
|
+
enable_progress_bar: bool = True,
|
|
86
|
+
) -> None: ...
|
|
87
|
+
|
|
88
|
+
@overload
|
|
89
|
+
def __init__( # numpydoc ignore=GL08
|
|
90
|
+
self,
|
|
91
|
+
source: Configuration,
|
|
92
|
+
work_dir: Union[Path, str] | None = None,
|
|
93
|
+
callbacks: list[Callback] | None = None,
|
|
94
|
+
enable_progress_bar: bool = True,
|
|
95
|
+
) -> None: ...
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
source: Union[Path, str, Configuration],
|
|
100
|
+
work_dir: Union[Path, str] | None = None,
|
|
101
|
+
callbacks: list[Callback] | None = None,
|
|
102
|
+
enable_progress_bar: bool = True,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Initialize CAREamist with a configuration object or a path.
|
|
106
|
+
|
|
107
|
+
A configuration object can be created using directly by calling `Configuration`,
|
|
108
|
+
using the configuration factory or loading a configuration from a yaml file.
|
|
109
|
+
|
|
110
|
+
Path can contain either a yaml file with parameters, or a saved checkpoint.
|
|
111
|
+
|
|
112
|
+
If no working directory is provided, the current working directory is used.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
117
|
+
Path to a configuration file or a trained model.
|
|
118
|
+
work_dir : str or pathlib.Path, optional
|
|
119
|
+
Path to working directory in which to save checkpoints and logs,
|
|
120
|
+
by default None.
|
|
121
|
+
callbacks : list of Callback, optional
|
|
122
|
+
List of callbacks to use during training and prediction, by default None.
|
|
123
|
+
enable_progress_bar : bool
|
|
124
|
+
Whether a progress bar will be displayed during training, validation and
|
|
125
|
+
prediction.
|
|
126
|
+
|
|
127
|
+
Raises
|
|
128
|
+
------
|
|
129
|
+
NotImplementedError
|
|
130
|
+
If the model is loaded from BioImage Model Zoo.
|
|
131
|
+
ValueError
|
|
132
|
+
If no hyper parameters are found in the checkpoint.
|
|
133
|
+
ValueError
|
|
134
|
+
If no data module hyper parameters are found in the checkpoint.
|
|
135
|
+
"""
|
|
136
|
+
# select current working directory if work_dir is None
|
|
137
|
+
if work_dir is None:
|
|
138
|
+
self.work_dir = Path.cwd()
|
|
139
|
+
logger.warning(
|
|
140
|
+
f"No working directory provided. Using current working directory: "
|
|
141
|
+
f"{self.work_dir}."
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
self.work_dir = Path(work_dir)
|
|
145
|
+
|
|
146
|
+
# configuration object
|
|
147
|
+
if isinstance(source, Configuration):
|
|
148
|
+
self.cfg = source
|
|
149
|
+
|
|
150
|
+
# instantiate model
|
|
151
|
+
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
|
|
152
|
+
self.model = FCNModule(
|
|
153
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
raise NotImplementedError("Architecture not supported.")
|
|
157
|
+
|
|
158
|
+
# path to configuration file or model
|
|
159
|
+
else:
|
|
160
|
+
# TODO: update this check so models can be downloaded directly from BMZ
|
|
161
|
+
source = check_path_exists(source)
|
|
162
|
+
|
|
163
|
+
# configuration file
|
|
164
|
+
if source.is_file() and (
|
|
165
|
+
source.suffix == ".yaml" or source.suffix == ".yml"
|
|
166
|
+
):
|
|
167
|
+
# load configuration
|
|
168
|
+
self.cfg = load_configuration(source)
|
|
169
|
+
|
|
170
|
+
# instantiate model
|
|
171
|
+
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
|
|
172
|
+
self.model = FCNModule(
|
|
173
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
174
|
+
) # type: ignore
|
|
175
|
+
else:
|
|
176
|
+
raise NotImplementedError("Architecture not supported.")
|
|
177
|
+
|
|
178
|
+
# attempt loading a pre-trained model
|
|
179
|
+
else:
|
|
180
|
+
self.model, self.cfg = load_pretrained(source)
|
|
181
|
+
|
|
182
|
+
# define the checkpoint saving callback
|
|
183
|
+
self._define_callbacks(callbacks, enable_progress_bar)
|
|
184
|
+
|
|
185
|
+
# instantiate logger
|
|
186
|
+
csv_logger = CSVLogger(
|
|
187
|
+
name=self.cfg.experiment_name,
|
|
188
|
+
save_dir=self.work_dir / "csv_logs",
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if self.cfg.training_config.has_logger():
|
|
192
|
+
if self.cfg.training_config.logger == SupportedLogger.WANDB:
|
|
193
|
+
experiment_logger: LOGGER_TYPES = [
|
|
194
|
+
WandbLogger(
|
|
195
|
+
name=self.cfg.experiment_name,
|
|
196
|
+
save_dir=self.work_dir / Path("wandb_logs"),
|
|
197
|
+
),
|
|
198
|
+
csv_logger,
|
|
199
|
+
]
|
|
200
|
+
elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
|
|
201
|
+
experiment_logger = [
|
|
202
|
+
TensorBoardLogger(
|
|
203
|
+
save_dir=self.work_dir / Path("tb_logs"),
|
|
204
|
+
),
|
|
205
|
+
csv_logger,
|
|
206
|
+
]
|
|
207
|
+
else:
|
|
208
|
+
experiment_logger = [csv_logger]
|
|
209
|
+
|
|
210
|
+
# instantiate trainer
|
|
211
|
+
self.trainer = Trainer(
|
|
212
|
+
enable_progress_bar=enable_progress_bar,
|
|
213
|
+
callbacks=self.callbacks,
|
|
214
|
+
default_root_dir=self.work_dir,
|
|
215
|
+
logger=experiment_logger,
|
|
216
|
+
**self.cfg.training_config.lightning_trainer_config or {},
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# place holder for the datamodules
|
|
220
|
+
self.train_datamodule: TrainDataModule | None = None
|
|
221
|
+
self.pred_datamodule: PredictDataModule | None = None
|
|
222
|
+
|
|
223
|
+
def _define_callbacks(
|
|
224
|
+
self, callbacks: list[Callback] | None, enable_progress_bar: bool
|
|
225
|
+
) -> None:
|
|
226
|
+
"""Define the callbacks for the training loop.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
callbacks : list of Callback, optional
|
|
231
|
+
List of callbacks to use during training and prediction, by default None.
|
|
232
|
+
enable_progress_bar : bool
|
|
233
|
+
Whether a progress bar will be displayed during training, validation and
|
|
234
|
+
prediction. It controls whether a `ProgressBarCallback` is added to the
|
|
235
|
+
callback list.
|
|
236
|
+
"""
|
|
237
|
+
self.callbacks = [] if callbacks is None else callbacks
|
|
238
|
+
|
|
239
|
+
# check that user callbacks are not any of the CAREamics callbacks
|
|
240
|
+
for c in self.callbacks:
|
|
241
|
+
if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"ModelCheckpoint and EarlyStopping callbacks are already defined "
|
|
244
|
+
"in CAREamics and should only be modified through the "
|
|
245
|
+
"training configuration (see TrainingConfig)."
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
if isinstance(c, HyperParametersCallback) or isinstance(
|
|
249
|
+
c, ProgressBarCallback
|
|
250
|
+
):
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"HyperParameter and ProgressBar callbacks are defined internally "
|
|
253
|
+
"and should not be passed as callbacks."
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# checkpoint callback saves checkpoints during training
|
|
257
|
+
self.callbacks.extend(
|
|
258
|
+
[
|
|
259
|
+
HyperParametersCallback(self.cfg),
|
|
260
|
+
ModelCheckpoint(
|
|
261
|
+
dirpath=self.work_dir / Path("checkpoints"),
|
|
262
|
+
filename=f"{self.cfg.experiment_name}_{{epoch:02d}}_step_{{step}}",
|
|
263
|
+
**self.cfg.training_config.checkpoint_callback.model_dump(),
|
|
264
|
+
),
|
|
265
|
+
]
|
|
266
|
+
)
|
|
267
|
+
if enable_progress_bar:
|
|
268
|
+
self.callbacks.append(ProgressBarCallback())
|
|
269
|
+
|
|
270
|
+
# early stopping callback
|
|
271
|
+
if self.cfg.training_config.early_stopping_callback is not None:
|
|
272
|
+
self.callbacks.append(
|
|
273
|
+
EarlyStopping(self.cfg.training_config.early_stopping_callback)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def stop_training(self) -> None:
|
|
277
|
+
"""Stop the training loop."""
|
|
278
|
+
# raise stop training flag
|
|
279
|
+
self.trainer.should_stop = True
|
|
280
|
+
self.trainer.limit_val_batches = 0 # skip validation
|
|
281
|
+
|
|
282
|
+
# TODO: is there are more elegant way than calling train again after _train_on_paths
|
|
283
|
+
def train(
|
|
284
|
+
self,
|
|
285
|
+
*,
|
|
286
|
+
datamodule: TrainDataModule | None = None,
|
|
287
|
+
train_source: Union[Path, str, NDArray] | None = None,
|
|
288
|
+
val_source: Union[Path, str, NDArray] | None = None,
|
|
289
|
+
train_target: Union[Path, str, NDArray] | None = None,
|
|
290
|
+
val_target: Union[Path, str, NDArray] | None = None,
|
|
291
|
+
use_in_memory: bool = True,
|
|
292
|
+
val_percentage: float = 0.1,
|
|
293
|
+
val_minimum_split: int = 1,
|
|
294
|
+
) -> None:
|
|
295
|
+
"""
|
|
296
|
+
Train the model on the provided data.
|
|
297
|
+
|
|
298
|
+
If a datamodule is provided, then training will be performed using it.
|
|
299
|
+
Alternatively, the training data can be provided as arrays or paths.
|
|
300
|
+
|
|
301
|
+
If `use_in_memory` is set to True, the source provided as Path or str will be
|
|
302
|
+
loaded in memory if it fits. Otherwise, training will be performed by loading
|
|
303
|
+
patches from the files one by one. Training on arrays is always performed
|
|
304
|
+
in memory.
|
|
305
|
+
|
|
306
|
+
If no validation source is provided, then the validation is extracted from
|
|
307
|
+
the training data using `val_percentage` and `val_minimum_split`. In the case
|
|
308
|
+
of data provided as Path or str, the percentage and minimum number are applied
|
|
309
|
+
to the number of files. For arrays, it is the number of patches.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
datamodule : TrainDataModule, optional
|
|
314
|
+
Datamodule to train on, by default None.
|
|
315
|
+
train_source : pathlib.Path or str or NDArray, optional
|
|
316
|
+
Train source, if no datamodule is provided, by default None.
|
|
317
|
+
val_source : pathlib.Path or str or NDArray, optional
|
|
318
|
+
Validation source, if no datamodule is provided, by default None.
|
|
319
|
+
train_target : pathlib.Path or str or NDArray, optional
|
|
320
|
+
Train target source, if no datamodule is provided, by default None.
|
|
321
|
+
val_target : pathlib.Path or str or NDArray, optional
|
|
322
|
+
Validation target source, if no datamodule is provided, by default None.
|
|
323
|
+
use_in_memory : bool, optional
|
|
324
|
+
Use in memory dataset if possible, by default True.
|
|
325
|
+
val_percentage : float, optional
|
|
326
|
+
Percentage of validation extracted from training data, by default 0.1.
|
|
327
|
+
val_minimum_split : int, optional
|
|
328
|
+
Minimum number of validation (patch or file) extracted from training data,
|
|
329
|
+
by default 1.
|
|
330
|
+
|
|
331
|
+
Raises
|
|
332
|
+
------
|
|
333
|
+
ValueError
|
|
334
|
+
If both `datamodule` and `train_source` are provided.
|
|
335
|
+
ValueError
|
|
336
|
+
If sources are not of the same type (e.g. train is an array and val is
|
|
337
|
+
a Path).
|
|
338
|
+
ValueError
|
|
339
|
+
If the training target is provided to N2V.
|
|
340
|
+
ValueError
|
|
341
|
+
If neither a datamodule nor a source is provided.
|
|
342
|
+
"""
|
|
343
|
+
if datamodule is not None and train_source is not None:
|
|
344
|
+
raise ValueError(
|
|
345
|
+
"Only one of `datamodule` and `train_source` can be provided."
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# check that inputs are the same type
|
|
349
|
+
source_types = {
|
|
350
|
+
type(s)
|
|
351
|
+
for s in (train_source, val_source, train_target, val_target)
|
|
352
|
+
if s is not None
|
|
353
|
+
}
|
|
354
|
+
if len(source_types) > 1:
|
|
355
|
+
raise ValueError("All sources should be of the same type.")
|
|
356
|
+
|
|
357
|
+
# train
|
|
358
|
+
if datamodule is not None:
|
|
359
|
+
self._train_on_datamodule(datamodule=datamodule)
|
|
360
|
+
|
|
361
|
+
else:
|
|
362
|
+
# raise error if target is provided to N2V
|
|
363
|
+
if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
|
|
364
|
+
if train_target is not None:
|
|
365
|
+
raise ValueError(
|
|
366
|
+
"Training target not compatible with N2V training."
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# dispatch the training
|
|
370
|
+
if isinstance(train_source, np.ndarray):
|
|
371
|
+
# mypy checks
|
|
372
|
+
assert isinstance(val_source, np.ndarray) or val_source is None
|
|
373
|
+
assert isinstance(train_target, np.ndarray) or train_target is None
|
|
374
|
+
assert isinstance(val_target, np.ndarray) or val_target is None
|
|
375
|
+
|
|
376
|
+
self._train_on_array(
|
|
377
|
+
train_source,
|
|
378
|
+
val_source,
|
|
379
|
+
train_target,
|
|
380
|
+
val_target,
|
|
381
|
+
val_percentage,
|
|
382
|
+
val_minimum_split,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
elif isinstance(train_source, Path) or isinstance(train_source, str):
|
|
386
|
+
# mypy checks
|
|
387
|
+
assert (
|
|
388
|
+
isinstance(val_source, Path)
|
|
389
|
+
or isinstance(val_source, str)
|
|
390
|
+
or val_source is None
|
|
391
|
+
)
|
|
392
|
+
assert (
|
|
393
|
+
isinstance(train_target, Path)
|
|
394
|
+
or isinstance(train_target, str)
|
|
395
|
+
or train_target is None
|
|
396
|
+
)
|
|
397
|
+
assert (
|
|
398
|
+
isinstance(val_target, Path)
|
|
399
|
+
or isinstance(val_target, str)
|
|
400
|
+
or val_target is None
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
self._train_on_path(
|
|
404
|
+
train_source,
|
|
405
|
+
val_source,
|
|
406
|
+
train_target,
|
|
407
|
+
val_target,
|
|
408
|
+
use_in_memory,
|
|
409
|
+
val_percentage,
|
|
410
|
+
val_minimum_split,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
else:
|
|
414
|
+
raise ValueError(
|
|
415
|
+
f"Invalid input, expected a str, Path, array or TrainDataModule "
|
|
416
|
+
f"instance (got {type(train_source)})."
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
|
|
420
|
+
"""
|
|
421
|
+
Train the model on the provided datamodule.
|
|
422
|
+
|
|
423
|
+
Parameters
|
|
424
|
+
----------
|
|
425
|
+
datamodule : TrainDataModule
|
|
426
|
+
Datamodule to train on.
|
|
427
|
+
"""
|
|
428
|
+
# register datamodule
|
|
429
|
+
self.train_datamodule = datamodule
|
|
430
|
+
|
|
431
|
+
# set defaults (in case `stop_training` was called before)
|
|
432
|
+
self.trainer.should_stop = False
|
|
433
|
+
self.trainer.limit_val_batches = 1.0 # 100%
|
|
434
|
+
|
|
435
|
+
# train
|
|
436
|
+
self.trainer.fit(self.model, datamodule=datamodule)
|
|
437
|
+
|
|
438
|
+
def _train_on_array(
|
|
439
|
+
self,
|
|
440
|
+
train_data: NDArray,
|
|
441
|
+
val_data: NDArray | None = None,
|
|
442
|
+
train_target: NDArray | None = None,
|
|
443
|
+
val_target: NDArray | None = None,
|
|
444
|
+
val_percentage: float = 0.1,
|
|
445
|
+
val_minimum_split: int = 5,
|
|
446
|
+
) -> None:
|
|
447
|
+
"""
|
|
448
|
+
Train the model on the provided data arrays.
|
|
449
|
+
|
|
450
|
+
Parameters
|
|
451
|
+
----------
|
|
452
|
+
train_data : NDArray
|
|
453
|
+
Training data.
|
|
454
|
+
val_data : NDArray, optional
|
|
455
|
+
Validation data, by default None.
|
|
456
|
+
train_target : NDArray, optional
|
|
457
|
+
Train target data, by default None.
|
|
458
|
+
val_target : NDArray, optional
|
|
459
|
+
Validation target data, by default None.
|
|
460
|
+
val_percentage : float, optional
|
|
461
|
+
Percentage of patches to use for validation, by default 0.1.
|
|
462
|
+
val_minimum_split : int, optional
|
|
463
|
+
Minimum number of patches to use for validation, by default 5.
|
|
464
|
+
"""
|
|
465
|
+
# create datamodule
|
|
466
|
+
datamodule = TrainDataModule(
|
|
467
|
+
data_config=self.cfg.data_config,
|
|
468
|
+
train_data=train_data,
|
|
469
|
+
val_data=val_data,
|
|
470
|
+
train_data_target=train_target,
|
|
471
|
+
val_data_target=val_target,
|
|
472
|
+
val_percentage=val_percentage,
|
|
473
|
+
val_minimum_split=val_minimum_split,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# train
|
|
477
|
+
self.train(datamodule=datamodule)
|
|
478
|
+
|
|
479
|
+
def _train_on_path(
|
|
480
|
+
self,
|
|
481
|
+
path_to_train_data: Union[Path, str],
|
|
482
|
+
path_to_val_data: Union[Path, str] | None = None,
|
|
483
|
+
path_to_train_target: Union[Path, str] | None = None,
|
|
484
|
+
path_to_val_target: Union[Path, str] | None = None,
|
|
485
|
+
use_in_memory: bool = True,
|
|
486
|
+
val_percentage: float = 0.1,
|
|
487
|
+
val_minimum_split: int = 1,
|
|
488
|
+
) -> None:
|
|
489
|
+
"""
|
|
490
|
+
Train the model on the provided data paths.
|
|
491
|
+
|
|
492
|
+
Parameters
|
|
493
|
+
----------
|
|
494
|
+
path_to_train_data : pathlib.Path or str
|
|
495
|
+
Path to the training data.
|
|
496
|
+
path_to_val_data : pathlib.Path or str, optional
|
|
497
|
+
Path to validation data, by default None.
|
|
498
|
+
path_to_train_target : pathlib.Path or str, optional
|
|
499
|
+
Path to train target data, by default None.
|
|
500
|
+
path_to_val_target : pathlib.Path or str, optional
|
|
501
|
+
Path to validation target data, by default None.
|
|
502
|
+
use_in_memory : bool, optional
|
|
503
|
+
Use in memory dataset if possible, by default True.
|
|
504
|
+
val_percentage : float, optional
|
|
505
|
+
Percentage of files to use for validation, by default 0.1.
|
|
506
|
+
val_minimum_split : int, optional
|
|
507
|
+
Minimum number of files to use for validation, by default 1.
|
|
508
|
+
"""
|
|
509
|
+
# sanity check on data (path exists)
|
|
510
|
+
path_to_train_data = check_path_exists(path_to_train_data)
|
|
511
|
+
|
|
512
|
+
if path_to_val_data is not None:
|
|
513
|
+
path_to_val_data = check_path_exists(path_to_val_data)
|
|
514
|
+
|
|
515
|
+
if path_to_train_target is not None:
|
|
516
|
+
path_to_train_target = check_path_exists(path_to_train_target)
|
|
517
|
+
|
|
518
|
+
if path_to_val_target is not None:
|
|
519
|
+
path_to_val_target = check_path_exists(path_to_val_target)
|
|
520
|
+
|
|
521
|
+
# create datamodule
|
|
522
|
+
datamodule = TrainDataModule(
|
|
523
|
+
data_config=self.cfg.data_config,
|
|
524
|
+
train_data=path_to_train_data,
|
|
525
|
+
val_data=path_to_val_data,
|
|
526
|
+
train_data_target=path_to_train_target,
|
|
527
|
+
val_data_target=path_to_val_target,
|
|
528
|
+
use_in_memory=use_in_memory,
|
|
529
|
+
val_percentage=val_percentage,
|
|
530
|
+
val_minimum_split=val_minimum_split,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# train
|
|
534
|
+
self.train(datamodule=datamodule)
|
|
535
|
+
|
|
536
|
+
@overload
|
|
537
|
+
def predict( # numpydoc ignore=GL08
|
|
538
|
+
self, source: PredictDataModule
|
|
539
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
540
|
+
|
|
541
|
+
@overload
|
|
542
|
+
def predict( # numpydoc ignore=GL08
|
|
543
|
+
self,
|
|
544
|
+
source: Union[Path, str],
|
|
545
|
+
*,
|
|
546
|
+
batch_size: int = 1,
|
|
547
|
+
tile_size: tuple[int, ...] | None = None,
|
|
548
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
549
|
+
axes: str | None = None,
|
|
550
|
+
data_type: Literal["tiff", "custom"] | None = None,
|
|
551
|
+
tta_transforms: bool = False,
|
|
552
|
+
dataloader_params: dict | None = None,
|
|
553
|
+
read_source_func: Callable | None = None,
|
|
554
|
+
extension_filter: str = "",
|
|
555
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
556
|
+
|
|
557
|
+
@overload
|
|
558
|
+
def predict( # numpydoc ignore=GL08
|
|
559
|
+
self,
|
|
560
|
+
source: NDArray,
|
|
561
|
+
*,
|
|
562
|
+
batch_size: int = 1,
|
|
563
|
+
tile_size: tuple[int, ...] | None = None,
|
|
564
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
565
|
+
axes: str | None = None,
|
|
566
|
+
data_type: Literal["array"] | None = None,
|
|
567
|
+
tta_transforms: bool = False,
|
|
568
|
+
dataloader_params: dict | None = None,
|
|
569
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
570
|
+
|
|
571
|
+
def predict(
|
|
572
|
+
self,
|
|
573
|
+
source: Union[PredictDataModule, Path, str, NDArray],
|
|
574
|
+
*,
|
|
575
|
+
batch_size: int = 1,
|
|
576
|
+
tile_size: tuple[int, ...] | None = None,
|
|
577
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
578
|
+
axes: str | None = None,
|
|
579
|
+
data_type: Literal["array", "tiff", "custom"] | None = None,
|
|
580
|
+
tta_transforms: bool = False,
|
|
581
|
+
dataloader_params: dict | None = None,
|
|
582
|
+
read_source_func: Callable | None = None,
|
|
583
|
+
extension_filter: str = "",
|
|
584
|
+
**kwargs: Any,
|
|
585
|
+
) -> Union[list[NDArray], NDArray]:
|
|
586
|
+
"""
|
|
587
|
+
Make predictions on the provided data.
|
|
588
|
+
|
|
589
|
+
Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
|
|
590
|
+
array.
|
|
591
|
+
|
|
592
|
+
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
593
|
+
configuration parameters will be used, with the `patch_size` instead of
|
|
594
|
+
`tile_size`.
|
|
595
|
+
|
|
596
|
+
Test-time augmentation (TTA) can be switched on using the `tta_transforms`
|
|
597
|
+
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
598
|
+
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
599
|
+
should not be used if you did not train with these augmentations.
|
|
600
|
+
|
|
601
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
602
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
603
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
604
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
605
|
+
happen in the Z dimension, consider padding your image.
|
|
606
|
+
|
|
607
|
+
Parameters
|
|
608
|
+
----------
|
|
609
|
+
source : PredictDataModule, pathlib.Path, str or numpy.ndarray
|
|
610
|
+
Data to predict on.
|
|
611
|
+
batch_size : int, default=1
|
|
612
|
+
Batch size for prediction.
|
|
613
|
+
tile_size : tuple of int, optional
|
|
614
|
+
Size of the tiles to use for prediction.
|
|
615
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
616
|
+
Overlap between tiles, can be None.
|
|
617
|
+
axes : str, optional
|
|
618
|
+
Axes of the input data, by default None.
|
|
619
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
620
|
+
Type of the input data.
|
|
621
|
+
tta_transforms : bool, default=True
|
|
622
|
+
Whether to apply test-time augmentation.
|
|
623
|
+
dataloader_params : dict, optional
|
|
624
|
+
Parameters to pass to the dataloader.
|
|
625
|
+
read_source_func : Callable, optional
|
|
626
|
+
Function to read the source data.
|
|
627
|
+
extension_filter : str, default=""
|
|
628
|
+
Filter for the file extension.
|
|
629
|
+
**kwargs : Any
|
|
630
|
+
Unused.
|
|
631
|
+
|
|
632
|
+
Returns
|
|
633
|
+
-------
|
|
634
|
+
list of NDArray or NDArray
|
|
635
|
+
Predictions made by the model.
|
|
636
|
+
|
|
637
|
+
Raises
|
|
638
|
+
------
|
|
639
|
+
ValueError
|
|
640
|
+
If mean and std are not provided in the configuration.
|
|
641
|
+
ValueError
|
|
642
|
+
If tile size is not divisible by 2**depth for UNet models.
|
|
643
|
+
ValueError
|
|
644
|
+
If tile overlap is not specified.
|
|
645
|
+
"""
|
|
646
|
+
if (
|
|
647
|
+
self.cfg.data_config.image_means is None
|
|
648
|
+
or self.cfg.data_config.image_stds is None
|
|
649
|
+
):
|
|
650
|
+
raise ValueError("Mean and std must be provided in the configuration.")
|
|
651
|
+
|
|
652
|
+
# tile size for UNets
|
|
653
|
+
if tile_size is not None:
|
|
654
|
+
model = self.cfg.algorithm_config.model
|
|
655
|
+
|
|
656
|
+
if model.architecture == SupportedArchitecture.UNET.value:
|
|
657
|
+
# tile size must be equal to k*2^n, where n is the number of pooling
|
|
658
|
+
# layers (equal to the depth) and k is an integer
|
|
659
|
+
depth = model.depth
|
|
660
|
+
tile_increment = 2**depth
|
|
661
|
+
|
|
662
|
+
for i, t in enumerate(tile_size):
|
|
663
|
+
if t % tile_increment != 0:
|
|
664
|
+
raise ValueError(
|
|
665
|
+
f"Tile size must be divisible by {tile_increment} along "
|
|
666
|
+
f"all axes (got {t} for axis {i}). If your image size is "
|
|
667
|
+
f"smaller along one axis (e.g. Z), consider padding the "
|
|
668
|
+
f"image."
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# tile overlaps must be specified
|
|
672
|
+
if tile_overlap is None:
|
|
673
|
+
raise ValueError("Tile overlap must be specified.")
|
|
674
|
+
|
|
675
|
+
# create the prediction
|
|
676
|
+
self.pred_datamodule = create_predict_datamodule(
|
|
677
|
+
pred_data=source,
|
|
678
|
+
data_type=data_type or self.cfg.data_config.data_type, # type: ignore
|
|
679
|
+
axes=axes or self.cfg.data_config.axes,
|
|
680
|
+
image_means=self.cfg.data_config.image_means,
|
|
681
|
+
image_stds=self.cfg.data_config.image_stds,
|
|
682
|
+
tile_size=tile_size,
|
|
683
|
+
tile_overlap=tile_overlap,
|
|
684
|
+
batch_size=batch_size or self.cfg.data_config.batch_size,
|
|
685
|
+
tta_transforms=tta_transforms,
|
|
686
|
+
read_source_func=read_source_func,
|
|
687
|
+
extension_filter=extension_filter,
|
|
688
|
+
dataloader_params=dataloader_params,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
# predict
|
|
692
|
+
predictions = self.trainer.predict(
|
|
693
|
+
model=self.model, datamodule=self.pred_datamodule
|
|
694
|
+
)
|
|
695
|
+
return convert_outputs(predictions, self.pred_datamodule.tiled)
|
|
696
|
+
|
|
697
|
+
def predict_to_disk(
|
|
698
|
+
self,
|
|
699
|
+
source: Union[PredictDataModule, Path, str],
|
|
700
|
+
*,
|
|
701
|
+
batch_size: int = 1,
|
|
702
|
+
tile_size: tuple[int, ...] | None = None,
|
|
703
|
+
tile_overlap: tuple[int, ...] | None = (48, 48),
|
|
704
|
+
axes: str | None = None,
|
|
705
|
+
data_type: Literal["tiff", "custom"] | None = None,
|
|
706
|
+
tta_transforms: bool = False,
|
|
707
|
+
dataloader_params: dict | None = None,
|
|
708
|
+
read_source_func: Callable | None = None,
|
|
709
|
+
extension_filter: str = "",
|
|
710
|
+
write_type: Literal["tiff", "custom"] = "tiff",
|
|
711
|
+
write_extension: str | None = None,
|
|
712
|
+
write_func: WriteFunc | None = None,
|
|
713
|
+
write_func_kwargs: dict[str, Any] | None = None,
|
|
714
|
+
prediction_dir: Union[Path, str] = "predictions",
|
|
715
|
+
**kwargs,
|
|
716
|
+
) -> None:
|
|
717
|
+
"""
|
|
718
|
+
Make predictions on the provided data and save outputs to files.
|
|
719
|
+
|
|
720
|
+
The predictions will be saved in a new directory 'predictions' within the set
|
|
721
|
+
working directory. The directory stucture within the 'predictions' directory
|
|
722
|
+
will match that of the source directory.
|
|
723
|
+
|
|
724
|
+
The `source` must be from files and not arrays. The file names of the
|
|
725
|
+
predictions will match those of the source. If there is more than one sample
|
|
726
|
+
within a file, the samples will be saved to seperate files. The file names of
|
|
727
|
+
samples will have the name of the corresponding source file but with the sample
|
|
728
|
+
index appended. E.g. If the the source file name is 'images.tiff' then the first
|
|
729
|
+
sample's prediction will be saved with the file name "image_0.tiff".
|
|
730
|
+
Input can be a PredictDataModule instance, a path to a data file, or a numpy
|
|
731
|
+
array.
|
|
732
|
+
|
|
733
|
+
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
734
|
+
configuration parameters will be used, with the `patch_size` instead of
|
|
735
|
+
`tile_size`.
|
|
736
|
+
|
|
737
|
+
Test-time augmentation (TTA) can be switched on using the `tta_transforms`
|
|
738
|
+
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
739
|
+
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
740
|
+
should not be used if you did not train with these augmentations.
|
|
741
|
+
|
|
742
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
743
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
744
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
745
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
746
|
+
happen in the Z dimension, consider padding your image.
|
|
747
|
+
|
|
748
|
+
Parameters
|
|
749
|
+
----------
|
|
750
|
+
source : PredictDataModule or pathlib.Path, str
|
|
751
|
+
Data to predict on.
|
|
752
|
+
batch_size : int, default=1
|
|
753
|
+
Batch size for prediction.
|
|
754
|
+
tile_size : tuple of int, optional
|
|
755
|
+
Size of the tiles to use for prediction.
|
|
756
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
757
|
+
Overlap between tiles.
|
|
758
|
+
axes : str, optional
|
|
759
|
+
Axes of the input data, by default None.
|
|
760
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
761
|
+
Type of the input data.
|
|
762
|
+
tta_transforms : bool, default=True
|
|
763
|
+
Whether to apply test-time augmentation.
|
|
764
|
+
dataloader_params : dict, optional
|
|
765
|
+
Parameters to pass to the dataloader.
|
|
766
|
+
read_source_func : Callable, optional
|
|
767
|
+
Function to read the source data.
|
|
768
|
+
extension_filter : str, default=""
|
|
769
|
+
Filter for the file extension.
|
|
770
|
+
write_type : {"tiff", "custom"}, default="tiff"
|
|
771
|
+
The data type to save as, includes custom.
|
|
772
|
+
write_extension : str, optional
|
|
773
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
774
|
+
`write_type` an extension to save the data with must be passed.
|
|
775
|
+
write_func : WriteFunc, optional
|
|
776
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
777
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
778
|
+
write_func_kwargs : dict of {str: any}, optional
|
|
779
|
+
Additional keyword arguments to be passed to the save function.
|
|
780
|
+
prediction_dir : Path | str, default="predictions"
|
|
781
|
+
The path to save the prediction results to. If `prediction_dir` is not
|
|
782
|
+
absolute, the directory will be assumed to be relative to the pre-set
|
|
783
|
+
`work_dir`. If the directory does not exist it will be created.
|
|
784
|
+
**kwargs : Any
|
|
785
|
+
Unused.
|
|
786
|
+
|
|
787
|
+
Raises
|
|
788
|
+
------
|
|
789
|
+
ValueError
|
|
790
|
+
If `write_type` is custom and `write_extension` is None.
|
|
791
|
+
ValueError
|
|
792
|
+
If `write_type` is custom and `write_fun is None.
|
|
793
|
+
ValueError
|
|
794
|
+
If `source` is not `str`, `Path` or `PredictDataModule`
|
|
795
|
+
"""
|
|
796
|
+
if write_func_kwargs is None:
|
|
797
|
+
write_func_kwargs = {}
|
|
798
|
+
|
|
799
|
+
if Path(prediction_dir).is_absolute():
|
|
800
|
+
write_dir = Path(prediction_dir)
|
|
801
|
+
else:
|
|
802
|
+
write_dir = self.work_dir / prediction_dir
|
|
803
|
+
write_dir.mkdir(exist_ok=True, parents=True)
|
|
804
|
+
|
|
805
|
+
# guards for custom types
|
|
806
|
+
if write_type == SupportedData.CUSTOM:
|
|
807
|
+
if write_extension is None:
|
|
808
|
+
raise ValueError(
|
|
809
|
+
"A `write_extension` must be provided for custom write types."
|
|
810
|
+
)
|
|
811
|
+
if write_func is None:
|
|
812
|
+
raise ValueError(
|
|
813
|
+
"A `write_func` must be provided for custom write types."
|
|
814
|
+
)
|
|
815
|
+
else:
|
|
816
|
+
write_func = get_write_func(write_type)
|
|
817
|
+
write_extension = SupportedData.get_extension(write_type)
|
|
818
|
+
|
|
819
|
+
# extract file names
|
|
820
|
+
source_path: Union[Path, str, NDArray]
|
|
821
|
+
source_data_type: Literal["array", "tiff", "custom"]
|
|
822
|
+
if isinstance(source, PredictDataModule):
|
|
823
|
+
source_path = source.pred_data
|
|
824
|
+
source_data_type = source.data_type # type: ignore
|
|
825
|
+
extension_filter = source.extension_filter
|
|
826
|
+
elif isinstance(source, (str | Path)):
|
|
827
|
+
source_path = source
|
|
828
|
+
source_data_type = (
|
|
829
|
+
data_type or self.cfg.data_config.data_type # type: ignore
|
|
830
|
+
)
|
|
831
|
+
extension_filter = SupportedData.get_extension_pattern(
|
|
832
|
+
SupportedData(source_data_type)
|
|
833
|
+
)
|
|
834
|
+
else:
|
|
835
|
+
raise ValueError(f"Unsupported source type: '{type(source)}'.")
|
|
836
|
+
|
|
837
|
+
if source_data_type == "array":
|
|
838
|
+
raise ValueError(
|
|
839
|
+
"Predicting to disk is not supported for input type 'array'."
|
|
840
|
+
)
|
|
841
|
+
assert isinstance(source_path, (Path | str)) # because data_type != "array"
|
|
842
|
+
source_path = Path(source_path)
|
|
843
|
+
|
|
844
|
+
file_paths = list_files(source_path, source_data_type, extension_filter)
|
|
845
|
+
|
|
846
|
+
# predict and write each file in turn
|
|
847
|
+
for file_path in file_paths:
|
|
848
|
+
# source_path is relative to original source path...
|
|
849
|
+
# should mirror original directory structure
|
|
850
|
+
prediction = self.predict(
|
|
851
|
+
source=file_path,
|
|
852
|
+
batch_size=batch_size,
|
|
853
|
+
tile_size=tile_size,
|
|
854
|
+
tile_overlap=tile_overlap,
|
|
855
|
+
axes=axes,
|
|
856
|
+
data_type=data_type,
|
|
857
|
+
tta_transforms=tta_transforms,
|
|
858
|
+
dataloader_params=dataloader_params,
|
|
859
|
+
read_source_func=read_source_func,
|
|
860
|
+
extension_filter=extension_filter,
|
|
861
|
+
**kwargs,
|
|
862
|
+
)
|
|
863
|
+
# TODO: cast to float16?
|
|
864
|
+
write_data = np.concatenate(prediction)
|
|
865
|
+
|
|
866
|
+
# create directory structure and write path
|
|
867
|
+
if not source_path.is_file():
|
|
868
|
+
file_write_dir = write_dir / file_path.parent.relative_to(source_path)
|
|
869
|
+
else:
|
|
870
|
+
file_write_dir = write_dir
|
|
871
|
+
file_write_dir.mkdir(parents=True, exist_ok=True)
|
|
872
|
+
write_path = (file_write_dir / file_path.name).with_suffix(write_extension)
|
|
873
|
+
|
|
874
|
+
# write data
|
|
875
|
+
write_func(file_path=write_path, img=write_data)
|
|
876
|
+
|
|
877
|
+
def export_to_bmz(
|
|
878
|
+
self,
|
|
879
|
+
path_to_archive: Union[Path | str],
|
|
880
|
+
friendly_model_name: str,
|
|
881
|
+
input_array: NDArray,
|
|
882
|
+
authors: list[dict],
|
|
883
|
+
general_description: str,
|
|
884
|
+
data_description: str,
|
|
885
|
+
covers: list[Union[Path, str]] | None = None,
|
|
886
|
+
channel_names: list[str] | None = None,
|
|
887
|
+
model_version: str = "0.1.0",
|
|
888
|
+
) -> None:
|
|
889
|
+
"""Export the model to the BioImage Model Zoo format.
|
|
890
|
+
|
|
891
|
+
This method packages the current weights into a zip file that can be uploaded
|
|
892
|
+
to the BioImage Model Zoo. The archive consists of the model weights, the model
|
|
893
|
+
specifications and various files (inputs, outputs, README, env.yaml etc.).
|
|
894
|
+
|
|
895
|
+
`path_to_archive` should point to a file with a ".zip" extension.
|
|
896
|
+
|
|
897
|
+
`friendly_model_name` is the name used for the model in the BMZ specs
|
|
898
|
+
and website, it should consist of letters, numbers, dashes, underscores and
|
|
899
|
+
parentheses only.
|
|
900
|
+
|
|
901
|
+
Input array must be of the same dimensions as the axes recorded in the
|
|
902
|
+
configuration of the `CAREamist`.
|
|
903
|
+
|
|
904
|
+
Parameters
|
|
905
|
+
----------
|
|
906
|
+
path_to_archive : pathlib.Path or str
|
|
907
|
+
Path in which to save the model, including file name, which should end with
|
|
908
|
+
".zip".
|
|
909
|
+
friendly_model_name : str
|
|
910
|
+
Name of the model as used in the BMZ specs, it should consist of letters,
|
|
911
|
+
numbers, dashes, underscores and parentheses only.
|
|
912
|
+
input_array : NDArray
|
|
913
|
+
Input array used to validate the model and as example.
|
|
914
|
+
authors : list of dict
|
|
915
|
+
List of authors of the model.
|
|
916
|
+
general_description : str
|
|
917
|
+
General description of the model used in the BMZ metadata.
|
|
918
|
+
data_description : str
|
|
919
|
+
Description of the data the model was trained on.
|
|
920
|
+
covers : list of pathlib.Path or str, default=None
|
|
921
|
+
Paths to the cover images.
|
|
922
|
+
channel_names : list of str, default=None
|
|
923
|
+
Channel names.
|
|
924
|
+
model_version : str, default="0.1.0"
|
|
925
|
+
Version of the model.
|
|
926
|
+
"""
|
|
927
|
+
# TODO: add in docs that it is expected that input_array dimensions match
|
|
928
|
+
# those in data_config
|
|
929
|
+
|
|
930
|
+
output_patch = self.predict(
|
|
931
|
+
input_array,
|
|
932
|
+
data_type=SupportedData.ARRAY.value,
|
|
933
|
+
tta_transforms=False,
|
|
934
|
+
)
|
|
935
|
+
output = np.concatenate(output_patch, axis=0)
|
|
936
|
+
input_array = reshape_array(input_array, self.cfg.data_config.axes)
|
|
937
|
+
|
|
938
|
+
export_to_bmz(
|
|
939
|
+
model=self.model,
|
|
940
|
+
config=self.cfg,
|
|
941
|
+
path_to_archive=path_to_archive,
|
|
942
|
+
model_name=friendly_model_name,
|
|
943
|
+
general_description=general_description,
|
|
944
|
+
data_description=data_description,
|
|
945
|
+
authors=authors,
|
|
946
|
+
input_array=input_array,
|
|
947
|
+
output_array=output,
|
|
948
|
+
covers=covers,
|
|
949
|
+
channel_names=channel_names,
|
|
950
|
+
model_version=model_version,
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
def get_losses(self) -> dict[str, list]:
|
|
954
|
+
"""Return data that can be used to plot train and validation loss curves.
|
|
955
|
+
|
|
956
|
+
Returns
|
|
957
|
+
-------
|
|
958
|
+
dict of str: list
|
|
959
|
+
Dictionary containing the losses for each epoch.
|
|
960
|
+
"""
|
|
961
|
+
return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")
|