careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""Module containing functions to convert prediction outputs to desired form."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, Union, overload
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from ..config.data.tile_information import TileInformation
|
|
9
|
+
from .stitch_prediction import stitch_prediction, stitch_prediction_vae
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
|
|
13
|
+
"""
|
|
14
|
+
Convert the Lightning trainer outputs to the desired form.
|
|
15
|
+
|
|
16
|
+
This method allows stitching back together tiled predictions.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
predictions : list
|
|
21
|
+
Predictions that are output from `Trainer.predict`.
|
|
22
|
+
tiled : bool
|
|
23
|
+
Whether the predictions are tiled.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
list of numpy.ndarray or numpy.ndarray
|
|
28
|
+
list of arrays with the axes SC(Z)YX. If there is only 1 output it will not
|
|
29
|
+
be in a list.
|
|
30
|
+
"""
|
|
31
|
+
if len(predictions) == 0:
|
|
32
|
+
return predictions
|
|
33
|
+
|
|
34
|
+
# this layout is to stop mypy complaining
|
|
35
|
+
if tiled:
|
|
36
|
+
predictions_comb = combine_batches(predictions, tiled)
|
|
37
|
+
predictions_output = stitch_prediction(*predictions_comb)
|
|
38
|
+
else:
|
|
39
|
+
predictions_output = combine_batches(predictions, tiled)
|
|
40
|
+
|
|
41
|
+
return predictions_output
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def convert_outputs_pn2v(
|
|
45
|
+
predictions: list[Any], tiled: bool
|
|
46
|
+
) -> tuple[list[NDArray], list[NDArray]]:
|
|
47
|
+
"""
|
|
48
|
+
Convert the Lightning trainer outputs to the desired form.
|
|
49
|
+
|
|
50
|
+
This method allows stitching back together tiled predictions.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
predictions : list
|
|
55
|
+
Predictions that are output from `Trainer.predict`. Length of list the total
|
|
56
|
+
number of tiles divided by the batch size. Each element consists of a tuple of
|
|
57
|
+
((prediction, mse), tile_info_list). 1st dimension of each tensor is the bs.
|
|
58
|
+
Length of tile info list is the batch size.
|
|
59
|
+
|
|
60
|
+
tiled : bool
|
|
61
|
+
Whether the predictions are tiled.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
tuple[list[NDArray], list[NDArray]]
|
|
66
|
+
Tuple of (predictions, mmse) where each is a list of arrays with axes SC(Z)YX.
|
|
67
|
+
"""
|
|
68
|
+
if len(predictions) == 0:
|
|
69
|
+
return [], []
|
|
70
|
+
# TODO test with multi_channel predictions
|
|
71
|
+
if tiled:
|
|
72
|
+
# Separate predictions and mmse, keeping tile info for each
|
|
73
|
+
pred_with_tiles = [
|
|
74
|
+
(pred, tile_info_list) for (pred, _), tile_info_list in predictions
|
|
75
|
+
]
|
|
76
|
+
mse_with_tiles = [
|
|
77
|
+
(mse, tile_info_list) for (_, mse), tile_info_list in predictions
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
# Process predictions
|
|
81
|
+
pred_comb = combine_batches(pred_with_tiles, tiled)
|
|
82
|
+
predictions_output = stitch_prediction(*pred_comb)
|
|
83
|
+
|
|
84
|
+
# Process mmse
|
|
85
|
+
mse_comb = combine_batches(mse_with_tiles, tiled)
|
|
86
|
+
mse_output = stitch_prediction(*mse_comb)
|
|
87
|
+
|
|
88
|
+
return predictions_output, mse_output
|
|
89
|
+
else:
|
|
90
|
+
# Separate predictions and mmse for non-tiled case
|
|
91
|
+
pred_only_tuple, mse_only_tuple = zip(*predictions, strict=False)
|
|
92
|
+
pred_only_list: list[NDArray] = list(pred_only_tuple)
|
|
93
|
+
mse_only_list: list[NDArray] = list(mse_only_tuple)
|
|
94
|
+
|
|
95
|
+
predictions_output = combine_batches(pred_only_list, tiled=False)
|
|
96
|
+
mse_output = combine_batches(mse_only_list, tiled=False)
|
|
97
|
+
|
|
98
|
+
return predictions_output, mse_output
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def convert_outputs_microsplit(
|
|
102
|
+
predictions: list[tuple[NDArray, NDArray]], dataset
|
|
103
|
+
) -> tuple[NDArray, NDArray]:
|
|
104
|
+
"""
|
|
105
|
+
Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
|
|
106
|
+
|
|
107
|
+
This function processes microsplit predictions that return
|
|
108
|
+
(tile_prediction, tile_std) tuples and stitches them back together using the same
|
|
109
|
+
logic as get_single_file_mmse.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
predictions : list of tuple[NDArray, NDArray]
|
|
114
|
+
Predictions from Lightning trainer for microsplit. Each element is a tuple of
|
|
115
|
+
(tile_prediction, tile_std) where both are numpy arrays from predict_step.
|
|
116
|
+
dataset : Dataset
|
|
117
|
+
The dataset object used for prediction, needed for stitching function selection
|
|
118
|
+
and stitching process.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
tuple[NDArray, NDArray]
|
|
123
|
+
A tuple of (stitched_predictions, stitched_stds) representing the full
|
|
124
|
+
stitched predictions and standard deviations.
|
|
125
|
+
"""
|
|
126
|
+
if len(predictions) == 0:
|
|
127
|
+
raise ValueError("No predictions provided")
|
|
128
|
+
|
|
129
|
+
# Separate predictions and stds from the list of tuples
|
|
130
|
+
tile_predictions = [pred for pred, _ in predictions]
|
|
131
|
+
tile_stds = [std for _, std in predictions]
|
|
132
|
+
|
|
133
|
+
# Concatenate all tiles exactly like get_single_file_mmse
|
|
134
|
+
tiles_arr = np.concatenate(tile_predictions, axis=0)
|
|
135
|
+
tile_stds_arr = np.concatenate(tile_stds, axis=0)
|
|
136
|
+
|
|
137
|
+
# Apply stitching using stitch_predictions_new
|
|
138
|
+
stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
|
|
139
|
+
stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)
|
|
140
|
+
|
|
141
|
+
return stitched_predictions, stitched_stds
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# for mypy
|
|
145
|
+
@overload
|
|
146
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
147
|
+
predictions: list[Any], tiled: Literal[True]
|
|
148
|
+
) -> tuple[list[NDArray], list[TileInformation]]: ...
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# for mypy
|
|
152
|
+
@overload
|
|
153
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
154
|
+
predictions: list[Any], tiled: Literal[False]
|
|
155
|
+
) -> list[NDArray]: ...
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# for mypy
|
|
159
|
+
@overload
|
|
160
|
+
def combine_batches( # numpydoc ignore=GL08
|
|
161
|
+
predictions: list[Any], tiled: Union[bool, Literal[True], Literal[False]]
|
|
162
|
+
) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]: ...
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def combine_batches(
|
|
166
|
+
predictions: list[Any], tiled: bool
|
|
167
|
+
) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]:
|
|
168
|
+
"""
|
|
169
|
+
If predictions are in batches, they will be combined.
|
|
170
|
+
|
|
171
|
+
# TODO improve description!
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
predictions : list
|
|
176
|
+
Predictions that are output from `Trainer.predict`.
|
|
177
|
+
tiled : bool
|
|
178
|
+
Whether the predictions are tiled.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
(list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
|
|
183
|
+
Combined batches.
|
|
184
|
+
"""
|
|
185
|
+
if tiled:
|
|
186
|
+
return _combine_tiled_batches(predictions)
|
|
187
|
+
else:
|
|
188
|
+
return _combine_array_batches(predictions)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _combine_tiled_batches(
|
|
192
|
+
predictions: list[tuple[NDArray, list[TileInformation]]],
|
|
193
|
+
) -> tuple[list[NDArray], list[TileInformation]]:
|
|
194
|
+
"""
|
|
195
|
+
Combine batches from tiled output.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
predictions : list of (numpy.ndarray, list of TileInformation)
|
|
200
|
+
Predictions that are output from `Trainer.predict`. For tiled batches, this is
|
|
201
|
+
a list of tuples. The first element of the tuples is the prediction output of
|
|
202
|
+
tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
|
|
203
|
+
element of the tuples is a list of TileInformation objects of length B.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
tuple of (list of numpy.ndarray, list of TileInformation)
|
|
208
|
+
Combined batches.
|
|
209
|
+
"""
|
|
210
|
+
# turn list of lists into single list
|
|
211
|
+
tile_infos = [
|
|
212
|
+
tile_info for *_, tile_info_list in predictions for tile_info in tile_info_list
|
|
213
|
+
]
|
|
214
|
+
prediction_tiles: list[NDArray] = _combine_array_batches(
|
|
215
|
+
[preds for preds, *_ in predictions]
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return prediction_tiles, tile_infos
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _combine_array_batches(predictions: list[NDArray]) -> list[NDArray]:
|
|
222
|
+
"""
|
|
223
|
+
Combine batches of arrays.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
predictions : list
|
|
228
|
+
Prediction arrays that are output from `Trainer.predict`. A list of arrays that
|
|
229
|
+
have dimensions (B, C, (Z), Y, X), where B is batch size.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
list of numpy.ndarray
|
|
234
|
+
A list of arrays with dimensions (1, C, (Z), Y, X).
|
|
235
|
+
"""
|
|
236
|
+
prediction_concat: NDArray = np.concatenate(predictions, axis=0)
|
|
237
|
+
prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
|
|
238
|
+
return prediction_split
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Prediction utility functions."""
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.data.tile_information import TileInformation
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TilingMode:
|
|
13
|
+
"""Enum for the tiling mode."""
|
|
14
|
+
|
|
15
|
+
TrimBoundary = 0
|
|
16
|
+
PadBoundary = 1
|
|
17
|
+
ShiftBoundary = 2
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def stitch_prediction_vae(predictions, dset):
|
|
21
|
+
"""Stitch predictions back together using dataset's index manager.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
predictions : numpy.ndarray
|
|
26
|
+
Array of predictions with shape (n_tiles, channels, height, width).
|
|
27
|
+
dset : Dataset
|
|
28
|
+
Dataset object with idx_manager containing tiling information.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
numpy.ndarray
|
|
33
|
+
Stitched predictions.
|
|
34
|
+
"""
|
|
35
|
+
mng = dset.idx_manager
|
|
36
|
+
|
|
37
|
+
# if there are more channels, use all of them.
|
|
38
|
+
shape = list(dset.get_data_shape())
|
|
39
|
+
shape[-1] = max(shape[-1], predictions.shape[1])
|
|
40
|
+
|
|
41
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
|
42
|
+
# frame_shape = dset.get_data_shape()[:-1]
|
|
43
|
+
for dset_idx in range(predictions.shape[0]):
|
|
44
|
+
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
45
|
+
# predictions.shape[-1])
|
|
46
|
+
# grid start, grid end
|
|
47
|
+
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
|
48
|
+
ge = gs + mng.grid_shape
|
|
49
|
+
|
|
50
|
+
# patch start, patch end
|
|
51
|
+
ps = gs - mng.patch_offset()
|
|
52
|
+
pe = ps + mng.patch_shape
|
|
53
|
+
|
|
54
|
+
# valid grid start, valid grid end
|
|
55
|
+
vgs = np.array([max(0, x) for x in gs], dtype=int)
|
|
56
|
+
vge = np.array(
|
|
57
|
+
[min(x, y) for x, y in zip(ge, mng.data_shape, strict=False)], dtype=int
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
|
61
|
+
for dim in range(len(vgs)):
|
|
62
|
+
if ps[dim] == 0:
|
|
63
|
+
vgs[dim] = 0
|
|
64
|
+
if pe[dim] == mng.data_shape[dim]:
|
|
65
|
+
vge[dim] = mng.data_shape[dim]
|
|
66
|
+
|
|
67
|
+
# relative start, relative end. This will be used on pred_tiled
|
|
68
|
+
rs = vgs - ps
|
|
69
|
+
re = rs + (vge - vgs)
|
|
70
|
+
|
|
71
|
+
for ch_idx in range(predictions.shape[1]):
|
|
72
|
+
if len(output.shape) == 4:
|
|
73
|
+
# channel dimension is the last one.
|
|
74
|
+
output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
|
|
75
|
+
predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
|
|
76
|
+
)
|
|
77
|
+
elif len(output.shape) == 5:
|
|
78
|
+
# channel dimension is the last one.
|
|
79
|
+
assert vge[0] - vgs[0] == 1, "Only one frame is supported"
|
|
80
|
+
output[
|
|
81
|
+
vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
|
|
82
|
+
] = predictions[dset_idx][
|
|
83
|
+
ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
|
|
84
|
+
]
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Unsupported shape {output.shape}")
|
|
87
|
+
|
|
88
|
+
return output
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# TODO: why not allow input and output of torch.tensor ?
|
|
92
|
+
def stitch_prediction(
|
|
93
|
+
tiles: list[np.ndarray],
|
|
94
|
+
tile_infos: list[TileInformation],
|
|
95
|
+
) -> list[np.ndarray]:
|
|
96
|
+
"""
|
|
97
|
+
Stitch tiles back together to form a full image(s).
|
|
98
|
+
|
|
99
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
100
|
+
singleton dimension.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
tiles : list of numpy.ndarray
|
|
105
|
+
Cropped tiles and their respective stitching coordinates. Can contain tiles
|
|
106
|
+
from multiple images.
|
|
107
|
+
tile_infos : list of TileInformation
|
|
108
|
+
List of information and coordinates obtained from
|
|
109
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
list of numpy.ndarray
|
|
114
|
+
Full image(s).
|
|
115
|
+
"""
|
|
116
|
+
# Find where to split the lists so that only info from one image is contained.
|
|
117
|
+
# Do this by locating the last tiles of each image.
|
|
118
|
+
last_tiles = [tile_info.last_tile for tile_info in tile_infos]
|
|
119
|
+
last_tile_position = np.where(last_tiles)[0]
|
|
120
|
+
image_slices = [
|
|
121
|
+
slice(
|
|
122
|
+
None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
|
|
123
|
+
)
|
|
124
|
+
for i in range(len(last_tile_position))
|
|
125
|
+
]
|
|
126
|
+
image_predictions = []
|
|
127
|
+
# slice the lists and apply stitch_prediction_single to each in turn.
|
|
128
|
+
for image_slice in image_slices:
|
|
129
|
+
image_predictions.append(
|
|
130
|
+
stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
|
|
131
|
+
)
|
|
132
|
+
return image_predictions
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def stitch_prediction_single(
|
|
136
|
+
tiles: list[NDArray],
|
|
137
|
+
tile_infos: list[TileInformation],
|
|
138
|
+
) -> NDArray:
|
|
139
|
+
"""
|
|
140
|
+
Stitch tiles back together to form a full image.
|
|
141
|
+
|
|
142
|
+
Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
|
|
143
|
+
singleton dimension.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
tiles : list of numpy.ndarray
|
|
148
|
+
Cropped tiles and their respective stitching coordinates.
|
|
149
|
+
tile_infos : list of TileInformation
|
|
150
|
+
List of information and coordinates obtained from
|
|
151
|
+
`dataset.tiled_patching.extract_tiles`.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
numpy.ndarray
|
|
156
|
+
Full image, with dimensions SC(Z)YX.
|
|
157
|
+
"""
|
|
158
|
+
# TODO: this is hacky... need a better way to deal with when input channels and
|
|
159
|
+
# target channels do not match
|
|
160
|
+
if len(tile_infos[0].array_shape) == 4:
|
|
161
|
+
# 4 dimensions => 3 spatial dimensions so -4 is channel dimension
|
|
162
|
+
tile_channels = tiles[0].shape[-4]
|
|
163
|
+
elif len(tile_infos[0].array_shape) == 3:
|
|
164
|
+
# 3 dimensions => 2 spatial dimensions so -3 is channel dimension
|
|
165
|
+
tile_channels = tiles[0].shape[-3]
|
|
166
|
+
else:
|
|
167
|
+
# Note pretty sure this is unreachable because array shape is already
|
|
168
|
+
# validated by TileInformation
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Unsupported number of output dimension {len(tile_infos[0].array_shape)}"
|
|
171
|
+
)
|
|
172
|
+
# retrieve whole array size, add S dim and use number of channels in tile
|
|
173
|
+
input_shape = (1, tile_channels, *tile_infos[0].array_shape[1:])
|
|
174
|
+
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
175
|
+
|
|
176
|
+
for tile, tile_info in zip(tiles, tile_infos, strict=False):
|
|
177
|
+
|
|
178
|
+
# Compute coordinates for cropping predicted tile
|
|
179
|
+
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
180
|
+
...,
|
|
181
|
+
*[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Crop predited tile according to overlap coordinates
|
|
185
|
+
cropped_tile = tile[crop_slices]
|
|
186
|
+
|
|
187
|
+
# Insert cropped tile into predicted image using stitch coordinates
|
|
188
|
+
image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
|
|
189
|
+
|
|
190
|
+
# TODO fix mypy error here, potentially due to numpy 2
|
|
191
|
+
predicted_image[image_slices] = cropped_tile.astype(np.float32) # type: ignore
|
|
192
|
+
|
|
193
|
+
return predicted_image
|
careamics/py.typed
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Transforms that are used to augment the data."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"Compose",
|
|
5
|
+
"Denormalize",
|
|
6
|
+
"ImageRestorationTTA",
|
|
7
|
+
"N2VManipulate",
|
|
8
|
+
"N2VManipulateTorch",
|
|
9
|
+
"Normalize",
|
|
10
|
+
"TrainDenormalize",
|
|
11
|
+
"XYFlip",
|
|
12
|
+
"XYRandomRotate90",
|
|
13
|
+
"get_all_transforms",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
from .compose import Compose, get_all_transforms
|
|
17
|
+
from .n2v_manipulate import N2VManipulate
|
|
18
|
+
from .n2v_manipulate_torch import N2VManipulateTorch
|
|
19
|
+
from .normalize import Denormalize, Normalize, TrainDenormalize
|
|
20
|
+
from .tta import ImageRestorationTTA
|
|
21
|
+
from .xy_flip import XYFlip
|
|
22
|
+
from .xy_random_rotate90 import XYRandomRotate90
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""A class chaining transforms together."""
|
|
2
|
+
|
|
3
|
+
from typing import Union, cast
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from careamics.config.transformations import NORM_AND_SPATIAL_UNION
|
|
8
|
+
|
|
9
|
+
from .normalize import Normalize
|
|
10
|
+
from .transform import Transform
|
|
11
|
+
from .xy_flip import XYFlip
|
|
12
|
+
from .xy_random_rotate90 import XYRandomRotate90
|
|
13
|
+
|
|
14
|
+
ALL_TRANSFORMS = {
|
|
15
|
+
"Normalize": Normalize,
|
|
16
|
+
"XYFlip": XYFlip,
|
|
17
|
+
"XYRandomRotate90": XYRandomRotate90,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_all_transforms() -> dict[str, type]:
|
|
22
|
+
"""Return all the transforms accepted by CAREamics.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
dict
|
|
27
|
+
A dictionary with all the transforms accepted by CAREamics, where the keys are
|
|
28
|
+
the transform names and the values are the transform classes.
|
|
29
|
+
"""
|
|
30
|
+
return ALL_TRANSFORMS
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Compose:
|
|
34
|
+
"""A class chaining transforms together.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
transform_list : list[TransformConfig]
|
|
39
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
40
|
+
transform and its parameters.
|
|
41
|
+
|
|
42
|
+
Attributes
|
|
43
|
+
----------
|
|
44
|
+
_callable_transforms : Callable
|
|
45
|
+
A callable that applies the transforms to the input data.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, transform_list: list[NORM_AND_SPATIAL_UNION]) -> None:
|
|
49
|
+
"""Instantiate a Compose object.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
transform_list : list[NORM_AND_SPATIAL_UNION]
|
|
54
|
+
A list of dictionaries where each dictionary contains the name of a
|
|
55
|
+
transform and its parameters.
|
|
56
|
+
"""
|
|
57
|
+
# retrieve all available transforms
|
|
58
|
+
# TODO: correctly type hint get_all_transforms function output
|
|
59
|
+
all_transforms: dict[str, type[Transform]] = get_all_transforms()
|
|
60
|
+
|
|
61
|
+
# instantiate all transforms
|
|
62
|
+
self.transforms: list[Transform] = [
|
|
63
|
+
all_transforms[t.name](**t.model_dump()) for t in transform_list
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
def _chain_transforms(
|
|
67
|
+
self, patch: NDArray, target: NDArray | None
|
|
68
|
+
) -> tuple[NDArray | None, ...]:
|
|
69
|
+
"""Chain transforms on the input data.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
patch : np.ndarray
|
|
74
|
+
Input data.
|
|
75
|
+
target : Optional[np.ndarray]
|
|
76
|
+
Target data, by default None.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
tuple[np.ndarray, Optional[np.ndarray]]
|
|
81
|
+
The output of the transformations.
|
|
82
|
+
"""
|
|
83
|
+
params: Union[tuple[NDArray, NDArray | None],] = (patch, target)
|
|
84
|
+
|
|
85
|
+
for t in self.transforms:
|
|
86
|
+
*params, _ = t(*params) # ignore additional_arrays dict
|
|
87
|
+
|
|
88
|
+
# avoid None values that create problems for collating
|
|
89
|
+
# TODO: removing None should be handled in dataset, not here
|
|
90
|
+
return tuple(p for p in params if p is not None)
|
|
91
|
+
|
|
92
|
+
def _chain_transforms_additional_arrays(
|
|
93
|
+
self,
|
|
94
|
+
patch: NDArray,
|
|
95
|
+
target: NDArray | None,
|
|
96
|
+
**additional_arrays: NDArray,
|
|
97
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
98
|
+
"""Chain transforms on the input data, with additional arrays.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
patch : np.ndarray
|
|
103
|
+
Input data.
|
|
104
|
+
target : Optional[np.ndarray]
|
|
105
|
+
Target data, by default None.
|
|
106
|
+
**additional_arrays : NDArray
|
|
107
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
108
|
+
`target`.
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
tuple[np.ndarray, Optional[np.ndarray]]
|
|
113
|
+
The output of the transformations.
|
|
114
|
+
"""
|
|
115
|
+
params = {"patch": patch, "target": target, **additional_arrays}
|
|
116
|
+
|
|
117
|
+
for t in self.transforms:
|
|
118
|
+
patch, target, additional_arrays = t(**params)
|
|
119
|
+
params = {"patch": patch, "target": target, **additional_arrays}
|
|
120
|
+
|
|
121
|
+
return patch, target, additional_arrays
|
|
122
|
+
|
|
123
|
+
def __call__(
|
|
124
|
+
self, patch: NDArray, target: NDArray | None = None
|
|
125
|
+
) -> tuple[NDArray, ...]:
|
|
126
|
+
"""Apply the transforms to the input data.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
patch : np.ndarray
|
|
131
|
+
The input data.
|
|
132
|
+
target : Optional[np.ndarray], optional
|
|
133
|
+
Target data, by default None.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
tuple[np.ndarray, ...]
|
|
138
|
+
The output of the transformations.
|
|
139
|
+
"""
|
|
140
|
+
# TODO: solve casting Compose.__call__ ouput
|
|
141
|
+
return cast(tuple[NDArray, ...], self._chain_transforms(patch, target))
|
|
142
|
+
|
|
143
|
+
def transform_with_additional_arrays(
|
|
144
|
+
self,
|
|
145
|
+
patch: NDArray,
|
|
146
|
+
target: NDArray | None = None,
|
|
147
|
+
**additional_arrays: NDArray,
|
|
148
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
149
|
+
"""Apply the transforms to the input data, including additional arrays.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
patch : np.ndarray
|
|
154
|
+
The input data.
|
|
155
|
+
target : Optional[np.ndarray], optional
|
|
156
|
+
Target data, by default None.
|
|
157
|
+
**additional_arrays : NDArray
|
|
158
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
159
|
+
`target`.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
NDArray
|
|
164
|
+
The transformed patch.
|
|
165
|
+
NDArray | None
|
|
166
|
+
The transformed target.
|
|
167
|
+
dict of {str, NDArray}
|
|
168
|
+
Transformed additional arrays. Keys correspond to the keyword argument
|
|
169
|
+
names.
|
|
170
|
+
"""
|
|
171
|
+
return self._chain_transforms_additional_arrays(
|
|
172
|
+
patch, target, **additional_arrays
|
|
173
|
+
)
|