careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
13
|
+
|
|
14
|
+
# TODO this module shouldn't be in lvae folder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_histogram(
|
|
18
|
+
bins: int, min_val: float, max_val: float, observation: NDArray, signal: NDArray
|
|
19
|
+
) -> NDArray:
|
|
20
|
+
"""
|
|
21
|
+
Creates a 2D histogram from 'observation' and 'signal'.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
bins : int
|
|
26
|
+
Number of bins in x and y.
|
|
27
|
+
min_val : float
|
|
28
|
+
Lower bound of the lowest bin in x and y.
|
|
29
|
+
max_val : float
|
|
30
|
+
Upper bound of the highest bin in x and y.
|
|
31
|
+
observation : np.ndarray
|
|
32
|
+
3D numpy array (stack of 2D images).
|
|
33
|
+
Observation.shape[0] must be divisible by signal.shape[0].
|
|
34
|
+
Assumes that n subsequent images in observation belong to one image in 'signal'.
|
|
35
|
+
signal : np.ndarray
|
|
36
|
+
3D numpy array (stack of 2D images).
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
histogram : np.ndarray
|
|
41
|
+
A 3D array:
|
|
42
|
+
- histogram[0]: Normalized 2D counts.
|
|
43
|
+
- histogram[1]: Lower boundaries of bins along y.
|
|
44
|
+
- histogram[2]: Upper boundaries of bins along y.
|
|
45
|
+
The values for x can be obtained by transposing 'histogram[1]' and 'histogram[2]'.
|
|
46
|
+
"""
|
|
47
|
+
histogram = np.zeros((3, bins, bins))
|
|
48
|
+
|
|
49
|
+
value_range = [min_val, max_val]
|
|
50
|
+
|
|
51
|
+
# Compute mapping factor between observation and signal samples
|
|
52
|
+
obs_to_signal_shape_factor = int(observation.shape[0] / signal.shape[0])
|
|
53
|
+
|
|
54
|
+
# Flatten arrays and align signal values
|
|
55
|
+
signal_indices = np.arange(observation.shape[0]) // obs_to_signal_shape_factor
|
|
56
|
+
signal_values = signal[signal_indices].ravel()
|
|
57
|
+
observation_values = observation.ravel()
|
|
58
|
+
|
|
59
|
+
count_histogram, signal_edges, _ = np.histogram2d(
|
|
60
|
+
signal_values, observation_values, bins=bins, range=[value_range, value_range]
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Normalize rows to obtain probabilities
|
|
64
|
+
row_sums = count_histogram.sum(axis=1, keepdims=True)
|
|
65
|
+
count_histogram /= np.clip(row_sums, a_min=1e-20, a_max=None)
|
|
66
|
+
|
|
67
|
+
histogram[0] = count_histogram
|
|
68
|
+
histogram[1] = signal_edges[:-1][..., np.newaxis]
|
|
69
|
+
histogram[2] = signal_edges[1:][..., np.newaxis]
|
|
70
|
+
|
|
71
|
+
return histogram
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def noise_model_factory(
|
|
75
|
+
model_config: Optional[GaussianMixtureNMConfig],
|
|
76
|
+
) -> Optional[GaussianMixtureNoiseModel]:
|
|
77
|
+
"""Noise model factory for single-channel noise models.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
model_config : Optional[GaussianMixtureNMConfig]
|
|
82
|
+
Noise model configuration for a single Gaussian mixture noise model.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
Optional[GaussianMixtureNoiseModel]
|
|
87
|
+
A single noise model instance, or None if no config is provided.
|
|
88
|
+
|
|
89
|
+
Raises
|
|
90
|
+
------
|
|
91
|
+
NotImplementedError
|
|
92
|
+
If the chosen noise model `model_type` is not implemented.
|
|
93
|
+
Currently only `GaussianMixtureNoiseModel` is implemented.
|
|
94
|
+
"""
|
|
95
|
+
if model_config:
|
|
96
|
+
if model_config.path:
|
|
97
|
+
if model_config.model_type == "GaussianMixtureNoiseModel":
|
|
98
|
+
return GaussianMixtureNoiseModel(model_config)
|
|
99
|
+
else:
|
|
100
|
+
raise NotImplementedError(
|
|
101
|
+
f"Model {model_config.model_type} is not implemented"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# TODO this is outdated and likely should be removed !!
|
|
105
|
+
else: # TODO this means signal/obs are provided. Controlled in pydantic model
|
|
106
|
+
# TODO train a new model. Config should always be provided?
|
|
107
|
+
if model_config.model_type == "GaussianMixtureNoiseModel":
|
|
108
|
+
# TODO one model for each channel all make this choise inside the model?
|
|
109
|
+
# trained_nm = train_gm_noise_model(model_config)
|
|
110
|
+
# return trained_nm
|
|
111
|
+
raise NotImplementedError(
|
|
112
|
+
"GaussianMixtureNoiseModel model training is not implemented."
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
raise NotImplementedError(
|
|
116
|
+
f"Model {model_config.model_type} is not implemented"
|
|
117
|
+
)
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def multichannel_noise_model_factory(
|
|
122
|
+
model_config: Optional[MultiChannelNMConfig],
|
|
123
|
+
) -> Optional[MultiChannelNoiseModel]:
|
|
124
|
+
"""Multi-channel noise model factory.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
model_config : Optional[MultiChannelNMConfig]
|
|
129
|
+
Noise model configuration, a `MultiChannelNMConfig` config that defines
|
|
130
|
+
noise models for the different output channels.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
Optional[MultiChannelNoiseModel]
|
|
135
|
+
A noise model instance.
|
|
136
|
+
|
|
137
|
+
Raises
|
|
138
|
+
------
|
|
139
|
+
NotImplementedError
|
|
140
|
+
If the chosen noise model `model_type` is not implemented.
|
|
141
|
+
Currently only `GaussianMixtureNoiseModel` is implemented.
|
|
142
|
+
"""
|
|
143
|
+
if model_config:
|
|
144
|
+
noise_models = []
|
|
145
|
+
for nm in model_config.noise_models:
|
|
146
|
+
if nm.path:
|
|
147
|
+
if nm.model_type == "GaussianMixtureNoiseModel":
|
|
148
|
+
noise_models.append(GaussianMixtureNoiseModel(nm))
|
|
149
|
+
else:
|
|
150
|
+
raise NotImplementedError(
|
|
151
|
+
f"Model {nm.model_type} is not implemented"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# TODO this is outdated and likely should be removed !!
|
|
155
|
+
else: # TODO this means signal/obs are provided. Controlled in pydantic model
|
|
156
|
+
# TODO train a new model. Config should always be provided?
|
|
157
|
+
if nm.model_type == "GaussianMixtureNoiseModel":
|
|
158
|
+
# TODO one model for each channel all make this choise inside the model?
|
|
159
|
+
# trained_nm = train_gm_noise_model(nm)
|
|
160
|
+
# noise_models.append(trained_nm)
|
|
161
|
+
raise NotImplementedError(
|
|
162
|
+
"GaussianMixtureNoiseModel model training is not implemented."
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
raise NotImplementedError(
|
|
166
|
+
f"Model {nm.model_type} is not implemented"
|
|
167
|
+
)
|
|
168
|
+
return MultiChannelNoiseModel(noise_models)
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def train_gm_noise_model(
|
|
173
|
+
model_config: GaussianMixtureNMConfig,
|
|
174
|
+
signal: np.ndarray,
|
|
175
|
+
observation: np.ndarray,
|
|
176
|
+
) -> GaussianMixtureNoiseModel:
|
|
177
|
+
"""Train a Gaussian mixture noise model.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
model_config : GaussianMixtureNoiseModel
|
|
182
|
+
_description_
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
_description_
|
|
187
|
+
"""
|
|
188
|
+
# TODO where to put train params?
|
|
189
|
+
# TODO any training params ? Different channels ?
|
|
190
|
+
noise_model = GaussianMixtureNoiseModel(model_config)
|
|
191
|
+
# TODO revisit config unpacking
|
|
192
|
+
noise_model.fit(signal, observation)
|
|
193
|
+
return noise_model
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class MultiChannelNoiseModel(nn.Module):
|
|
197
|
+
def __init__(self, nmodels: list[GaussianMixtureNoiseModel]):
|
|
198
|
+
"""Constructor.
|
|
199
|
+
|
|
200
|
+
To handle noise models and the relative likelihood computation for multiple
|
|
201
|
+
output channels (e.g., muSplit, denoiseSplit).
|
|
202
|
+
|
|
203
|
+
This class:
|
|
204
|
+
- receives as input a variable number of noise models, one for each channel.
|
|
205
|
+
- computes the likelihood of observations given signals for each channel.
|
|
206
|
+
- returns the concatenation of these likelihoods.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
nmodels : list[GaussianMixtureNoiseModel]
|
|
211
|
+
List of noise models, one for each output channel.
|
|
212
|
+
"""
|
|
213
|
+
super().__init__()
|
|
214
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
215
|
+
|
|
216
|
+
for i, nmodel in enumerate(nmodels): # TODO refactor this !!!
|
|
217
|
+
if nmodel is not None:
|
|
218
|
+
self.add_module(
|
|
219
|
+
f"nmodel_{i}", nmodel
|
|
220
|
+
) # TODO: wouldn't be easier to use a list?
|
|
221
|
+
|
|
222
|
+
self._nm_cnt = 0
|
|
223
|
+
for nmodel in nmodels:
|
|
224
|
+
if nmodel is not None:
|
|
225
|
+
self._nm_cnt += 1
|
|
226
|
+
|
|
227
|
+
print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
|
|
228
|
+
|
|
229
|
+
def to_device(self, device: torch.device):
|
|
230
|
+
self.device = device
|
|
231
|
+
self.to(device)
|
|
232
|
+
for ch_idx in range(self._nm_cnt):
|
|
233
|
+
nmodel = getattr(self, f"nmodel_{ch_idx}")
|
|
234
|
+
nmodel.to_device(device)
|
|
235
|
+
|
|
236
|
+
def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
|
|
237
|
+
"""Compute the likelihood of observations given signals for each channel.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
obs : torch.Tensor
|
|
242
|
+
Noisy observations, i.e., the target(s). Specifically, the input noisy
|
|
243
|
+
image for HDN, or the noisy unmixed images used for supervision
|
|
244
|
+
for denoiSplit. Shape: (B, C, [Z], Y, X), where C is the number of
|
|
245
|
+
unmixed channels.
|
|
246
|
+
signal : torch.Tensor
|
|
247
|
+
Underlying signals, i.e., the (clean) output of the model. Specifically, the
|
|
248
|
+
denoised image for HDN, or the unmixed images for denoiSplit.
|
|
249
|
+
Shape: (B, C, [Z], Y, X), where C is the number of unmixed channels.
|
|
250
|
+
"""
|
|
251
|
+
# Case 1: obs and signal have a single channel (e.g., denoising)
|
|
252
|
+
if obs.shape[1] == 1:
|
|
253
|
+
assert signal.shape[1] == 1
|
|
254
|
+
return self.nmodel_0.likelihood(obs, signal)
|
|
255
|
+
|
|
256
|
+
# Case 2: obs and signal have multiple channels (e.g., denoiSplit)
|
|
257
|
+
assert obs.shape[1] == self._nm_cnt, (
|
|
258
|
+
"The number of channels in `obs` must match the number of noise models."
|
|
259
|
+
f" Got instead: obs={obs.shape[1]}, nm={self._nm_cnt}"
|
|
260
|
+
)
|
|
261
|
+
ll_list = []
|
|
262
|
+
for ch_idx in range(obs.shape[1]):
|
|
263
|
+
nmodel = getattr(self, f"nmodel_{ch_idx}")
|
|
264
|
+
ll_list.append(
|
|
265
|
+
nmodel.likelihood(
|
|
266
|
+
obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
|
|
267
|
+
) # slicing to keep the channel dimension
|
|
268
|
+
)
|
|
269
|
+
return torch.cat(ll_list, dim=1)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class GaussianMixtureNoiseModel(nn.Module):
|
|
273
|
+
"""Define a noise model parameterized as a mixture of gaussians.
|
|
274
|
+
|
|
275
|
+
If `config.path` is not provided a new object is initialized from scratch.
|
|
276
|
+
Otherwise, a model is loaded from `config.path`.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
config : GaussianMixtureNMConfig
|
|
281
|
+
A `pydantic` model that defines the configuration of the GMM noise model.
|
|
282
|
+
|
|
283
|
+
Attributes
|
|
284
|
+
----------
|
|
285
|
+
min_signal : float
|
|
286
|
+
Minimum signal intensity expected in the image.
|
|
287
|
+
max_signal : float
|
|
288
|
+
Maximum signal intensity expected in the image.
|
|
289
|
+
path: Union[str, Path]
|
|
290
|
+
Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
|
|
291
|
+
weight : torch.nn.Parameter
|
|
292
|
+
A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
293
|
+
describing the GMM noise model, with each row corresponding to one
|
|
294
|
+
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
295
|
+
Specifically, rows are organized as follows:
|
|
296
|
+
- first n_gaussian rows correspond to the means
|
|
297
|
+
- next n_gaussian rows correspond to the weights
|
|
298
|
+
- last n_gaussian rows correspond to the standard deviations
|
|
299
|
+
If `weight=None`, the weight array is initialized using the `min_signal`
|
|
300
|
+
and `max_signal` parameters.
|
|
301
|
+
n_gaussian: int
|
|
302
|
+
Number of gaussians in the mixture.
|
|
303
|
+
n_coeff: int
|
|
304
|
+
Number of coefficients to describe the functional relationship between gaussian
|
|
305
|
+
parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
|
|
306
|
+
relationship and so on.
|
|
307
|
+
device: device
|
|
308
|
+
GPU device.
|
|
309
|
+
min_sigma: float
|
|
310
|
+
All values of `standard deviation` below this are clamped to this value.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
# TODO training a NM relies on getting a clean data(N2V e.g,)
|
|
314
|
+
def __init__(self, config: GaussianMixtureNMConfig) -> None:
|
|
315
|
+
super().__init__()
|
|
316
|
+
self.device = torch.device("cpu")
|
|
317
|
+
|
|
318
|
+
if config.path is not None:
|
|
319
|
+
params = np.load(config.path)
|
|
320
|
+
else:
|
|
321
|
+
params = config.model_dump(exclude_none=True)
|
|
322
|
+
|
|
323
|
+
min_sigma = torch.tensor(params["min_sigma"])
|
|
324
|
+
min_signal = torch.tensor(params["min_signal"])
|
|
325
|
+
max_signal = torch.tensor(params["max_signal"])
|
|
326
|
+
self.register_buffer("min_signal", min_signal)
|
|
327
|
+
self.register_buffer("max_signal", max_signal)
|
|
328
|
+
self.register_buffer("min_sigma", min_sigma)
|
|
329
|
+
self.register_buffer("tolerance", torch.tensor([1e-10]))
|
|
330
|
+
|
|
331
|
+
if "trained_weight" in params:
|
|
332
|
+
weight = torch.tensor(params["trained_weight"])
|
|
333
|
+
elif "weight" in params and params["weight"] is not None:
|
|
334
|
+
weight = torch.tensor(params["weight"])
|
|
335
|
+
else:
|
|
336
|
+
weight = self._initialize_weights(
|
|
337
|
+
params["n_gaussian"], params["n_coeff"], max_signal, min_signal
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
self.n_gaussian = weight.shape[0] // 3
|
|
341
|
+
self.n_coeff = weight.shape[1]
|
|
342
|
+
|
|
343
|
+
self.register_parameter("weight", nn.Parameter(weight))
|
|
344
|
+
self._set_model_mode(mode="prediction")
|
|
345
|
+
|
|
346
|
+
print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
|
|
347
|
+
|
|
348
|
+
def _initialize_weights(
|
|
349
|
+
self,
|
|
350
|
+
n_gaussian: int,
|
|
351
|
+
n_coeff: int,
|
|
352
|
+
max_signal: torch.Tensor,
|
|
353
|
+
min_signal: torch.Tensor,
|
|
354
|
+
) -> torch.Tensor:
|
|
355
|
+
"""Create random weight initialization."""
|
|
356
|
+
weight = torch.randn(n_gaussian * 3, n_coeff)
|
|
357
|
+
weight[n_gaussian : 2 * n_gaussian, 1] = torch.log(
|
|
358
|
+
max_signal - min_signal
|
|
359
|
+
).float()
|
|
360
|
+
return weight
|
|
361
|
+
|
|
362
|
+
def to_device(self, device: torch.device):
|
|
363
|
+
self.device = device
|
|
364
|
+
self.to(device)
|
|
365
|
+
|
|
366
|
+
def _set_model_mode(self, mode: str) -> None:
|
|
367
|
+
"""Move parameters to the device and set weights' requires_grad depending on the mode"""
|
|
368
|
+
if mode == "train":
|
|
369
|
+
self.weight.requires_grad = True
|
|
370
|
+
else:
|
|
371
|
+
self.weight.requires_grad = False
|
|
372
|
+
|
|
373
|
+
def polynomial_regressor(
|
|
374
|
+
self, weight_params: torch.Tensor, signals: torch.Tensor
|
|
375
|
+
) -> torch.Tensor:
|
|
376
|
+
"""Combines `weight_params` and signal `signals` to regress for the gaussian parameter values.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
weight_params : Tensor
|
|
381
|
+
Corresponds to specific rows of the `self.weight`
|
|
382
|
+
|
|
383
|
+
signals : Tensor
|
|
384
|
+
Signals
|
|
385
|
+
|
|
386
|
+
Returns
|
|
387
|
+
-------
|
|
388
|
+
value : Tensor
|
|
389
|
+
Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
|
|
390
|
+
"""
|
|
391
|
+
value = torch.zeros_like(signals)
|
|
392
|
+
device = (
|
|
393
|
+
value.device
|
|
394
|
+
) # TODO the whole device handling in this class needs to be refactored
|
|
395
|
+
weight_params = weight_params.to(device)
|
|
396
|
+
self.min_signal = self.min_signal.to(device)
|
|
397
|
+
self.max_signal = self.max_signal.to(device)
|
|
398
|
+
for i in range(weight_params.shape[0]):
|
|
399
|
+
value += weight_params[i] * (
|
|
400
|
+
((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
|
|
401
|
+
)
|
|
402
|
+
return value
|
|
403
|
+
|
|
404
|
+
def normal_density(
|
|
405
|
+
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
|
|
406
|
+
) -> torch.Tensor:
|
|
407
|
+
"""
|
|
408
|
+
Evaluates the normal probability density at `x` given the mean `mean` and standard deviation `std`.
|
|
409
|
+
|
|
410
|
+
Parameters
|
|
411
|
+
----------
|
|
412
|
+
x: torch.Tensor
|
|
413
|
+
The ground-truth tensor. Shape is (batch, 1, dim1, dim2).
|
|
414
|
+
mean: torch.Tensor
|
|
415
|
+
The inferred mean of distribution. Shape is (batch, 1, dim1, dim2).
|
|
416
|
+
std: torch.Tensor
|
|
417
|
+
The inferred standard deviation of distribution. Shape is (batch, 1, dim1, dim2).
|
|
418
|
+
|
|
419
|
+
Returns
|
|
420
|
+
-------
|
|
421
|
+
tmp: torch.Tensor
|
|
422
|
+
Normal probability density of `x` given `mean` and `std`
|
|
423
|
+
"""
|
|
424
|
+
tmp = -((x - mean) ** 2)
|
|
425
|
+
tmp = tmp / (2.0 * std * std)
|
|
426
|
+
tmp = torch.exp(tmp)
|
|
427
|
+
tmp = tmp / torch.sqrt((2.0 * np.pi) * std * std)
|
|
428
|
+
return tmp
|
|
429
|
+
|
|
430
|
+
def likelihood(
|
|
431
|
+
self, observations: torch.Tensor, signals: torch.Tensor
|
|
432
|
+
) -> torch.Tensor:
|
|
433
|
+
"""
|
|
434
|
+
Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
|
|
435
|
+
|
|
436
|
+
Parameters
|
|
437
|
+
----------
|
|
438
|
+
observations : Tensor
|
|
439
|
+
Noisy observations. Shape is (batch, 1, dim1, dim2).
|
|
440
|
+
signals : Tensor
|
|
441
|
+
Underlying signals. Shape is (batch, 1, dim1, dim2).
|
|
442
|
+
|
|
443
|
+
Returns
|
|
444
|
+
-------
|
|
445
|
+
value: torch.Tensor:
|
|
446
|
+
Likelihood of observations given the signals and the GMM noise model
|
|
447
|
+
"""
|
|
448
|
+
observations = observations.float()
|
|
449
|
+
signals = signals.float()
|
|
450
|
+
gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
|
|
451
|
+
p = 0 # torch.zeros_like(observations)
|
|
452
|
+
for gaussian in range(self.n_gaussian):
|
|
453
|
+
# Ensure all tensors have compatible shapes
|
|
454
|
+
mean = gaussian_parameters[gaussian]
|
|
455
|
+
std = gaussian_parameters[self.n_gaussian + gaussian]
|
|
456
|
+
weight = gaussian_parameters[2 * self.n_gaussian + gaussian]
|
|
457
|
+
|
|
458
|
+
# Compute normal density
|
|
459
|
+
p += (
|
|
460
|
+
self.normal_density(
|
|
461
|
+
observations,
|
|
462
|
+
mean,
|
|
463
|
+
std,
|
|
464
|
+
)
|
|
465
|
+
* weight
|
|
466
|
+
)
|
|
467
|
+
return p + self.tolerance
|
|
468
|
+
|
|
469
|
+
def get_gaussian_parameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
|
|
470
|
+
"""
|
|
471
|
+
Returns the noise model for given signals
|
|
472
|
+
|
|
473
|
+
Parameters
|
|
474
|
+
----------
|
|
475
|
+
signals : Tensor
|
|
476
|
+
Underlying signals
|
|
477
|
+
|
|
478
|
+
Returns
|
|
479
|
+
-------
|
|
480
|
+
noise_model: list of Tensor
|
|
481
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
482
|
+
"""
|
|
483
|
+
noise_model = []
|
|
484
|
+
mu = []
|
|
485
|
+
sigma = []
|
|
486
|
+
alpha = []
|
|
487
|
+
kernels = self.weight.shape[0] // 3
|
|
488
|
+
device = signals.device
|
|
489
|
+
self.min_signal = self.min_signal.to(device)
|
|
490
|
+
self.max_signal = self.max_signal.to(device)
|
|
491
|
+
self.min_sigma = self.min_sigma.to(device)
|
|
492
|
+
self.tolerance = self.tolerance.to(device)
|
|
493
|
+
for num in range(kernels):
|
|
494
|
+
mu.append(self.polynomial_regressor(self.weight[num, :], signals))
|
|
495
|
+
expval = torch.exp(self.weight[kernels + num, :])
|
|
496
|
+
sigma_temp = self.polynomial_regressor(expval, signals)
|
|
497
|
+
sigma_temp = torch.clamp(sigma_temp, min=self.min_sigma)
|
|
498
|
+
sigma.append(torch.sqrt(sigma_temp))
|
|
499
|
+
|
|
500
|
+
expval = torch.exp(
|
|
501
|
+
self.polynomial_regressor(self.weight[2 * kernels + num, :], signals)
|
|
502
|
+
+ self.tolerance
|
|
503
|
+
)
|
|
504
|
+
alpha.append(expval)
|
|
505
|
+
|
|
506
|
+
sum_alpha = 0
|
|
507
|
+
for al in range(kernels):
|
|
508
|
+
sum_alpha = alpha[al] + sum_alpha
|
|
509
|
+
|
|
510
|
+
# sum of alpha is forced to be 1.
|
|
511
|
+
for ker in range(kernels):
|
|
512
|
+
alpha[ker] = alpha[ker] / sum_alpha
|
|
513
|
+
|
|
514
|
+
sum_means = 0
|
|
515
|
+
# sum_means is the alpha weighted average of the means
|
|
516
|
+
for ker in range(kernels):
|
|
517
|
+
sum_means = alpha[ker] * mu[ker] + sum_means
|
|
518
|
+
|
|
519
|
+
# subtracting the alpha weighted average of the means from the means
|
|
520
|
+
# ensures that the GMM has the inclination to have the mean=signals.
|
|
521
|
+
# its like a residual conection. I don't understand why we need to learn the mean?
|
|
522
|
+
for ker in range(kernels):
|
|
523
|
+
mu[ker] = mu[ker] - sum_means + signals
|
|
524
|
+
|
|
525
|
+
for i in range(kernels):
|
|
526
|
+
noise_model.append(mu[i])
|
|
527
|
+
for j in range(kernels):
|
|
528
|
+
noise_model.append(sigma[j])
|
|
529
|
+
for k in range(kernels):
|
|
530
|
+
noise_model.append(alpha[k])
|
|
531
|
+
|
|
532
|
+
return noise_model
|
|
533
|
+
|
|
534
|
+
@staticmethod
|
|
535
|
+
def _fast_shuffle(series: torch.Tensor, num: int) -> torch.Tensor:
|
|
536
|
+
"""Shuffle the inputs randomly num times"""
|
|
537
|
+
length = series.shape[0]
|
|
538
|
+
for _ in range(num):
|
|
539
|
+
idx = torch.randperm(length)
|
|
540
|
+
series = series[idx, :]
|
|
541
|
+
return series
|
|
542
|
+
|
|
543
|
+
def get_signal_observation_pairs(
|
|
544
|
+
self,
|
|
545
|
+
signal: NDArray,
|
|
546
|
+
observation: NDArray,
|
|
547
|
+
lower_clip: float,
|
|
548
|
+
upper_clip: float,
|
|
549
|
+
) -> torch.Tensor:
|
|
550
|
+
"""Returns the Signal-Observation pixel intensities as a two-column array
|
|
551
|
+
|
|
552
|
+
Parameters
|
|
553
|
+
----------
|
|
554
|
+
signal : numpy array
|
|
555
|
+
Clean Signal Data
|
|
556
|
+
observation: numpy array
|
|
557
|
+
Noisy observation Data
|
|
558
|
+
lower_clip: float
|
|
559
|
+
Lower percentile bound for clipping.
|
|
560
|
+
upper_clip: float
|
|
561
|
+
Upper percentile bound for clipping.
|
|
562
|
+
|
|
563
|
+
Returns
|
|
564
|
+
-------
|
|
565
|
+
noise_model: list of torch floats
|
|
566
|
+
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
|
|
567
|
+
"""
|
|
568
|
+
lb = np.percentile(signal, lower_clip)
|
|
569
|
+
ub = np.percentile(signal, upper_clip)
|
|
570
|
+
stepsize = observation[0].size
|
|
571
|
+
n_observations = observation.shape[0]
|
|
572
|
+
n_signals = signal.shape[0]
|
|
573
|
+
sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
|
|
574
|
+
|
|
575
|
+
for i in range(n_observations):
|
|
576
|
+
j = i // (n_observations // n_signals)
|
|
577
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
|
|
578
|
+
sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
|
|
579
|
+
sig_obs_pairs = sig_obs_pairs[
|
|
580
|
+
(sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
|
|
581
|
+
]
|
|
582
|
+
sig_obs_pairs = sig_obs_pairs.astype(np.float32)
|
|
583
|
+
sig_obs_pairs = torch.from_numpy(sig_obs_pairs)
|
|
584
|
+
return self._fast_shuffle(sig_obs_pairs, 2)
|
|
585
|
+
|
|
586
|
+
def fit(
|
|
587
|
+
self,
|
|
588
|
+
signal: NDArray,
|
|
589
|
+
observation: NDArray,
|
|
590
|
+
learning_rate: float = 1e-1,
|
|
591
|
+
batch_size: int = 250000,
|
|
592
|
+
n_epochs: int = 2000,
|
|
593
|
+
lower_clip: float = 0.0,
|
|
594
|
+
upper_clip: float = 100.0,
|
|
595
|
+
) -> list[float]:
|
|
596
|
+
"""Training to learn the noise model from signal - observation pairs.
|
|
597
|
+
|
|
598
|
+
Parameters
|
|
599
|
+
----------
|
|
600
|
+
signal: numpy array
|
|
601
|
+
Clean Signal Data
|
|
602
|
+
observation: numpy array
|
|
603
|
+
Noisy Observation Data
|
|
604
|
+
learning_rate: float
|
|
605
|
+
Learning rate. Default = 1e-1.
|
|
606
|
+
batch_size: int
|
|
607
|
+
Nini-batch size. Default = 250000.
|
|
608
|
+
n_epochs: int
|
|
609
|
+
Number of epochs. Default = 2000.
|
|
610
|
+
lower_clip : int
|
|
611
|
+
Lower percentile for clipping. Default is 0.
|
|
612
|
+
upper_clip : int
|
|
613
|
+
Upper percentile for clipping. Default is 100.
|
|
614
|
+
"""
|
|
615
|
+
self._set_model_mode(mode="train")
|
|
616
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
617
|
+
self.to_device(device)
|
|
618
|
+
optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
|
|
619
|
+
|
|
620
|
+
sig_obs_pairs = self.get_signal_observation_pairs(
|
|
621
|
+
signal, observation, lower_clip, upper_clip
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
train_losses = []
|
|
625
|
+
counter = 0
|
|
626
|
+
for t in range(n_epochs):
|
|
627
|
+
if (counter + 1) * batch_size >= sig_obs_pairs.shape[0]:
|
|
628
|
+
counter = 0
|
|
629
|
+
sig_obs_pairs = self._fast_shuffle(sig_obs_pairs, 1)
|
|
630
|
+
|
|
631
|
+
batch_vectors = sig_obs_pairs[
|
|
632
|
+
counter * batch_size : (counter + 1) * batch_size, :
|
|
633
|
+
]
|
|
634
|
+
observations = batch_vectors[:, 1].to(self.device)
|
|
635
|
+
signals = batch_vectors[:, 0].to(self.device)
|
|
636
|
+
|
|
637
|
+
p = self.likelihood(observations, signals)
|
|
638
|
+
|
|
639
|
+
joint_loss = torch.mean(-torch.log(p))
|
|
640
|
+
train_losses.append(joint_loss.item())
|
|
641
|
+
|
|
642
|
+
if self.weight.isnan().any() or self.weight.isinf().any():
|
|
643
|
+
print(
|
|
644
|
+
"NaN or Inf detected in the weights. Aborting training at epoch: ",
|
|
645
|
+
t,
|
|
646
|
+
)
|
|
647
|
+
break
|
|
648
|
+
|
|
649
|
+
if t % 100 == 0:
|
|
650
|
+
last_losses = train_losses[-100:]
|
|
651
|
+
print(t, np.mean(last_losses))
|
|
652
|
+
|
|
653
|
+
optimizer.zero_grad()
|
|
654
|
+
joint_loss.backward()
|
|
655
|
+
optimizer.step()
|
|
656
|
+
counter += 1
|
|
657
|
+
|
|
658
|
+
self._set_model_mode(mode="prediction")
|
|
659
|
+
self.to_device(torch.device("cpu"))
|
|
660
|
+
print("===================\n")
|
|
661
|
+
return train_losses
|
|
662
|
+
|
|
663
|
+
def sample_observation_from_signal(self, signal: NDArray) -> NDArray:
|
|
664
|
+
"""
|
|
665
|
+
Sample an instance of observation based on an input signal using a
|
|
666
|
+
learned Gaussian Mixture Model. For each pixel in the input signal,
|
|
667
|
+
samples a corresponding noisy pixel.
|
|
668
|
+
|
|
669
|
+
Parameters
|
|
670
|
+
----------
|
|
671
|
+
signal: numpy array
|
|
672
|
+
Clean 2D signal data.
|
|
673
|
+
|
|
674
|
+
Returns
|
|
675
|
+
-------
|
|
676
|
+
observation: numpy array
|
|
677
|
+
An instance of noisy observation data based on the input signal.
|
|
678
|
+
"""
|
|
679
|
+
assert len(signal.shape) == 2, "Only 2D inputs are supported."
|
|
680
|
+
|
|
681
|
+
signal_tensor = torch.from_numpy(signal).to(torch.float32)
|
|
682
|
+
height, width = signal_tensor.shape
|
|
683
|
+
|
|
684
|
+
with torch.no_grad():
|
|
685
|
+
# Get gaussian parameters for each pixel
|
|
686
|
+
gaussian_params = self.get_gaussian_parameters(signal_tensor)
|
|
687
|
+
means = np.array(gaussian_params[: self.n_gaussian])
|
|
688
|
+
stds = np.array(gaussian_params[self.n_gaussian : self.n_gaussian * 2])
|
|
689
|
+
alphas = np.array(gaussian_params[self.n_gaussian * 2 :])
|
|
690
|
+
|
|
691
|
+
if self.n_gaussian == 1:
|
|
692
|
+
# Single gaussian case
|
|
693
|
+
observation = np.random.normal(
|
|
694
|
+
loc=means[0], scale=stds[0], size=(height, width)
|
|
695
|
+
)
|
|
696
|
+
else:
|
|
697
|
+
# Multiple gaussians: sample component for each pixel
|
|
698
|
+
uniform = np.random.rand(1, height, width)
|
|
699
|
+
# Compute cumulative probabilities for component selection
|
|
700
|
+
cumulative_alphas = np.cumsum(
|
|
701
|
+
alphas, axis=0
|
|
702
|
+
) # Shape: (n_gaussian, height, width)
|
|
703
|
+
selected_component = np.argmax(
|
|
704
|
+
uniform < cumulative_alphas, axis=0, keepdims=True
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
# For every pixel, choose the corresponding gaussian
|
|
708
|
+
# and get the learned mu and sigma
|
|
709
|
+
selected_mus = np.take_along_axis(means, selected_component, axis=0)
|
|
710
|
+
selected_stds = np.take_along_axis(stds, selected_component, axis=0)
|
|
711
|
+
selected_mus = selected_mus.squeeze(0)
|
|
712
|
+
selected_stds = selected_stds.squeeze(0)
|
|
713
|
+
|
|
714
|
+
# Sample from the normal distribution with learned mu and sigma
|
|
715
|
+
observation = np.random.normal(
|
|
716
|
+
selected_mus, selected_stds, size=(height, width)
|
|
717
|
+
)
|
|
718
|
+
return observation
|
|
719
|
+
|
|
720
|
+
def save(self, path: str, name: str) -> None:
|
|
721
|
+
"""Save the trained parameters on the noise model.
|
|
722
|
+
|
|
723
|
+
Parameters
|
|
724
|
+
----------
|
|
725
|
+
path : str
|
|
726
|
+
Path to save the trained parameters.
|
|
727
|
+
name : str
|
|
728
|
+
File name to save the trained parameters.
|
|
729
|
+
"""
|
|
730
|
+
os.makedirs(path, exist_ok=True)
|
|
731
|
+
np.savez(
|
|
732
|
+
os.path.join(path, name),
|
|
733
|
+
trained_weight=self.weight.numpy(),
|
|
734
|
+
min_signal=self.min_signal.numpy(),
|
|
735
|
+
max_signal=self.max_signal.numpy(),
|
|
736
|
+
min_sigma=self.min_sigma,
|
|
737
|
+
)
|
|
738
|
+
print("The trained parameters (" + name + ") is saved at location: " + path)
|