careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Module containing pytorch implementations for obtaining predictions from an LVAE."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from careamics.models.lvae import LadderVAE as LVAE
|
|
8
|
+
from careamics.models.lvae.likelihoods import LikelihoodModule
|
|
9
|
+
|
|
10
|
+
# TODO: convert these functions to lightning module `predict_step`
|
|
11
|
+
# -> mmse_count will have to be an instance attribute?
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# This function is needed because the output of the datasets (input here) can include
|
|
15
|
+
# auxillary items, such as the TileInformation. This function allows for easier reuse
|
|
16
|
+
# between lvae_predict_single_sample and lvae_predict_mmse.
|
|
17
|
+
def lvae_predict_single_sample(
|
|
18
|
+
model: LVAE,
|
|
19
|
+
likelihood_obj: LikelihoodModule,
|
|
20
|
+
input: torch.Tensor,
|
|
21
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
22
|
+
"""
|
|
23
|
+
Generate a single sample prediction from an LVAE model, for a given input.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
model : LVAE
|
|
28
|
+
Trained LVAE model.
|
|
29
|
+
likelihood_obj : LikelihoodModule
|
|
30
|
+
Instance of a likelihood class.
|
|
31
|
+
input : torch.tensor
|
|
32
|
+
Input to generate prediction for. Expected shape is (S, C, Y, X).
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
tuple of (torch.tensor, optional torch.tensor)
|
|
37
|
+
The first element is the sample prediction, and the second element is the
|
|
38
|
+
log-variance. The log-variance will be None if `model.predict_logvar is None`.
|
|
39
|
+
"""
|
|
40
|
+
model.eval() # Not in original predict code: effects batch_norm and dropout layers
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
output: torch.Tensor
|
|
43
|
+
output, _ = model(input) # 2nd item is top-down data dict
|
|
44
|
+
|
|
45
|
+
# presently, get_mean_lv just splits the output in 2 if predict_logvar=True,
|
|
46
|
+
# optionally clips the logvavr if logvar_lowerbound is not None
|
|
47
|
+
# TODO: consider refactoring to remove use of the likelihood object
|
|
48
|
+
sample_prediction, log_var = likelihood_obj.get_mean_lv(output)
|
|
49
|
+
|
|
50
|
+
# TODO: output denormalization using target stats that will be saved in data config
|
|
51
|
+
# -> Don't think we need this, saw it in a random bit of code somewhere.
|
|
52
|
+
|
|
53
|
+
return sample_prediction, log_var
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def lvae_predict_tiled_batch(
|
|
57
|
+
model: LVAE,
|
|
58
|
+
likelihood_obj: LikelihoodModule,
|
|
59
|
+
input: tuple[Any],
|
|
60
|
+
) -> tuple[tuple[Any], tuple[Any] | None]:
|
|
61
|
+
# TODO: fix docstring return types, ... too many output options
|
|
62
|
+
"""
|
|
63
|
+
Generate a single sample prediction from an LVAE model, for a given input.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
model : LVAE
|
|
68
|
+
Trained LVAE model.
|
|
69
|
+
likelihood_obj : LikelihoodModule
|
|
70
|
+
Instance of a likelihood class.
|
|
71
|
+
input : torch.tensor | tuple of (torch.tensor, Any, ...)
|
|
72
|
+
Input to generate prediction for. This can include auxilary inputs such as
|
|
73
|
+
`TileInformation`, but the model input is always the first item of the tuple.
|
|
74
|
+
Expected shape of the model input is (S, C, Y, X).
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
tuple of ((torch.tensor, Any, ...), optional tuple of (torch.tensor, Any, ...))
|
|
79
|
+
The first element is the sample prediction, and the second element is the
|
|
80
|
+
log-variance. The log-variance will be None if `model.predict_logvar is None`.
|
|
81
|
+
Any auxillary data included in the input will also be include with both the
|
|
82
|
+
sample prediction and the log-variance.
|
|
83
|
+
"""
|
|
84
|
+
x: torch.Tensor
|
|
85
|
+
aux: list[Any]
|
|
86
|
+
x, *aux = input
|
|
87
|
+
|
|
88
|
+
sample_prediction, log_var = lvae_predict_single_sample(
|
|
89
|
+
model=model, likelihood_obj=likelihood_obj, input=x
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
log_var_output = (log_var, *aux) if log_var is not None else None
|
|
93
|
+
return (sample_prediction, *aux), log_var_output
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def lvae_predict_mmse_tiled_batch(
|
|
97
|
+
model: LVAE,
|
|
98
|
+
likelihood_obj: LikelihoodModule,
|
|
99
|
+
input: tuple[Any],
|
|
100
|
+
mmse_count: int,
|
|
101
|
+
) -> tuple[tuple[Any], tuple[Any], tuple[Any] | None]:
|
|
102
|
+
# TODO: fix docstring return types, ... hard to make readable
|
|
103
|
+
"""
|
|
104
|
+
Generate the MMSE (minimum mean squared error) prediction, for a given input.
|
|
105
|
+
|
|
106
|
+
This is calculated from the mean of multiple single sample predictions.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
model : LVAE
|
|
111
|
+
Trained LVAE model.
|
|
112
|
+
likelihood_obj : LikelihoodModule
|
|
113
|
+
Instance of a likelihood class.
|
|
114
|
+
input : torch.tensor | tuple of (torch.tensor, Any, ...)
|
|
115
|
+
Input to generate prediction for. This can include auxilary inputs such as
|
|
116
|
+
`TileInformation`, but the model input is always the first item of the tuple.
|
|
117
|
+
Expected shape of the model input is (S, C, Y, X).
|
|
118
|
+
mmse_count : int
|
|
119
|
+
Number of samples to generate to calculate MMSE (minimum mean squared error).
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
tuple of (tuple of (torch.Tensor[Any], Any, ...))
|
|
124
|
+
A tuple of 3 elements. The first element contains the MMSE prediction, the
|
|
125
|
+
second contains the standard deviation of the samples used to create the MMSE
|
|
126
|
+
prediction. Finally the last element contains the log-variance of the
|
|
127
|
+
likelihood, this will be `None` if `likelihood.predict_logvar` is `None`.
|
|
128
|
+
Any auxillary data included in the input will also be include with all of the
|
|
129
|
+
MMSE prediction, the standard deviation, and the log-variance.
|
|
130
|
+
"""
|
|
131
|
+
if mmse_count <= 0:
|
|
132
|
+
raise ValueError("MMSE count must be greater than zero.")
|
|
133
|
+
|
|
134
|
+
x: torch.Tensor
|
|
135
|
+
aux: list[Any]
|
|
136
|
+
x, *aux = input
|
|
137
|
+
|
|
138
|
+
input_shape = x.shape
|
|
139
|
+
output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
|
|
140
|
+
log_var: torch.Tensor | None = None
|
|
141
|
+
# pre-declare empty array to fill with individual sample predictions
|
|
142
|
+
sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
|
|
143
|
+
for mmse_idx in range(mmse_count):
|
|
144
|
+
sample_prediction, lv = lvae_predict_single_sample(
|
|
145
|
+
model=model, likelihood_obj=likelihood_obj, input=x
|
|
146
|
+
)
|
|
147
|
+
# only keep the log variance of the first sample prediction
|
|
148
|
+
if mmse_idx == 0:
|
|
149
|
+
log_var = lv
|
|
150
|
+
|
|
151
|
+
# store sample predictions
|
|
152
|
+
sample_predictions[mmse_idx, ...] = sample_prediction
|
|
153
|
+
|
|
154
|
+
mmse_prediction = torch.mean(sample_predictions, dim=0)
|
|
155
|
+
mmse_prediction_std = torch.std(sample_predictions, dim=0)
|
|
156
|
+
|
|
157
|
+
log_var_output = (log_var, *aux) if log_var is not None else None
|
|
158
|
+
return (mmse_prediction, *aux), (mmse_prediction_std, *aux), log_var_output
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
"""Module contiaing tiling manager class."""
|
|
2
|
+
|
|
3
|
+
# # TODO: remove this file, left as a reference for now.
|
|
4
|
+
|
|
5
|
+
# from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
# import numpy as np
|
|
8
|
+
# from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
# from careamics.config.tile_information import TileInformation
|
|
11
|
+
# from careamics.config.validators import check_axes_validity
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# def calculate_padding(
|
|
15
|
+
# patch_start_location: NDArray,
|
|
16
|
+
# patch_size: NDArray,
|
|
17
|
+
# data_shape: NDArray,
|
|
18
|
+
# ) -> NDArray:
|
|
19
|
+
# patch_end_location = patch_start_location + patch_size
|
|
20
|
+
|
|
21
|
+
# pad_before = np.zeros_like(patch_start_location)
|
|
22
|
+
# start_out_of_bounds = patch_start_location < 0
|
|
23
|
+
# pad_before[start_out_of_bounds] = -patch_start_location[start_out_of_bounds]
|
|
24
|
+
|
|
25
|
+
# pad_after = np.zeros_like(patch_start_location)
|
|
26
|
+
# end_out_of_bounds = patch_end_location > data_shape
|
|
27
|
+
# pad_after[end_out_of_bounds] = (
|
|
28
|
+
# patch_end_location - data_shape
|
|
29
|
+
# )[end_out_of_bounds]
|
|
30
|
+
|
|
31
|
+
# return np.stack([pad_before, pad_after], axis=1)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# def extract_tile(
|
|
35
|
+
# img: np.ndarray,
|
|
36
|
+
# grid_start_loc: tuple[int, ...],
|
|
37
|
+
# patch_size: tuple[int, ...],
|
|
38
|
+
# overlap: tuple[int, ...],
|
|
39
|
+
# padding: bool,
|
|
40
|
+
# padding_kwargs: Optional[dict[str, Any]] = None,
|
|
41
|
+
# ) -> NDArray:
|
|
42
|
+
# if padding_kwargs is None:
|
|
43
|
+
# padding_kwargs = {}
|
|
44
|
+
|
|
45
|
+
# data_shape = img.shape
|
|
46
|
+
# patch_start_loc = np.array(grid_start_loc) - np.array(overlap) // 2
|
|
47
|
+
# crop_slices = tuple(
|
|
48
|
+
# slice(max(0, start), min(start + size, dim_shape))
|
|
49
|
+
# for start, size, dim_shape in zip(patch_start_loc, patch_size, data_shape)
|
|
50
|
+
# )
|
|
51
|
+
# crop = img[crop_slices]
|
|
52
|
+
# if padding:
|
|
53
|
+
# pad = calculate_padding(
|
|
54
|
+
# patch_start_location=patch_start_loc,
|
|
55
|
+
# patch_size=patch_size,
|
|
56
|
+
# data_shape=data_shape,
|
|
57
|
+
# )
|
|
58
|
+
# crop = np.pad(crop, pad, **padding_kwargs)
|
|
59
|
+
|
|
60
|
+
# return crop
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# class TilingManager:
|
|
64
|
+
|
|
65
|
+
# def __init__(
|
|
66
|
+
# self,
|
|
67
|
+
# data_shape: tuple[int, ...],
|
|
68
|
+
# tile_size: tuple[int, ...],
|
|
69
|
+
# overlaps: tuple[int, ...],
|
|
70
|
+
# trim_boundary: tuple[int, ...],
|
|
71
|
+
# ):
|
|
72
|
+
# # --- validation
|
|
73
|
+
# if len(data_shape) != len(tile_size):
|
|
74
|
+
# raise ValueError(
|
|
75
|
+
# f"Data shape:{data_shape} and tile size:{tile_size} must have the "
|
|
76
|
+
# "same dimension"
|
|
77
|
+
# )
|
|
78
|
+
# if len(data_shape) != len(overlaps):
|
|
79
|
+
# raise ValueError(
|
|
80
|
+
# f"Data shape:{data_shape} and tile overlaps:{overlaps} must have the "
|
|
81
|
+
# "same dimension"
|
|
82
|
+
# )
|
|
83
|
+
# # overlaps = np.array(tile_size) - np.array(grid_shape)
|
|
84
|
+
# if (np.array(overlaps) < 0).any():
|
|
85
|
+
# raise ValueError(
|
|
86
|
+
# "Tile overlap must be positive or zero in all dimension."
|
|
87
|
+
# )
|
|
88
|
+
# if ((np.array(overlaps) % 2) != 0).any():
|
|
89
|
+
# # TODO: currently not required by CAREamics tiling,
|
|
90
|
+
# # -> because floor divide is used.
|
|
91
|
+
# raise ValueError("Tile overlaps must be even.")
|
|
92
|
+
|
|
93
|
+
# # initialize attributes
|
|
94
|
+
# self.data_shape = data_shape
|
|
95
|
+
# self.overlaps = overlaps
|
|
96
|
+
# self.grid_shape = tuple(np.array(tile_size) - np.array(overlaps))
|
|
97
|
+
# self.patch_shape = tile_size
|
|
98
|
+
# self.trim_boundary = trim_boundary
|
|
99
|
+
|
|
100
|
+
# def compute_tile_info(self, index: int, axes: str):
|
|
101
|
+
|
|
102
|
+
# # TODO: better axis validation, data should already be in the form SC(Z)YX
|
|
103
|
+
|
|
104
|
+
# # validate axes
|
|
105
|
+
# check_axes_validity(axes)
|
|
106
|
+
# # z will be -1 if not present
|
|
107
|
+
# spatial_axes = [axes.find("Z"), axes.find("Y"), axes.find("X")]
|
|
108
|
+
|
|
109
|
+
# # convert to numpy for convenience
|
|
110
|
+
# data_shape = np.array(self.data_shape)
|
|
111
|
+
# patch_shape = np.array(self.patch_shape)
|
|
112
|
+
|
|
113
|
+
# # --- calculate stitch coords
|
|
114
|
+
# stitch_coords_start = np.array(self.get_location_from_dataset_idx(index))
|
|
115
|
+
# stitch_coords_end = stitch_coords_start + np.array(self.grid_shape)
|
|
116
|
+
|
|
117
|
+
# # --- patch coords
|
|
118
|
+
# patch_coords_start = stitch_coords_start - np.array(self.overlaps) // 2
|
|
119
|
+
# patch_coords_end = patch_coords_start + patch_shape
|
|
120
|
+
|
|
121
|
+
# # --- replace out of bounds indices
|
|
122
|
+
|
|
123
|
+
# out_of_lower_bound = stitch_coords_start < 0
|
|
124
|
+
# out_of_upper_bound = stitch_coords_end > data_shape
|
|
125
|
+
|
|
126
|
+
# stitch_coords_start[out_of_lower_bound] = 0
|
|
127
|
+
# stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
|
|
128
|
+
|
|
129
|
+
# # --- calculate overlap crop coords
|
|
130
|
+
# overlap_crop_coords_start = stitch_coords_start - patch_coords_start
|
|
131
|
+
# overlap_crop_coords_end = overlap_crop_coords_start + (
|
|
132
|
+
# stitch_coords_end - stitch_coords_start
|
|
133
|
+
# )
|
|
134
|
+
|
|
135
|
+
# # --- combine start and end
|
|
136
|
+
# stitch_coords = tuple(
|
|
137
|
+
# (stitch_coords_start[axis], stitch_coords_end[axis])
|
|
138
|
+
# for axis in spatial_axes
|
|
139
|
+
# if axis != -1
|
|
140
|
+
# )
|
|
141
|
+
# overlap_crop_coords = tuple(
|
|
142
|
+
# (overlap_crop_coords_start[axis], overlap_crop_coords_end[axis])
|
|
143
|
+
# for axis in spatial_axes
|
|
144
|
+
# if axis != -1
|
|
145
|
+
# )
|
|
146
|
+
|
|
147
|
+
# channel_axis = axes.find("C")
|
|
148
|
+
# array_shape_processed = tuple(
|
|
149
|
+
# data_shape[axis] for axis in [channel_axis, *spatial_axes] if axis != -1
|
|
150
|
+
# )
|
|
151
|
+
|
|
152
|
+
# tile_info = TileInformation(
|
|
153
|
+
# array_shape=array_shape_processed,
|
|
154
|
+
# last_tile=index == self.total_grid_count() - 1,
|
|
155
|
+
# overlap_crop_coords=overlap_crop_coords,
|
|
156
|
+
# stitch_coords=stitch_coords,
|
|
157
|
+
# sample_id=0, # TODO: in iterable dataset this is also always 0 pretty sure
|
|
158
|
+
# )
|
|
159
|
+
# return tile_info
|
|
160
|
+
|
|
161
|
+
# def patch_offset(self):
|
|
162
|
+
# return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
163
|
+
|
|
164
|
+
# def get_individual_dim_grid_count(self, dim: int):
|
|
165
|
+
# """
|
|
166
|
+
# Returns the number of the grid in the specified dimension, ignoring all other
|
|
167
|
+
# dimensions.
|
|
168
|
+
# """
|
|
169
|
+
# assert dim < len(
|
|
170
|
+
# self.data_shape
|
|
171
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
172
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
173
|
+
|
|
174
|
+
# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
175
|
+
# return self.data_shape[dim]
|
|
176
|
+
# elif self.trim_boundary is False:
|
|
177
|
+
# return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
|
178
|
+
# else:
|
|
179
|
+
# excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
180
|
+
# return int(
|
|
181
|
+
# np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
182
|
+
# )
|
|
183
|
+
|
|
184
|
+
# def total_grid_count(self):
|
|
185
|
+
# """
|
|
186
|
+
# Returns the total number of grids in the dataset.
|
|
187
|
+
# """
|
|
188
|
+
# return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
|
189
|
+
|
|
190
|
+
# def grid_count(self, dim: int):
|
|
191
|
+
# """
|
|
192
|
+
# Returns the total number of grids for one value in the specified dimension.
|
|
193
|
+
# """
|
|
194
|
+
# assert dim < len(
|
|
195
|
+
# self.data_shape
|
|
196
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
197
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
198
|
+
# if dim == len(self.data_shape) - 1:
|
|
199
|
+
# return 1
|
|
200
|
+
|
|
201
|
+
# return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
|
|
202
|
+
|
|
203
|
+
# def get_grid_index(self, dim: int, coordinate: int):
|
|
204
|
+
# """
|
|
205
|
+
# Returns the index of the grid in the specified dimension.
|
|
206
|
+
# """
|
|
207
|
+
# assert dim < len(
|
|
208
|
+
# self.data_shape
|
|
209
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
210
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
211
|
+
# assert (
|
|
212
|
+
# coordinate < self.data_shape[dim]
|
|
213
|
+
# ), (
|
|
214
|
+
# f"Coordinate {coordinate} is out of bounds for data "
|
|
215
|
+
# f"shape {self.data_shape}"
|
|
216
|
+
# )
|
|
217
|
+
# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
218
|
+
# return coordinate
|
|
219
|
+
# elif self.trim_boundary is False:
|
|
220
|
+
# return np.floor(coordinate / self.grid_shape[dim])
|
|
221
|
+
# else:
|
|
222
|
+
# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
223
|
+
# # can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
224
|
+
# return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
225
|
+
|
|
226
|
+
# def dataset_idx_from_grid_idx(self, grid_idx: tuple):
|
|
227
|
+
# """
|
|
228
|
+
# Returns the index of the grid in the dataset.
|
|
229
|
+
# """
|
|
230
|
+
# assert len(grid_idx) == len(
|
|
231
|
+
# self.data_shape
|
|
232
|
+
# ), (
|
|
233
|
+
# f"Dimension indices {grid_idx} must have the same dimension as data "
|
|
234
|
+
# f"shape {self.data_shape}"
|
|
235
|
+
# )
|
|
236
|
+
# index = 0
|
|
237
|
+
# for dim in range(len(grid_idx)):
|
|
238
|
+
# index += grid_idx[dim] * self.grid_count(dim)
|
|
239
|
+
# return index
|
|
240
|
+
|
|
241
|
+
# def get_patch_location_from_dataset_idx(self, dataset_idx: int):
|
|
242
|
+
# """
|
|
243
|
+
# Returns the patch location of the grid in the dataset.
|
|
244
|
+
# """
|
|
245
|
+
# location = self.get_location_from_dataset_idx(dataset_idx)
|
|
246
|
+
# offset = self.patch_offset()
|
|
247
|
+
# return tuple(np.array(location) - np.array(offset))
|
|
248
|
+
|
|
249
|
+
# def get_dataset_idx_from_grid_location(self, location: tuple):
|
|
250
|
+
# assert len(location) == len(
|
|
251
|
+
# self.data_shape
|
|
252
|
+
# ), (
|
|
253
|
+
# f"Location {location} must have the same dimension as data shape "
|
|
254
|
+
# f"{self.data_shape}"
|
|
255
|
+
# )
|
|
256
|
+
# grid_idx = [
|
|
257
|
+
# self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
258
|
+
# ]
|
|
259
|
+
# return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
|
260
|
+
|
|
261
|
+
# def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
|
|
262
|
+
# """
|
|
263
|
+
# Returns the grid-start coordinate of the grid in the specified dimension.
|
|
264
|
+
# """
|
|
265
|
+
# assert dim < len(
|
|
266
|
+
# self.data_shape
|
|
267
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
268
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
269
|
+
# assert dim_index < self.get_individual_dim_grid_count(
|
|
270
|
+
# dim
|
|
271
|
+
# ), (
|
|
272
|
+
# f"Dimension index {dim_index} is out of bounds for data shape "
|
|
273
|
+
# f"{self.data_shape}"
|
|
274
|
+
# )
|
|
275
|
+
|
|
276
|
+
# if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
277
|
+
# return dim_index
|
|
278
|
+
# elif self.trim_boundary is False:
|
|
279
|
+
# return dim_index * self.grid_shape[dim]
|
|
280
|
+
# else:
|
|
281
|
+
# excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
282
|
+
# return dim_index * self.grid_shape[dim] + excess_size
|
|
283
|
+
|
|
284
|
+
# def get_location_from_dataset_idx(self, dataset_idx: int):
|
|
285
|
+
# grid_idx = []
|
|
286
|
+
# for dim in range(len(self.data_shape)):
|
|
287
|
+
# grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
288
|
+
# dataset_idx = dataset_idx % self.grid_count(dim)
|
|
289
|
+
# location = [
|
|
290
|
+
# self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
291
|
+
# for dim in range(len(self.data_shape))
|
|
292
|
+
# ]
|
|
293
|
+
# return tuple(location)
|
|
294
|
+
|
|
295
|
+
# def on_boundary(self, dataset_idx: int, dim: int):
|
|
296
|
+
# """
|
|
297
|
+
# Returns True if the grid is on the boundary in the specified dimension.
|
|
298
|
+
# """
|
|
299
|
+
# assert dim < len(
|
|
300
|
+
# self.data_shape
|
|
301
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
302
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
303
|
+
|
|
304
|
+
# if dim > 0:
|
|
305
|
+
# dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
306
|
+
|
|
307
|
+
# dim_index = dataset_idx // self.grid_count(dim)
|
|
308
|
+
# return (
|
|
309
|
+
# dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
310
|
+
# )
|
|
311
|
+
|
|
312
|
+
# def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
313
|
+
# """
|
|
314
|
+
# Returns the index of the grid in the specified dimension in the specified "
|
|
315
|
+
# "direction.
|
|
316
|
+
# """
|
|
317
|
+
# assert dim < len(
|
|
318
|
+
# self.data_shape
|
|
319
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
320
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
321
|
+
# new_idx = dataset_idx + self.grid_count(dim)
|
|
322
|
+
# if new_idx >= self.total_grid_count():
|
|
323
|
+
# return None
|
|
324
|
+
# return new_idx
|
|
325
|
+
|
|
326
|
+
# def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
327
|
+
# """
|
|
328
|
+
# Returns the index of the grid in the specified dimension in the specified "
|
|
329
|
+
# "direction.
|
|
330
|
+
# """
|
|
331
|
+
# assert dim < len(
|
|
332
|
+
# self.data_shape
|
|
333
|
+
# ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
334
|
+
# assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
335
|
+
# new_idx = dataset_idx - self.grid_count(dim)
|
|
336
|
+
# if new_idx < 0:
|
|
337
|
+
# return None
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
# if __name__ == "__main__":
|
|
341
|
+
# data_shape = (1, 1, 103, 103, 2)
|
|
342
|
+
# grid_shape = (1, 1, 16, 16, 2)
|
|
343
|
+
# patch_shape = (1, 1, 32, 32, 2)
|
|
344
|
+
# overlap = tuple(np.array(patch_shape) - np.array(grid_shape))
|
|
345
|
+
|
|
346
|
+
# trim_boundary = False
|
|
347
|
+
# manager = TilingManager(
|
|
348
|
+
# data_shape=data_shape,
|
|
349
|
+
# tile_size=patch_shape,
|
|
350
|
+
# overlaps=overlap,
|
|
351
|
+
# trim_boundary=trim_boundary,
|
|
352
|
+
# )
|
|
353
|
+
# gc = manager.total_grid_count()
|
|
354
|
+
# print("Grid count", gc)
|
|
355
|
+
# for i in range(gc):
|
|
356
|
+
# loc = manager.get_location_from_dataset_idx(i)
|
|
357
|
+
# print(i, loc)
|
|
358
|
+
# inferred_i = manager.get_dataset_idx_from_grid_location(loc)
|
|
359
|
+
# assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}"
|
|
360
|
+
|
|
361
|
+
# for i in range(5):
|
|
362
|
+
# print(manager.on_boundary(40, i))
|