careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Script for utility functions needed by the LVAE model.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Literal, Sequence
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
import torchvision.transforms.functional as F
|
|
11
|
+
from torch.distributions.normal import Normal
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def torch_nanmean(inp):
|
|
15
|
+
return torch.mean(inp[~inp.isnan()])
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def power_of_2(self, x):
|
|
19
|
+
assert isinstance(x, int)
|
|
20
|
+
if x == 1:
|
|
21
|
+
return True
|
|
22
|
+
if x == 0:
|
|
23
|
+
# happens with validation
|
|
24
|
+
return False
|
|
25
|
+
if x % 2 == 1:
|
|
26
|
+
return False
|
|
27
|
+
return self.power_of_2(x // 2)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Enum:
|
|
31
|
+
@classmethod
|
|
32
|
+
def name(cls, enum_type):
|
|
33
|
+
for key, value in cls.__dict__.items():
|
|
34
|
+
if enum_type == value:
|
|
35
|
+
return key
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def contains(cls, enum_type):
|
|
39
|
+
for key, value in cls.__dict__.items():
|
|
40
|
+
if enum_type == value:
|
|
41
|
+
return True
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_name(cls, enum_type_str):
|
|
46
|
+
for key, value in cls.__dict__.items():
|
|
47
|
+
if key == enum_type_str:
|
|
48
|
+
return value
|
|
49
|
+
assert f"{cls.__name__}:{enum_type_str} doesnot exist."
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class LossType(Enum):
|
|
53
|
+
Elbo = 0
|
|
54
|
+
ElboWithCritic = 1
|
|
55
|
+
ElboMixedReconstruction = 2
|
|
56
|
+
MSE = 3
|
|
57
|
+
ElboWithNbrConsistency = 4
|
|
58
|
+
ElboSemiSupMixedReconstruction = 5
|
|
59
|
+
ElboCL = 6
|
|
60
|
+
ElboRestrictedReconstruction = 7
|
|
61
|
+
DenoiSplitMuSplit = 8
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ModelType(Enum):
|
|
65
|
+
LadderVae = 3
|
|
66
|
+
LadderVaeTwinDecoder = 4
|
|
67
|
+
LadderVAECritic = 5
|
|
68
|
+
# Separate vampprior: two optimizers
|
|
69
|
+
LadderVaeSepVampprior = 6
|
|
70
|
+
# one encoder for mixed input, two for separate inputs.
|
|
71
|
+
LadderVaeSepEncoder = 7
|
|
72
|
+
LadderVAEMultiTarget = 8
|
|
73
|
+
LadderVaeSepEncoderSingleOptim = 9
|
|
74
|
+
UNet = 10
|
|
75
|
+
BraveNet = 11
|
|
76
|
+
LadderVaeStitch = 12
|
|
77
|
+
LadderVaeSemiSupervised = 13
|
|
78
|
+
LadderVaeStitch2Stage = 14 # Note that previously trained models will have issue.
|
|
79
|
+
# since earlier, LadderVaeStitch2Stage = 13, LadderVaeSemiSupervised = 14
|
|
80
|
+
LadderVaeMixedRecons = 15
|
|
81
|
+
LadderVaeCL = 16
|
|
82
|
+
LadderVaeTwoDataSet = (
|
|
83
|
+
17 # on one subdset, apply disentanglement, on other apply reconstruction
|
|
84
|
+
)
|
|
85
|
+
LadderVaeTwoDatasetMultiBranch = 18
|
|
86
|
+
LadderVaeTwoDatasetMultiOptim = 19
|
|
87
|
+
LVaeDeepEncoderIntensityAug = 20
|
|
88
|
+
AutoRegresiveLadderVAE = 21
|
|
89
|
+
LadderVAEInterleavedOptimization = 22
|
|
90
|
+
Denoiser = 23
|
|
91
|
+
DenoiserSplitter = 24
|
|
92
|
+
SplitterDenoiser = 25
|
|
93
|
+
LadderVAERestrictedReconstruction = 26
|
|
94
|
+
LadderVAETwoDataSetRestRecon = 27
|
|
95
|
+
LadderVAETwoDataSetFinetuning = 28
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _pad_crop_img(
|
|
99
|
+
x: torch.Tensor, size: Sequence[int], mode: Literal["crop", "pad"]
|
|
100
|
+
) -> torch.Tensor:
|
|
101
|
+
"""Pads or crops a tensor.
|
|
102
|
+
|
|
103
|
+
Pads or crops a tensor of shape (B, C, [Z], Y, X) to new shape.
|
|
104
|
+
|
|
105
|
+
Parameters:
|
|
106
|
+
-----------
|
|
107
|
+
x: torch.Tensor
|
|
108
|
+
Input image of shape (B, C, [Z], Y, X)
|
|
109
|
+
size: Sequence[int]
|
|
110
|
+
Desired size ([Z*], Y*, X*)
|
|
111
|
+
mode: Literal["crop", "pad"]
|
|
112
|
+
Mode, either 'pad' or 'crop'
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
--------
|
|
116
|
+
torch.Tensor:
|
|
117
|
+
The padded or cropped tensor
|
|
118
|
+
"""
|
|
119
|
+
# TODO: Support cropping/padding on selected dimensions
|
|
120
|
+
assert (x.dim() == 4 and len(size) == 2) or (x.dim() == 5 and len(size) == 3)
|
|
121
|
+
|
|
122
|
+
size = tuple(size)
|
|
123
|
+
x_size = x.size()[2:]
|
|
124
|
+
|
|
125
|
+
if mode == "pad":
|
|
126
|
+
cond = any(x_size[i] > size[i] for i in range(len(size)))
|
|
127
|
+
elif mode == "crop":
|
|
128
|
+
cond = any(x_size[i] < size[i] for i in range(len(size)))
|
|
129
|
+
|
|
130
|
+
if cond:
|
|
131
|
+
raise ValueError(f"Trying to {mode} from size {x_size} to size {size}")
|
|
132
|
+
|
|
133
|
+
diffs = [abs(x - s) for x, s in zip(x_size, size)]
|
|
134
|
+
d1 = [d // 2 for d in diffs]
|
|
135
|
+
d2 = [d - (d // 2) for d in diffs]
|
|
136
|
+
|
|
137
|
+
if mode == "pad":
|
|
138
|
+
if x.dim() == 4:
|
|
139
|
+
padding = [d1[1], d2[1], d1[0], d2[0], 0, 0, 0, 0]
|
|
140
|
+
elif x.dim() == 5:
|
|
141
|
+
padding = [d1[2], d2[2], d1[1], d2[1], d1[0], d2[0], 0, 0, 0, 0]
|
|
142
|
+
return nn.functional.pad(x, padding)
|
|
143
|
+
elif mode == "crop":
|
|
144
|
+
if x.dim() == 4:
|
|
145
|
+
return x[:, :, d1[0] : (x_size[0] - d2[0]), d1[1] : (x_size[1] - d2[1])]
|
|
146
|
+
elif x.dim() == 5:
|
|
147
|
+
return x[
|
|
148
|
+
:,
|
|
149
|
+
:,
|
|
150
|
+
d1[0] : (x_size[0] - d2[0]),
|
|
151
|
+
d1[1] : (x_size[1] - d2[1]),
|
|
152
|
+
d1[2] : (x_size[2] - d2[2]),
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def pad_img_tensor(x: torch.Tensor, size: Sequence[int]) -> torch.Tensor:
|
|
157
|
+
"""Pads a tensor
|
|
158
|
+
|
|
159
|
+
Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.
|
|
160
|
+
|
|
161
|
+
Parameters:
|
|
162
|
+
-----------
|
|
163
|
+
x (torch.Tensor): Input image of shape (B, C, [Z], Y, X)
|
|
164
|
+
size (list or tuple): Desired size ([Z*], Y*, X*)
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
--------
|
|
168
|
+
The padded tensor
|
|
169
|
+
"""
|
|
170
|
+
return _pad_crop_img(x, size, "pad")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def crop_img_tensor(x, size) -> torch.Tensor:
|
|
174
|
+
"""Crops a tensor.
|
|
175
|
+
Crops a tensor of shape (batch, channels, h, w) to a desired height and width
|
|
176
|
+
given by a tuple.
|
|
177
|
+
Args:
|
|
178
|
+
x (torch.Tensor): Input image
|
|
179
|
+
size (list or tuple): Desired size (height, width)
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
The cropped tensor
|
|
184
|
+
"""
|
|
185
|
+
return _pad_crop_img(x, size, "crop")
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class StableExponential:
|
|
189
|
+
"""
|
|
190
|
+
Class that redefines the definition of exp() to increase numerical stability.
|
|
191
|
+
Naturally, also the definition of log() must change accordingly.
|
|
192
|
+
However, it is worth noting that the two operations remain one the inverse of the other,
|
|
193
|
+
meaning that x = log(exp(x)) and x = exp(log(x)) are always true.
|
|
194
|
+
|
|
195
|
+
Definition:
|
|
196
|
+
exp(x) = {
|
|
197
|
+
exp(x) if x<=0
|
|
198
|
+
x+1 if x>0
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
log(x) = {
|
|
202
|
+
x if x<=0
|
|
203
|
+
log(1+x) if x>0
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
NOTE 1:
|
|
207
|
+
Within the class everything is done on the tensor given as input to the constructor.
|
|
208
|
+
Therefore, when exp() is called, self._tensor.exp() is computed.
|
|
209
|
+
When log() is called, torch.log(self._tensor.exp()) is computed instead.
|
|
210
|
+
|
|
211
|
+
NOTE 2:
|
|
212
|
+
Given the output from exp(), torch.log() or the log() method of the class give identical results.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(self, tensor):
|
|
216
|
+
self._raw_tensor = tensor
|
|
217
|
+
posneg_dic = self.posneg_separation(self._raw_tensor)
|
|
218
|
+
self.pos_f, self.neg_f = posneg_dic["filter"]
|
|
219
|
+
self.pos_data, self.neg_data = posneg_dic["value"]
|
|
220
|
+
|
|
221
|
+
def posneg_separation(self, tensor):
|
|
222
|
+
pos = tensor > 0
|
|
223
|
+
pos_tensor = torch.clip(tensor, min=0)
|
|
224
|
+
|
|
225
|
+
neg = tensor <= 0
|
|
226
|
+
neg_tensor = torch.clip(tensor, max=0)
|
|
227
|
+
|
|
228
|
+
return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]}
|
|
229
|
+
|
|
230
|
+
def exp(self):
|
|
231
|
+
return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f
|
|
232
|
+
|
|
233
|
+
def log(self):
|
|
234
|
+
return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class StableLogVar:
|
|
238
|
+
"""
|
|
239
|
+
Class that provides a numerically stable implementation of Log-Variance.
|
|
240
|
+
Specifically, it uses the exp() and log() formulas defined in `StableExponential` class.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(
|
|
244
|
+
self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Constructor.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
logvar: torch.Tensor
|
|
252
|
+
The input (true) logvar vector, to be converted in the Stable version.
|
|
253
|
+
enable_stable: bool, optional
|
|
254
|
+
Whether to compute the stable version of log-variance. Default is `True`.
|
|
255
|
+
var_eps: float, optional
|
|
256
|
+
The minimum value attainable by the variance. Default is `1e-6`.
|
|
257
|
+
"""
|
|
258
|
+
self._lv = logvar
|
|
259
|
+
self._enable_stable = enable_stable
|
|
260
|
+
self._eps = var_eps
|
|
261
|
+
|
|
262
|
+
def get(self) -> torch.Tensor:
|
|
263
|
+
if self._enable_stable is False:
|
|
264
|
+
return self._lv
|
|
265
|
+
|
|
266
|
+
return torch.log(self.get_var())
|
|
267
|
+
|
|
268
|
+
def get_var(self) -> torch.Tensor:
|
|
269
|
+
"""
|
|
270
|
+
Get Variance from Log-Variance.
|
|
271
|
+
"""
|
|
272
|
+
if self._enable_stable is False:
|
|
273
|
+
return torch.exp(self._lv)
|
|
274
|
+
return StableExponential(self._lv).exp() + self._eps
|
|
275
|
+
|
|
276
|
+
def get_std(self) -> torch.Tensor:
|
|
277
|
+
return torch.sqrt(self.get_var())
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def is_3D(self) -> bool:
|
|
281
|
+
"""Check if the _lv tensor is 3D.
|
|
282
|
+
|
|
283
|
+
Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
|
|
284
|
+
"""
|
|
285
|
+
return self._lv.dim() == 5
|
|
286
|
+
|
|
287
|
+
def centercrop_to_size(self, size: Sequence[int]) -> None:
|
|
288
|
+
"""
|
|
289
|
+
Centercrop the log-variance tensor to the desired size.
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
size: torch.Tensor
|
|
294
|
+
The desired size of the log-variance tensor.
|
|
295
|
+
"""
|
|
296
|
+
assert not self.is_3D, "Centercrop is implemented only for 2D tensors."
|
|
297
|
+
|
|
298
|
+
if self._lv.shape[-1] == size:
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
diff = self._lv.shape[-1] - size
|
|
302
|
+
assert diff > 0 and diff % 2 == 0
|
|
303
|
+
self._lv = F.center_crop(self._lv, (size, size))
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class StableMean:
|
|
307
|
+
|
|
308
|
+
def __init__(self, mean):
|
|
309
|
+
self._mean = mean
|
|
310
|
+
|
|
311
|
+
def get(self) -> torch.Tensor:
|
|
312
|
+
return self._mean
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def is_3D(self) -> bool:
|
|
316
|
+
"""Check if the _mean tensor is 3D.
|
|
317
|
+
|
|
318
|
+
Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
|
|
319
|
+
"""
|
|
320
|
+
return self._mean.dim() == 5
|
|
321
|
+
|
|
322
|
+
def centercrop_to_size(self, size: Sequence[int]) -> None:
|
|
323
|
+
"""Centercrop the mean tensor to the desired size.
|
|
324
|
+
|
|
325
|
+
Implemented only in the case of 2D tensors.
|
|
326
|
+
|
|
327
|
+
Parameters
|
|
328
|
+
----------
|
|
329
|
+
size: torch.Tensor
|
|
330
|
+
The desired size of the log-variance tensor.
|
|
331
|
+
"""
|
|
332
|
+
assert not self.is_3D, "Centercrop is implemented only for 2D tensors."
|
|
333
|
+
|
|
334
|
+
if self._mean.shape[-1] == size:
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
diff = self._mean.shape[-1] - size
|
|
338
|
+
assert diff > 0 and diff % 2 == 0
|
|
339
|
+
self._mean = F.center_crop(self._mean, (size, size))
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def allow_numpy(func):
|
|
343
|
+
"""
|
|
344
|
+
All optional arguments are passed as is. positional arguments are checked. if they are numpy array,
|
|
345
|
+
they are converted to torch Tensor.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
def numpy_wrapper(*args, **kwargs):
|
|
349
|
+
new_args = []
|
|
350
|
+
for arg in args:
|
|
351
|
+
if isinstance(arg, np.ndarray):
|
|
352
|
+
arg = torch.Tensor(arg)
|
|
353
|
+
new_args.append(arg)
|
|
354
|
+
new_args = tuple(new_args)
|
|
355
|
+
|
|
356
|
+
output = func(*new_args, **kwargs)
|
|
357
|
+
return output
|
|
358
|
+
|
|
359
|
+
return numpy_wrapper
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class Interpolate(nn.Module):
|
|
363
|
+
"""Wrapper for torch.nn.functional.interpolate."""
|
|
364
|
+
|
|
365
|
+
def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False):
|
|
366
|
+
super().__init__()
|
|
367
|
+
assert (size is None) == (scale is not None)
|
|
368
|
+
self.size = size
|
|
369
|
+
self.scale = scale
|
|
370
|
+
self.mode = mode
|
|
371
|
+
self.align_corners = align_corners
|
|
372
|
+
|
|
373
|
+
def forward(self, x):
|
|
374
|
+
out = F.interpolate(
|
|
375
|
+
x,
|
|
376
|
+
size=self.size,
|
|
377
|
+
scale_factor=self.scale,
|
|
378
|
+
mode=self.mode,
|
|
379
|
+
align_corners=self.align_corners,
|
|
380
|
+
)
|
|
381
|
+
return out
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def kl_normal_mc(z, p_mulv, q_mulv):
|
|
385
|
+
"""
|
|
386
|
+
One-sample estimation of element-wise KL between two diagonal
|
|
387
|
+
multivariate normal distributions. Any number of dimensions,
|
|
388
|
+
broadcasting supported (be careful).
|
|
389
|
+
:param z:
|
|
390
|
+
:param p_mulv:
|
|
391
|
+
:param q_mulv:
|
|
392
|
+
:return:
|
|
393
|
+
"""
|
|
394
|
+
assert isinstance(p_mulv, tuple)
|
|
395
|
+
assert isinstance(q_mulv, tuple)
|
|
396
|
+
p_mu, p_lv = p_mulv
|
|
397
|
+
q_mu, q_lv = q_mulv
|
|
398
|
+
|
|
399
|
+
p_std = p_lv.get_std()
|
|
400
|
+
q_std = q_lv.get_std()
|
|
401
|
+
|
|
402
|
+
p_distrib = Normal(p_mu.get(), p_std)
|
|
403
|
+
q_distrib = Normal(q_mu.get(), q_std)
|
|
404
|
+
return q_distrib.log_prob(z) - p_distrib.log_prob(z)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Model creation factory functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedArchitecture
|
|
10
|
+
from careamics.models.lvae import LadderVAE as LVAE
|
|
11
|
+
from careamics.models.unet import UNet
|
|
12
|
+
from careamics.utils import get_logger
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from careamics.config.architectures import (
|
|
16
|
+
LVAEConfig,
|
|
17
|
+
UNetConfig,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def model_factory(
|
|
25
|
+
model_configuration: Union[UNetConfig, LVAEConfig],
|
|
26
|
+
) -> torch.nn.Module:
|
|
27
|
+
"""
|
|
28
|
+
Deep learning model factory.
|
|
29
|
+
|
|
30
|
+
Supported models are defined in careamics.config.SupportedArchitecture.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
model_configuration : Union[UNetModel, VAEModel]
|
|
35
|
+
Model configuration.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
torch.nn.Module
|
|
40
|
+
Model class.
|
|
41
|
+
|
|
42
|
+
Raises
|
|
43
|
+
------
|
|
44
|
+
NotImplementedError
|
|
45
|
+
If the requested architecture is not implemented.
|
|
46
|
+
"""
|
|
47
|
+
if model_configuration.architecture == SupportedArchitecture.UNET:
|
|
48
|
+
return UNet(**model_configuration.model_dump())
|
|
49
|
+
elif model_configuration.architecture == SupportedArchitecture.LVAE:
|
|
50
|
+
return LVAE(**model_configuration.model_dump())
|
|
51
|
+
else:
|
|
52
|
+
raise NotImplementedError(
|
|
53
|
+
f"Model {model_configuration.architecture} is not implemented or unknown."
|
|
54
|
+
)
|