careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""Generic UNet Lightning DataModule."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Union
|
|
4
|
+
|
|
5
|
+
import pytorch_lightning as L
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torchmetrics import MetricCollection
|
|
9
|
+
from torchmetrics.image import PeakSignalNoiseRatio
|
|
10
|
+
|
|
11
|
+
from careamics.config import algorithm_factory
|
|
12
|
+
from careamics.config.algorithms import (
|
|
13
|
+
CAREAlgorithm,
|
|
14
|
+
N2NAlgorithm,
|
|
15
|
+
N2VAlgorithm,
|
|
16
|
+
PN2VAlgorithm,
|
|
17
|
+
)
|
|
18
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
19
|
+
from careamics.models.unet import UNet
|
|
20
|
+
from careamics.transforms import Denormalize
|
|
21
|
+
from careamics.utils.logging import get_logger
|
|
22
|
+
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class UnetModule(L.LightningModule):
|
|
28
|
+
"""CAREamics PyTorch Lightning module for UNet based algorithms.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
|
|
33
|
+
Configuration for the algorithm, either as an instance of a specific algorithm
|
|
34
|
+
class or a dictionary that can be converted to an algorithm instance.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
algorithm_config: Union[
|
|
40
|
+
CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, PN2VAlgorithm, dict
|
|
41
|
+
],
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Instantiate UNet DataModule.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
algorithm_config : CAREAlgorithm, N2VAlgorithm, N2NAlgorithm, or dict
|
|
48
|
+
Configuration for the algorithm, either as an instance of a specific
|
|
49
|
+
algorithm class or a dictionary that can be converted to an algorithm
|
|
50
|
+
instance.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__()
|
|
53
|
+
|
|
54
|
+
if isinstance(algorithm_config, dict):
|
|
55
|
+
algorithm_config = algorithm_factory(algorithm_config)
|
|
56
|
+
|
|
57
|
+
self.config = algorithm_config
|
|
58
|
+
self.model: nn.Module = UNet(**algorithm_config.model.model_dump())
|
|
59
|
+
|
|
60
|
+
self._best_checkpoint_loaded = False
|
|
61
|
+
|
|
62
|
+
# TODO: how to support metric evaluation better
|
|
63
|
+
self.metrics = MetricCollection(PeakSignalNoiseRatio())
|
|
64
|
+
|
|
65
|
+
def forward(self, x: Any) -> Any:
|
|
66
|
+
"""Default forward method.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
x : Any
|
|
71
|
+
Input data.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
Any
|
|
76
|
+
Output from the model.
|
|
77
|
+
"""
|
|
78
|
+
return self.model(x)
|
|
79
|
+
|
|
80
|
+
def _log_training_stats(self, loss: Any, batch_size: Any) -> None:
|
|
81
|
+
"""Log training statistics.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
loss : Any
|
|
86
|
+
The loss value for the current training step.
|
|
87
|
+
batch_size : Any
|
|
88
|
+
The size of the batch used in the current training step.
|
|
89
|
+
"""
|
|
90
|
+
self.log(
|
|
91
|
+
"train_loss",
|
|
92
|
+
loss,
|
|
93
|
+
on_step=True,
|
|
94
|
+
on_epoch=True,
|
|
95
|
+
prog_bar=True,
|
|
96
|
+
logger=True,
|
|
97
|
+
batch_size=batch_size,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
optimizer = self.optimizers()
|
|
101
|
+
if isinstance(optimizer, list):
|
|
102
|
+
current_lr = optimizer[0].param_groups[0]["lr"]
|
|
103
|
+
else:
|
|
104
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
105
|
+
self.log(
|
|
106
|
+
"learning_rate",
|
|
107
|
+
current_lr,
|
|
108
|
+
on_step=False,
|
|
109
|
+
on_epoch=True,
|
|
110
|
+
logger=True,
|
|
111
|
+
batch_size=batch_size,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def _log_validation_stats(self, loss: Any, batch_size: Any) -> None:
|
|
115
|
+
"""Log validation statistics.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
loss : Any
|
|
120
|
+
The loss value for the current validation step.
|
|
121
|
+
batch_size : Any
|
|
122
|
+
The size of the batch used in the current validation step.
|
|
123
|
+
"""
|
|
124
|
+
self.log(
|
|
125
|
+
"val_loss",
|
|
126
|
+
loss,
|
|
127
|
+
on_step=False,
|
|
128
|
+
on_epoch=True,
|
|
129
|
+
prog_bar=True,
|
|
130
|
+
logger=True,
|
|
131
|
+
batch_size=batch_size,
|
|
132
|
+
)
|
|
133
|
+
self.log_dict(self.metrics, on_step=False, on_epoch=True, batch_size=batch_size)
|
|
134
|
+
|
|
135
|
+
def _load_best_checkpoint(self) -> None:
|
|
136
|
+
"""Load the best checkpoint from the trainer's checkpoint callback."""
|
|
137
|
+
if (
|
|
138
|
+
not hasattr(self.trainer, "checkpoint_callback")
|
|
139
|
+
or self.trainer.checkpoint_callback is None
|
|
140
|
+
):
|
|
141
|
+
logger.warning("No checkpoint callback found, cannot load best checkpoint.")
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
best_model_path = self.trainer.checkpoint_callback.best_model_path
|
|
145
|
+
if best_model_path and best_model_path != "":
|
|
146
|
+
logger.info(f"Loading best checkpoint from: {best_model_path}")
|
|
147
|
+
model_state = torch.load(best_model_path, weights_only=True)["state_dict"]
|
|
148
|
+
self.load_state_dict(model_state)
|
|
149
|
+
else:
|
|
150
|
+
logger.warning("No best checkpoint found.")
|
|
151
|
+
|
|
152
|
+
def predict_step(
|
|
153
|
+
self,
|
|
154
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
155
|
+
batch_idx: Any,
|
|
156
|
+
load_best_checkpoint=False,
|
|
157
|
+
) -> Any:
|
|
158
|
+
"""Default predict step.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
163
|
+
A tuple containing the input data and optionally the target data.
|
|
164
|
+
batch_idx : Any
|
|
165
|
+
The index of the current batch in the prediction loop.
|
|
166
|
+
load_best_checkpoint : bool, default=False
|
|
167
|
+
Whether to load the best checkpoint before making predictions.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
Any
|
|
172
|
+
The output batch containing the predictions.
|
|
173
|
+
"""
|
|
174
|
+
if self._best_checkpoint_loaded is False and load_best_checkpoint:
|
|
175
|
+
self._load_best_checkpoint()
|
|
176
|
+
self._best_checkpoint_loaded = True
|
|
177
|
+
|
|
178
|
+
x = batch[0]
|
|
179
|
+
# TODO: add TTA
|
|
180
|
+
prediction = self.model(x.data).cpu().numpy()
|
|
181
|
+
|
|
182
|
+
means = self._trainer.datamodule.stats.means
|
|
183
|
+
stds = self._trainer.datamodule.stats.stds
|
|
184
|
+
denormalize = Denormalize(
|
|
185
|
+
image_means=means,
|
|
186
|
+
image_stds=stds,
|
|
187
|
+
)
|
|
188
|
+
denormalized_output = denormalize(prediction)
|
|
189
|
+
|
|
190
|
+
output_batch = ImageRegionData(
|
|
191
|
+
data=denormalized_output,
|
|
192
|
+
source=x.source,
|
|
193
|
+
data_shape=x.data_shape,
|
|
194
|
+
dtype=x.dtype,
|
|
195
|
+
axes=x.axes,
|
|
196
|
+
region_spec=x.region_spec,
|
|
197
|
+
additional_metadata={},
|
|
198
|
+
)
|
|
199
|
+
return output_batch
|
|
200
|
+
|
|
201
|
+
def configure_optimizers(self) -> Any:
|
|
202
|
+
"""Configure optimizers.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
Any
|
|
207
|
+
A dictionary containing the optimizer and learning rate scheduler.
|
|
208
|
+
"""
|
|
209
|
+
optimizer_func = get_optimizer(self.config.optimizer.name)
|
|
210
|
+
optimizer = optimizer_func(
|
|
211
|
+
self.model.parameters(), **self.config.optimizer.parameters
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
scheduler_func = get_scheduler(self.config.lr_scheduler.name)
|
|
215
|
+
scheduler = scheduler_func(optimizer, **self.config.lr_scheduler.parameters)
|
|
216
|
+
|
|
217
|
+
return {
|
|
218
|
+
"optimizer": optimizer,
|
|
219
|
+
"lr_scheduler": scheduler,
|
|
220
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
221
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Prediction utilities for the NG Dataset."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"combine_samples",
|
|
5
|
+
"convert_prediction",
|
|
6
|
+
"decollate_image_region_data",
|
|
7
|
+
"stitch_prediction",
|
|
8
|
+
"stitch_single_prediction",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from .convert_prediction import (
|
|
12
|
+
combine_samples,
|
|
13
|
+
convert_prediction,
|
|
14
|
+
decollate_image_region_data,
|
|
15
|
+
)
|
|
16
|
+
from .stitch_prediction import stitch_prediction, stitch_single_prediction
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
9
|
+
|
|
10
|
+
from .stitch_prediction import group_tiles_by_key, stitch_prediction
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _decollate_batch_dict(
|
|
17
|
+
batched_dict: "dict[str, list | Tensor]",
|
|
18
|
+
index: int,
|
|
19
|
+
) -> dict[str, int | tuple[int, ...]]:
|
|
20
|
+
"""
|
|
21
|
+
Decollate element `index` from a batched_dict.
|
|
22
|
+
|
|
23
|
+
This method is only compatible with integer elements.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
batched_dict : dict of {str: list or Tensor}
|
|
28
|
+
Batch dictionary where each value is a list of elements of length B or a
|
|
29
|
+
Tensor of shape (B,).
|
|
30
|
+
index : int
|
|
31
|
+
Index of the element to extract.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
dict of {str: int | tuple[int, ...]}
|
|
36
|
+
Dictionary of the `index` element in the collated batch.
|
|
37
|
+
"""
|
|
38
|
+
item_dict = {
|
|
39
|
+
key: (
|
|
40
|
+
# cast to int otherwise we have Tensor scalars
|
|
41
|
+
# TODO for additional types (e.g. axes in additional_metadata), we will need
|
|
42
|
+
# to handle it differently
|
|
43
|
+
tuple(int(value[idx][index]) for idx in range(len(value)))
|
|
44
|
+
if isinstance(value, list)
|
|
45
|
+
else int(value[index])
|
|
46
|
+
) # handles tensor (1D) vs list of 1D tensors (2D)
|
|
47
|
+
for key, value in batched_dict.items()
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
return item_dict
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def decollate_image_region_data(
|
|
54
|
+
batch: ImageRegionData,
|
|
55
|
+
) -> list[ImageRegionData]:
|
|
56
|
+
"""
|
|
57
|
+
Decollate a batch of `ImageRegionData` into a list of `ImageRegionData`.
|
|
58
|
+
|
|
59
|
+
Input batch has the following structure:
|
|
60
|
+
- data: (B, C, (Z), Y, X) numpy.ndarray
|
|
61
|
+
- source: sequence of str, length B
|
|
62
|
+
- data_shape: sequence of tuple of int, each tuple being of length B
|
|
63
|
+
- dtype: list of numpy.dtype, length B
|
|
64
|
+
- axes: list of str, length B
|
|
65
|
+
- region_spec: dict of {str: sequence}, each sequence being of length B
|
|
66
|
+
- additional_metadata: dict of {str: Any}, each sequence being of length B
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
batch : ImageRegionData
|
|
71
|
+
Batch of `ImageRegionData`.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
list of ImageRegionData
|
|
76
|
+
List of `ImageRegionData`.
|
|
77
|
+
"""
|
|
78
|
+
batch_size = batch.data.shape[0]
|
|
79
|
+
decollated: list[ImageRegionData] = []
|
|
80
|
+
for i in range(batch_size):
|
|
81
|
+
# unpack region spec irrespective of whether it is a PatchSpecs or TileSpecs
|
|
82
|
+
region_spec = _decollate_batch_dict(batch.region_spec, i)
|
|
83
|
+
|
|
84
|
+
# handle additional metadata
|
|
85
|
+
# currently only zarr chunks and shards may be stored there, as tuples.
|
|
86
|
+
# TODO if additional metadata becomes used for anything else, this function
|
|
87
|
+
# call may not be appropriate anymore.
|
|
88
|
+
additional_metadata = _decollate_batch_dict(batch.additional_metadata, i)
|
|
89
|
+
|
|
90
|
+
# data shape
|
|
91
|
+
assert isinstance(batch.data_shape, list)
|
|
92
|
+
data_shape = tuple(int(dim[i]) for dim in batch.data_shape)
|
|
93
|
+
|
|
94
|
+
image_region = ImageRegionData(
|
|
95
|
+
data=batch.data[i], # discard batch dimension
|
|
96
|
+
source=batch.source[i],
|
|
97
|
+
dtype=batch.dtype[i],
|
|
98
|
+
data_shape=data_shape,
|
|
99
|
+
axes=batch.axes[i],
|
|
100
|
+
region_spec=region_spec, # type: ignore
|
|
101
|
+
additional_metadata=additional_metadata,
|
|
102
|
+
)
|
|
103
|
+
decollated.append(image_region)
|
|
104
|
+
|
|
105
|
+
return decollated
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def combine_samples(
|
|
109
|
+
predictions: list[ImageRegionData],
|
|
110
|
+
) -> tuple[list[NDArray], list[str]]:
|
|
111
|
+
"""
|
|
112
|
+
Combine predictions by `data_idx`.
|
|
113
|
+
|
|
114
|
+
Images are first grouped by their `data_idx` found in their `region_spec`, then
|
|
115
|
+
sorted by ascending `sample_idx` before being stacked along the `S` dimension.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
predictions : list of ImageRegionData
|
|
120
|
+
List of `ImageRegionData`.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
list of numpy.ndarray
|
|
125
|
+
List of combined predictions, one per unique `data_idx`.
|
|
126
|
+
list of str
|
|
127
|
+
List of sources, one per unique `data_idx`.
|
|
128
|
+
"""
|
|
129
|
+
# group predictions by data idx
|
|
130
|
+
grouped_prediction: dict[int, list[ImageRegionData]] = group_tiles_by_key(
|
|
131
|
+
predictions, key="data_idx"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# sort predictions by sample idx
|
|
135
|
+
combined_predictions: list[NDArray] = []
|
|
136
|
+
combined_sources: list[str] = []
|
|
137
|
+
for data_idx in sorted(grouped_prediction.keys()):
|
|
138
|
+
image_regions = grouped_prediction[data_idx]
|
|
139
|
+
combined_sources.append(image_regions[0].source)
|
|
140
|
+
|
|
141
|
+
# sort by sample idx
|
|
142
|
+
image_regions.sort(key=lambda x: x.region_spec["sample_idx"])
|
|
143
|
+
|
|
144
|
+
# remove singleton dims and stack along S axis
|
|
145
|
+
combined_data = np.stack([img.data.squeeze() for img in image_regions], axis=0)
|
|
146
|
+
combined_predictions.append(combined_data)
|
|
147
|
+
|
|
148
|
+
return combined_predictions, combined_sources
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def convert_prediction(
|
|
152
|
+
predictions: list[ImageRegionData],
|
|
153
|
+
tiled: bool,
|
|
154
|
+
) -> tuple[list[NDArray], list[str]]:
|
|
155
|
+
"""
|
|
156
|
+
Convert the Lightning trainer outputs to the desired form.
|
|
157
|
+
|
|
158
|
+
This method allows decollating batches and stitching back together tiled
|
|
159
|
+
predictions.
|
|
160
|
+
|
|
161
|
+
If the `source` of all predictions is "array" (see `InMemoryImageStack`), then the
|
|
162
|
+
returned sources list will be empty.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
predictions : list[ImageRegionData]
|
|
167
|
+
Output from `Trainer.predict`, list of batches.
|
|
168
|
+
tiled : bool
|
|
169
|
+
Whether the predictions are tiled.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
list of numpy.ndarray
|
|
174
|
+
List of arrays with the axes SC(Z)YX.
|
|
175
|
+
list of str
|
|
176
|
+
List of sources, one per output or empty if all equal to `array`.
|
|
177
|
+
"""
|
|
178
|
+
# decollate batches
|
|
179
|
+
decollated_predictions: list[ImageRegionData] = []
|
|
180
|
+
for batch in predictions:
|
|
181
|
+
decollated_batch = decollate_image_region_data(batch)
|
|
182
|
+
decollated_predictions.extend(decollated_batch)
|
|
183
|
+
|
|
184
|
+
if not tiled and "total_tiles" in decollated_predictions[0].region_spec:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
"Predictions contain `total_tiles` in region_spec but `tiled` is set to "
|
|
187
|
+
"False."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if tiled:
|
|
191
|
+
predictions_output, sources = stitch_prediction(decollated_predictions)
|
|
192
|
+
else:
|
|
193
|
+
predictions_output, sources = combine_samples(decollated_predictions)
|
|
194
|
+
|
|
195
|
+
if set(sources) == {"array"}:
|
|
196
|
+
sources = []
|
|
197
|
+
|
|
198
|
+
return predictions_output, sources
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Tiled prediction stitching utilities."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
11
|
+
from careamics.dataset_ng.patching_strategies import TileSpecs
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def group_tiles_by_key(
|
|
15
|
+
tiles: list[ImageRegionData], key: Literal["data_idx", "sample_idx"]
|
|
16
|
+
) -> dict[int, list[ImageRegionData]]:
|
|
17
|
+
"""
|
|
18
|
+
Sort tiles by key.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
tiles : list of ImageRegionData
|
|
23
|
+
List of tiles to sort.
|
|
24
|
+
key : {'data_idx', 'sample_idx'}
|
|
25
|
+
Key to group tiles by.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
{int: list of ImageRegionData}
|
|
30
|
+
Dictionary mapping data indices to lists of tiles.
|
|
31
|
+
"""
|
|
32
|
+
sorted_tiles: dict[int, list[ImageRegionData]] = defaultdict(list)
|
|
33
|
+
for tile in tiles:
|
|
34
|
+
key_value = tile.region_spec[key]
|
|
35
|
+
sorted_tiles[key_value].append(tile)
|
|
36
|
+
return sorted_tiles
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def stitch_prediction(
|
|
40
|
+
tiles: list[ImageRegionData],
|
|
41
|
+
) -> tuple[list[NDArray], list[str]]:
|
|
42
|
+
"""
|
|
43
|
+
Stitch tiles back together to form full images.
|
|
44
|
+
|
|
45
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
46
|
+
singleton dimension.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
tiles : list of ImageRegionData
|
|
51
|
+
Cropped tiles and their respective stitching coordinates. Can contain tiles
|
|
52
|
+
from multiple images.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
list of numpy.ndarray
|
|
57
|
+
Full images, may be a single image.
|
|
58
|
+
list of str
|
|
59
|
+
List of sources, one per output.
|
|
60
|
+
"""
|
|
61
|
+
# sort tiles by data index
|
|
62
|
+
grouped_tiles: dict[int, list[ImageRegionData]] = group_tiles_by_key(
|
|
63
|
+
tiles, key="data_idx"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# stitch each image separately
|
|
67
|
+
image_predictions: list[NDArray] = []
|
|
68
|
+
image_sources: list[str] = []
|
|
69
|
+
for data_idx in sorted(grouped_tiles.keys()):
|
|
70
|
+
image_predictions.append(stitch_single_prediction(grouped_tiles[data_idx]))
|
|
71
|
+
image_sources.append(grouped_tiles[data_idx][0].source)
|
|
72
|
+
|
|
73
|
+
return image_predictions, image_sources
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def stitch_single_prediction(
|
|
77
|
+
tiles: list[ImageRegionData],
|
|
78
|
+
) -> NDArray:
|
|
79
|
+
"""
|
|
80
|
+
Stitch tiles back together to form a full image.
|
|
81
|
+
|
|
82
|
+
Tiles are of dimensions C(Z)YX, where C is the number of channels and can be a
|
|
83
|
+
singleton dimension.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
tiles : list of ImageRegionData
|
|
88
|
+
Cropped tiles and their respective stitching coordinates.
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
numpy.ndarray
|
|
93
|
+
Full image, with dimensions SC(Z)YX.
|
|
94
|
+
"""
|
|
95
|
+
data_shape = tiles[0].data_shape
|
|
96
|
+
predicted_image = np.zeros(data_shape, dtype=np.float32)
|
|
97
|
+
|
|
98
|
+
if "S" in tiles[0].axes:
|
|
99
|
+
tiles_by_sample = group_tiles_by_key(tiles, key="sample_idx")
|
|
100
|
+
for sample_idx in tiles_by_sample.keys():
|
|
101
|
+
sample_tiles = tiles_by_sample[sample_idx]
|
|
102
|
+
stitched_sample = stitch_single_sample(sample_tiles)
|
|
103
|
+
|
|
104
|
+
# compute sample slice
|
|
105
|
+
sample_slice = slice(
|
|
106
|
+
sample_idx,
|
|
107
|
+
sample_idx + 1,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# insert stitched sample into predicted image
|
|
111
|
+
predicted_image[sample_slice] = stitched_sample.astype(np.float32)
|
|
112
|
+
else:
|
|
113
|
+
# stitch as a single sample
|
|
114
|
+
# predicted_image has singleton sample dimension
|
|
115
|
+
predicted_image[0] = stitch_single_sample(tiles)
|
|
116
|
+
|
|
117
|
+
return predicted_image
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def stitch_single_sample(
|
|
121
|
+
tiles: list[ImageRegionData],
|
|
122
|
+
) -> NDArray:
|
|
123
|
+
"""
|
|
124
|
+
Stitch tiles back together to form a full sample.
|
|
125
|
+
|
|
126
|
+
Tiles are of dimensions C(Z)YX, where C is the number of channels and can be a
|
|
127
|
+
singleton dimension.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
tiles : list of ImageRegionData
|
|
132
|
+
Cropped tiles and their respective stitching coordinates.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
numpy.ndarray
|
|
137
|
+
Full sample, with dimensions C(Z)YX.
|
|
138
|
+
"""
|
|
139
|
+
data_shape = tiles[0].data_shape # SC(Z)YX
|
|
140
|
+
predicted_sample = np.zeros(data_shape[1:], dtype=np.float32)
|
|
141
|
+
|
|
142
|
+
for tile in tiles:
|
|
143
|
+
# compute crop coordinates and stitiching coordinates
|
|
144
|
+
tile_spec: TileSpecs = tile.region_spec # type: ignore
|
|
145
|
+
crop_coords = tile_spec["crop_coords"]
|
|
146
|
+
crop_size = tile_spec["crop_size"]
|
|
147
|
+
stitch_coords = tile_spec["stitch_coords"]
|
|
148
|
+
|
|
149
|
+
crop_slices: tuple[builtins.ellipsis | slice, ...] = (
|
|
150
|
+
...,
|
|
151
|
+
*[
|
|
152
|
+
slice(start, start + length)
|
|
153
|
+
for start, length in zip(crop_coords, crop_size, strict=True)
|
|
154
|
+
],
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
stitch_slices: tuple[builtins.ellipsis | slice, ...] = (
|
|
158
|
+
...,
|
|
159
|
+
*[
|
|
160
|
+
slice(start, start + length)
|
|
161
|
+
for start, length in zip(stitch_coords, crop_size, strict=True)
|
|
162
|
+
],
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# crop predited tile according to overlap coordinates
|
|
166
|
+
cropped_tile = tile.data[crop_slices]
|
|
167
|
+
|
|
168
|
+
# insert cropped tile into predicted image
|
|
169
|
+
predicted_sample[stitch_slices] = cropped_tile.astype(np.float32)
|
|
170
|
+
|
|
171
|
+
return predicted_sample
|