careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss submodule.
|
|
3
|
+
|
|
4
|
+
This submodule contains the various losses used in CAREamics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.nn import L1Loss, MSELoss
|
|
9
|
+
|
|
10
|
+
from careamics.models.lvae.noise_models import GaussianMixtureNoiseModel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def mse_loss(source: torch.Tensor, target: torch.Tensor, *args) -> torch.Tensor:
|
|
14
|
+
"""
|
|
15
|
+
Mean squared error loss.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
source : torch.Tensor
|
|
20
|
+
Source patches.
|
|
21
|
+
target : torch.Tensor
|
|
22
|
+
Target patches.
|
|
23
|
+
*args : Any
|
|
24
|
+
Additional arguments.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
torch.Tensor
|
|
29
|
+
Loss value.
|
|
30
|
+
"""
|
|
31
|
+
loss = MSELoss()
|
|
32
|
+
return loss(source, target)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def n2v_loss(
|
|
36
|
+
manipulated_batch: torch.Tensor,
|
|
37
|
+
original_batch: torch.Tensor,
|
|
38
|
+
masks: torch.Tensor,
|
|
39
|
+
*args,
|
|
40
|
+
) -> torch.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
N2V Loss function described in A Krull et al 2018.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
manipulated_batch : torch.Tensor
|
|
47
|
+
Batch after manipulation function applied.
|
|
48
|
+
original_batch : torch.Tensor
|
|
49
|
+
Original images.
|
|
50
|
+
masks : torch.Tensor
|
|
51
|
+
Coordinates of changed pixels.
|
|
52
|
+
*args : Any
|
|
53
|
+
Additional arguments.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
torch.Tensor
|
|
58
|
+
Loss value.
|
|
59
|
+
"""
|
|
60
|
+
errors = (original_batch - manipulated_batch) ** 2
|
|
61
|
+
# Average over pixels and batch
|
|
62
|
+
loss = torch.sum(errors * masks) / torch.sum(masks)
|
|
63
|
+
return loss # TODO change output to dict ?
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def pn2v_loss(
|
|
67
|
+
samples: torch.Tensor,
|
|
68
|
+
labels: torch.Tensor,
|
|
69
|
+
masks: torch.Tensor,
|
|
70
|
+
noise_model: GaussianMixtureNoiseModel,
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Probabilistic N2V loss function described in A Krull et al., CVF (2019).
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
samples : torch.Tensor # TODO this naming is confusing
|
|
78
|
+
Predicted pixel values from the network.
|
|
79
|
+
labels : torch.Tensor
|
|
80
|
+
Original pixel values.
|
|
81
|
+
masks : torch.Tensor
|
|
82
|
+
Coordinates of manipulated pixels.
|
|
83
|
+
noise_model : GaussianMixtureNoiseModel
|
|
84
|
+
Noise model for computing likelihood.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
torch.Tensor
|
|
89
|
+
Loss value.
|
|
90
|
+
"""
|
|
91
|
+
likelihoods = noise_model.likelihood(labels, samples)
|
|
92
|
+
likelihoods_avg = torch.log(torch.mean(likelihoods, dim=1, keepdim=True))
|
|
93
|
+
|
|
94
|
+
# Average over pixels and batch
|
|
95
|
+
loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
|
|
96
|
+
return loss
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def mae_loss(samples: torch.Tensor, labels: torch.Tensor, *args) -> torch.Tensor:
|
|
100
|
+
"""
|
|
101
|
+
N2N Loss function described in to J Lehtinen et al 2018.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
samples : torch.Tensor
|
|
106
|
+
Raw patches.
|
|
107
|
+
labels : torch.Tensor
|
|
108
|
+
Different subset of noisy patches.
|
|
109
|
+
*args : Any
|
|
110
|
+
Additional arguments.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
torch.Tensor
|
|
115
|
+
Loss value.
|
|
116
|
+
"""
|
|
117
|
+
loss = L1Loss()
|
|
118
|
+
return loss(samples, labels)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# def dice_loss(
|
|
122
|
+
# samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
|
|
123
|
+
# ) -> torch.Tensor:
|
|
124
|
+
# """Dice loss function."""
|
|
125
|
+
# return DiceLoss(mode=mode)(samples, labels.long())
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loss factory module.
|
|
3
|
+
|
|
4
|
+
This module contains a factory function for creating loss functions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Union
|
|
12
|
+
|
|
13
|
+
from torch import Tensor as tensor
|
|
14
|
+
|
|
15
|
+
from ..config.support import SupportedLoss
|
|
16
|
+
from .fcn.losses import mae_loss, mse_loss, n2v_loss, pn2v_loss
|
|
17
|
+
from .lvae.losses import (
|
|
18
|
+
denoisplit_loss,
|
|
19
|
+
denoisplit_musplit_loss,
|
|
20
|
+
hdn_loss,
|
|
21
|
+
musplit_loss,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class FCNLossParameters:
|
|
27
|
+
"""Dataclass for FCN loss."""
|
|
28
|
+
|
|
29
|
+
# TODO check
|
|
30
|
+
prediction: tensor
|
|
31
|
+
targets: tensor
|
|
32
|
+
mask: tensor
|
|
33
|
+
current_epoch: int
|
|
34
|
+
loss_weight: float
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
38
|
+
"""Return loss function.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
loss : Union[SupportedLoss, str]
|
|
43
|
+
Requested loss.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
Callable
|
|
48
|
+
Loss function.
|
|
49
|
+
|
|
50
|
+
Raises
|
|
51
|
+
------
|
|
52
|
+
NotImplementedError
|
|
53
|
+
If the loss is unknown.
|
|
54
|
+
"""
|
|
55
|
+
if loss == SupportedLoss.N2V:
|
|
56
|
+
return n2v_loss
|
|
57
|
+
|
|
58
|
+
elif loss == SupportedLoss.PN2V:
|
|
59
|
+
return pn2v_loss
|
|
60
|
+
|
|
61
|
+
elif loss == SupportedLoss.MAE:
|
|
62
|
+
return mae_loss
|
|
63
|
+
|
|
64
|
+
elif loss == SupportedLoss.MSE:
|
|
65
|
+
return mse_loss
|
|
66
|
+
|
|
67
|
+
elif loss == SupportedLoss.HDN:
|
|
68
|
+
return hdn_loss
|
|
69
|
+
|
|
70
|
+
elif loss == SupportedLoss.MUSPLIT:
|
|
71
|
+
return musplit_loss
|
|
72
|
+
|
|
73
|
+
elif loss == SupportedLoss.DENOISPLIT:
|
|
74
|
+
return denoisplit_loss
|
|
75
|
+
|
|
76
|
+
elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
|
|
77
|
+
return denoisplit_musplit_loss
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
raise NotImplementedError(f"Loss {loss} is not yet supported.")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""LVAE losses."""
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def free_bits_kl(
|
|
5
|
+
kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
|
|
6
|
+
) -> torch.Tensor:
|
|
7
|
+
"""Compute free-bits version of KL divergence.
|
|
8
|
+
|
|
9
|
+
This function ensures that the KL doesn't go to zero for any latent dimension.
|
|
10
|
+
Hence, it contributes to use latent variables more efficiently, leading to
|
|
11
|
+
better representation learning.
|
|
12
|
+
|
|
13
|
+
NOTE:
|
|
14
|
+
Takes in the KL with shape (batch size, layers), returns the KL with
|
|
15
|
+
free bits (for optimization) with shape (layers,), which is the average
|
|
16
|
+
free-bits KL per layer in the current batch.
|
|
17
|
+
If batch_average is False (default), the free bits are per layer and
|
|
18
|
+
per batch element. Otherwise, the free bits are still per layer, but
|
|
19
|
+
are assigned on average to the whole batch. In both cases, the batch
|
|
20
|
+
average is returned, so it's simply a matter of doing mean(clamp(KL))
|
|
21
|
+
or clamp(mean(KL)).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
kl : torch.Tensor
|
|
26
|
+
The KL divergence tensor with shape (batch size, layers).
|
|
27
|
+
free_bits : float
|
|
28
|
+
The free bits value. Set to 0.0 to disable free bits.
|
|
29
|
+
batch_average : bool
|
|
30
|
+
Whether to average over the batch before clamping to `free_bits`.
|
|
31
|
+
eps : float
|
|
32
|
+
A small value to avoid numerical instability.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
torch.Tensor
|
|
37
|
+
The free-bits version of the KL divergence with shape (layers,).
|
|
38
|
+
"""
|
|
39
|
+
assert kl.dim() == 2
|
|
40
|
+
if free_bits < eps:
|
|
41
|
+
return kl.mean(0)
|
|
42
|
+
if batch_average:
|
|
43
|
+
return kl.mean(0).clamp(min=free_bits)
|
|
44
|
+
return kl.clamp(min=free_bits).mean(0)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_kl_weight(
|
|
48
|
+
kl_annealing: bool,
|
|
49
|
+
kl_start: int,
|
|
50
|
+
kl_annealtime: int,
|
|
51
|
+
kl_weight: float,
|
|
52
|
+
current_epoch: int,
|
|
53
|
+
) -> float:
|
|
54
|
+
"""Compute the weight of the KL loss in case of annealing.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
kl_annealing : bool
|
|
59
|
+
Whether to use KL annealing.
|
|
60
|
+
kl_start : int
|
|
61
|
+
The epoch at which to start
|
|
62
|
+
kl_annealtime : int
|
|
63
|
+
The number of epochs for which annealing is applied.
|
|
64
|
+
kl_weight : float
|
|
65
|
+
The weight for the KL loss. If `None`, the weight is computed
|
|
66
|
+
using annealing, else it is set to a default of 1.
|
|
67
|
+
current_epoch : int
|
|
68
|
+
The current epoch.
|
|
69
|
+
"""
|
|
70
|
+
if kl_annealing:
|
|
71
|
+
# calculate relative weight
|
|
72
|
+
kl_weight = (current_epoch - kl_start) * (1.0 / kl_annealtime)
|
|
73
|
+
# clamp to [0,1]
|
|
74
|
+
kl_weight = min(max(0.0, kl_weight), 1.0)
|
|
75
|
+
|
|
76
|
+
# if the final weight is given, then apply that weight on top of it
|
|
77
|
+
if kl_weight is not None:
|
|
78
|
+
kl_weight = kl_weight * kl_weight
|
|
79
|
+
elif kl_weight is not None:
|
|
80
|
+
return kl_weight
|
|
81
|
+
else:
|
|
82
|
+
kl_weight = 1.0
|
|
83
|
+
return kl_weight
|