careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
careamics/models/unet.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UNet model.
|
|
3
|
+
|
|
4
|
+
A UNet encoder, decoder and complete model.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
from ..config.support import SupportedActivation
|
|
13
|
+
from .activation import get_activation
|
|
14
|
+
from .layers import Conv_Block, MaxBlurPool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnetEncoder(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
Unet encoder pathway.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
conv_dim : int
|
|
24
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
25
|
+
in_channels : int, optional
|
|
26
|
+
Number of input channels, by default 1.
|
|
27
|
+
depth : int, optional
|
|
28
|
+
Number of encoder blocks, by default 3.
|
|
29
|
+
num_channels_init : int, optional
|
|
30
|
+
Number of channels in the first encoder block, by default 64.
|
|
31
|
+
use_batch_norm : bool, optional
|
|
32
|
+
Whether to use batch normalization, by default True.
|
|
33
|
+
dropout : float, optional
|
|
34
|
+
Dropout probability, by default 0.0.
|
|
35
|
+
pool_kernel : int, optional
|
|
36
|
+
Kernel size for the max pooling layers, by default 2.
|
|
37
|
+
n2v2 : bool, optional
|
|
38
|
+
Whether to use N2V2 architecture, by default False.
|
|
39
|
+
groups : int, optional
|
|
40
|
+
Number of blocked connections from input channels to output
|
|
41
|
+
channels, by default 1.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
conv_dim: int,
|
|
47
|
+
in_channels: int = 1,
|
|
48
|
+
depth: int = 3,
|
|
49
|
+
num_channels_init: int = 64,
|
|
50
|
+
use_batch_norm: bool = True,
|
|
51
|
+
dropout: float = 0.0,
|
|
52
|
+
pool_kernel: int = 2,
|
|
53
|
+
n2v2: bool = False,
|
|
54
|
+
groups: int = 1,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Constructor.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
conv_dim : int
|
|
62
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
63
|
+
in_channels : int, optional
|
|
64
|
+
Number of input channels, by default 1.
|
|
65
|
+
depth : int, optional
|
|
66
|
+
Number of encoder blocks, by default 3.
|
|
67
|
+
num_channels_init : int, optional
|
|
68
|
+
Number of channels in the first encoder block, by default 64.
|
|
69
|
+
use_batch_norm : bool, optional
|
|
70
|
+
Whether to use batch normalization, by default True.
|
|
71
|
+
dropout : float, optional
|
|
72
|
+
Dropout probability, by default 0.0.
|
|
73
|
+
pool_kernel : int, optional
|
|
74
|
+
Kernel size for the max pooling layers, by default 2.
|
|
75
|
+
n2v2 : bool, optional
|
|
76
|
+
Whether to use N2V2 architecture, by default False.
|
|
77
|
+
groups : int, optional
|
|
78
|
+
Number of blocked connections from input channels to output
|
|
79
|
+
channels, by default 1.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
self.pooling = (
|
|
84
|
+
getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
|
|
85
|
+
if not n2v2
|
|
86
|
+
else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
encoder_blocks = []
|
|
90
|
+
|
|
91
|
+
for n in range(depth):
|
|
92
|
+
out_channels = num_channels_init * (2**n) * groups
|
|
93
|
+
in_channels = in_channels if n == 0 else out_channels // 2
|
|
94
|
+
encoder_blocks.append(
|
|
95
|
+
Conv_Block(
|
|
96
|
+
conv_dim,
|
|
97
|
+
in_channels=in_channels,
|
|
98
|
+
out_channels=out_channels,
|
|
99
|
+
dropout_perc=dropout,
|
|
100
|
+
use_batch_norm=use_batch_norm,
|
|
101
|
+
groups=groups,
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
encoder_blocks.append(self.pooling)
|
|
105
|
+
self.encoder_blocks = nn.ModuleList(encoder_blocks)
|
|
106
|
+
|
|
107
|
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
108
|
+
"""
|
|
109
|
+
Forward pass.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
x : torch.Tensor
|
|
114
|
+
Input tensor.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
list[torch.Tensor]
|
|
119
|
+
Output of each encoder block (skip connections) and final output of the
|
|
120
|
+
encoder.
|
|
121
|
+
"""
|
|
122
|
+
encoder_features = []
|
|
123
|
+
for module in self.encoder_blocks:
|
|
124
|
+
x = module(x)
|
|
125
|
+
if isinstance(module, Conv_Block):
|
|
126
|
+
encoder_features.append(x)
|
|
127
|
+
features = [x, *encoder_features]
|
|
128
|
+
return features
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class UnetDecoder(nn.Module):
|
|
132
|
+
"""
|
|
133
|
+
Unet decoder pathway.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
conv_dim : int
|
|
138
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
139
|
+
depth : int, optional
|
|
140
|
+
Number of decoder blocks, by default 3.
|
|
141
|
+
num_channels_init : int, optional
|
|
142
|
+
Number of channels in the first encoder block, by default 64.
|
|
143
|
+
use_batch_norm : bool, optional
|
|
144
|
+
Whether to use batch normalization, by default True.
|
|
145
|
+
dropout : float, optional
|
|
146
|
+
Dropout probability, by default 0.0.
|
|
147
|
+
n2v2 : bool, optional
|
|
148
|
+
Whether to use N2V2 architecture, by default False.
|
|
149
|
+
groups : int, optional
|
|
150
|
+
Number of blocked connections from input channels to output
|
|
151
|
+
channels, by default 1.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
conv_dim: int,
|
|
157
|
+
depth: int = 3,
|
|
158
|
+
num_channels_init: int = 64,
|
|
159
|
+
use_batch_norm: bool = True,
|
|
160
|
+
dropout: float = 0.0,
|
|
161
|
+
n2v2: bool = False,
|
|
162
|
+
groups: int = 1,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Constructor.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
conv_dim : int
|
|
170
|
+
Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
|
|
171
|
+
depth : int, optional
|
|
172
|
+
Number of decoder blocks, by default 3.
|
|
173
|
+
num_channels_init : int, optional
|
|
174
|
+
Number of channels in the first encoder block, by default 64.
|
|
175
|
+
use_batch_norm : bool, optional
|
|
176
|
+
Whether to use batch normalization, by default True.
|
|
177
|
+
dropout : float, optional
|
|
178
|
+
Dropout probability, by default 0.0.
|
|
179
|
+
n2v2 : bool, optional
|
|
180
|
+
Whether to use N2V2 architecture, by default False.
|
|
181
|
+
groups : int, optional
|
|
182
|
+
Number of blocked connections from input channels to output
|
|
183
|
+
channels, by default 1.
|
|
184
|
+
"""
|
|
185
|
+
super().__init__()
|
|
186
|
+
|
|
187
|
+
upsampling = nn.Upsample(
|
|
188
|
+
scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
|
|
189
|
+
)
|
|
190
|
+
in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
|
|
191
|
+
|
|
192
|
+
self.n2v2 = n2v2
|
|
193
|
+
self.groups = groups
|
|
194
|
+
|
|
195
|
+
self.bottleneck = Conv_Block(
|
|
196
|
+
conv_dim,
|
|
197
|
+
in_channels=in_channels,
|
|
198
|
+
out_channels=out_channels,
|
|
199
|
+
intermediate_channel_multiplier=2,
|
|
200
|
+
use_batch_norm=use_batch_norm,
|
|
201
|
+
dropout_perc=dropout,
|
|
202
|
+
groups=self.groups,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
decoder_blocks: list[nn.Module] = []
|
|
206
|
+
for n in range(depth):
|
|
207
|
+
decoder_blocks.append(upsampling)
|
|
208
|
+
|
|
209
|
+
in_channels = (num_channels_init * 2 ** (depth - n - 1)) * groups
|
|
210
|
+
# final decoder block has the same number in and out features
|
|
211
|
+
out_channels = in_channels // 2 if n != depth - 1 else in_channels
|
|
212
|
+
if not (n2v2 and (n == depth - 1)):
|
|
213
|
+
in_channels = in_channels * 2 # accounting for skip connection concat
|
|
214
|
+
|
|
215
|
+
decoder_blocks.append(
|
|
216
|
+
Conv_Block(
|
|
217
|
+
conv_dim,
|
|
218
|
+
in_channels=in_channels,
|
|
219
|
+
out_channels=out_channels,
|
|
220
|
+
# TODO: Tensorflow n2v implementation has intermediate channel
|
|
221
|
+
# multiplication for skip_skipone=True but not skip_skipone=False
|
|
222
|
+
# this needs to be benchmarked.
|
|
223
|
+
# final decoder block doesn't multiply the intermediate features
|
|
224
|
+
intermediate_channel_multiplier=2 if n != depth - 1 else 1,
|
|
225
|
+
dropout_perc=dropout,
|
|
226
|
+
activation="ReLU",
|
|
227
|
+
use_batch_norm=use_batch_norm,
|
|
228
|
+
groups=groups,
|
|
229
|
+
)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
self.decoder_blocks = nn.ModuleList(decoder_blocks)
|
|
233
|
+
|
|
234
|
+
def forward(self, *features: torch.Tensor) -> torch.Tensor:
|
|
235
|
+
"""
|
|
236
|
+
Forward pass.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
*features : list[torch.Tensor]
|
|
241
|
+
List containing the output of each encoder block(skip connections) and final
|
|
242
|
+
output of the encoder.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
torch.Tensor
|
|
247
|
+
Output of the decoder.
|
|
248
|
+
"""
|
|
249
|
+
x: torch.Tensor = features[0]
|
|
250
|
+
skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
251
|
+
depth = len(skip_connections)
|
|
252
|
+
|
|
253
|
+
x = self.bottleneck(x)
|
|
254
|
+
|
|
255
|
+
for i, module in enumerate(self.decoder_blocks):
|
|
256
|
+
x = module(x)
|
|
257
|
+
if isinstance(module, nn.Upsample):
|
|
258
|
+
# divide index by 2 because of upsampling layers
|
|
259
|
+
skip_connection: torch.Tensor = skip_connections[i // 2]
|
|
260
|
+
# top level skip connection not added for n2v2
|
|
261
|
+
if (not self.n2v2) or (self.n2v2 and (i // 2 < depth - 1)):
|
|
262
|
+
x = self._interleave(x, skip_connection, self.groups)
|
|
263
|
+
return x
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
|
|
267
|
+
"""Interleave two tensors.
|
|
268
|
+
|
|
269
|
+
Splits the tensors `A` and `B` into equally sized groups along the channel
|
|
270
|
+
axis (axis=1); then concatenates the groups in alternating order along the
|
|
271
|
+
channel axis, starting with the first group from tensor A.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
A : torch.Tensor
|
|
276
|
+
First tensor.
|
|
277
|
+
B : torch.Tensor
|
|
278
|
+
Second tensor.
|
|
279
|
+
groups : int
|
|
280
|
+
The number of groups.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
torch.Tensor
|
|
285
|
+
Interleaved tensor.
|
|
286
|
+
|
|
287
|
+
Raises
|
|
288
|
+
------
|
|
289
|
+
ValueError:
|
|
290
|
+
If either of `A` or `B`'s channel axis is not divisible by `groups`.
|
|
291
|
+
"""
|
|
292
|
+
if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
|
|
293
|
+
raise ValueError(f"Number of channels not divisible by {groups} groups.")
|
|
294
|
+
|
|
295
|
+
m = A.shape[1] // groups
|
|
296
|
+
n = B.shape[1] // groups
|
|
297
|
+
|
|
298
|
+
A_groups: list[torch.Tensor] = [
|
|
299
|
+
A[:, i * m : (i + 1) * m] for i in range(groups)
|
|
300
|
+
]
|
|
301
|
+
B_groups: list[torch.Tensor] = [
|
|
302
|
+
B[:, i * n : (i + 1) * n] for i in range(groups)
|
|
303
|
+
]
|
|
304
|
+
|
|
305
|
+
interleaved = torch.cat(
|
|
306
|
+
[
|
|
307
|
+
tensor_list[i]
|
|
308
|
+
for i in range(groups)
|
|
309
|
+
for tensor_list in [A_groups, B_groups]
|
|
310
|
+
],
|
|
311
|
+
dim=1,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return interleaved
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class UNet(nn.Module):
|
|
318
|
+
"""
|
|
319
|
+
UNet model.
|
|
320
|
+
|
|
321
|
+
Adapted for PyTorch from:
|
|
322
|
+
https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
conv_dims : int
|
|
327
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
328
|
+
num_classes : int, optional
|
|
329
|
+
Number of classes to predict, by default 1.
|
|
330
|
+
in_channels : int, optional
|
|
331
|
+
Number of input channels, by default 1.
|
|
332
|
+
depth : int, optional
|
|
333
|
+
Number of downsamplings, by default 3.
|
|
334
|
+
num_channels_init : int, optional
|
|
335
|
+
Number of filters in the first convolution layer, by default 64.
|
|
336
|
+
use_batch_norm : bool, optional
|
|
337
|
+
Whether to use batch normalization, by default True.
|
|
338
|
+
dropout : float, optional
|
|
339
|
+
Dropout probability, by default 0.0.
|
|
340
|
+
pool_kernel : int, optional
|
|
341
|
+
Kernel size of the pooling layers, by default 2.
|
|
342
|
+
final_activation : Optional[Callable], optional
|
|
343
|
+
Activation function to use for the last layer, by default None.
|
|
344
|
+
n2v2 : bool, optional
|
|
345
|
+
Whether to use N2V2 architecture, by default False.
|
|
346
|
+
independent_channels : bool
|
|
347
|
+
Whether to train the channels independently, by default True.
|
|
348
|
+
**kwargs : Any
|
|
349
|
+
Additional keyword arguments, unused.
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
def __init__(
|
|
353
|
+
self,
|
|
354
|
+
conv_dims: int,
|
|
355
|
+
num_classes: int = 1,
|
|
356
|
+
in_channels: int = 1,
|
|
357
|
+
depth: int = 3,
|
|
358
|
+
num_channels_init: int = 64,
|
|
359
|
+
use_batch_norm: bool = True,
|
|
360
|
+
dropout: float = 0.0,
|
|
361
|
+
pool_kernel: int = 2,
|
|
362
|
+
final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
|
|
363
|
+
n2v2: bool = False,
|
|
364
|
+
independent_channels: bool = True,
|
|
365
|
+
**kwargs: Any,
|
|
366
|
+
) -> None:
|
|
367
|
+
"""
|
|
368
|
+
Constructor.
|
|
369
|
+
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
conv_dims : int
|
|
373
|
+
Number of dimensions of the convolution layers (2 or 3).
|
|
374
|
+
num_classes : int, optional
|
|
375
|
+
Number of classes to predict, by default 1.
|
|
376
|
+
in_channels : int, optional
|
|
377
|
+
Number of input channels, by default 1.
|
|
378
|
+
depth : int, optional
|
|
379
|
+
Number of downsamplings, by default 3.
|
|
380
|
+
num_channels_init : int, optional
|
|
381
|
+
Number of filters in the first convolution layer, by default 64.
|
|
382
|
+
use_batch_norm : bool, optional
|
|
383
|
+
Whether to use batch normalization, by default True.
|
|
384
|
+
dropout : float, optional
|
|
385
|
+
Dropout probability, by default 0.0.
|
|
386
|
+
pool_kernel : int, optional
|
|
387
|
+
Kernel size of the pooling layers, by default 2.
|
|
388
|
+
final_activation : Optional[Callable], optional
|
|
389
|
+
Activation function to use for the last layer, by default None.
|
|
390
|
+
n2v2 : bool, optional
|
|
391
|
+
Whether to use N2V2 architecture, by default False.
|
|
392
|
+
independent_channels : bool
|
|
393
|
+
Whether to train parallel independent networks for each channel, by
|
|
394
|
+
default True.
|
|
395
|
+
**kwargs : Any
|
|
396
|
+
Additional keyword arguments, unused.
|
|
397
|
+
"""
|
|
398
|
+
super().__init__()
|
|
399
|
+
|
|
400
|
+
groups = in_channels if independent_channels else 1
|
|
401
|
+
|
|
402
|
+
self.encoder = UnetEncoder(
|
|
403
|
+
conv_dims,
|
|
404
|
+
in_channels=in_channels,
|
|
405
|
+
depth=depth,
|
|
406
|
+
num_channels_init=num_channels_init,
|
|
407
|
+
use_batch_norm=use_batch_norm,
|
|
408
|
+
dropout=dropout,
|
|
409
|
+
pool_kernel=pool_kernel,
|
|
410
|
+
n2v2=n2v2,
|
|
411
|
+
groups=groups,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
self.decoder = UnetDecoder(
|
|
415
|
+
conv_dims,
|
|
416
|
+
depth=depth,
|
|
417
|
+
num_channels_init=num_channels_init,
|
|
418
|
+
use_batch_norm=use_batch_norm,
|
|
419
|
+
dropout=dropout,
|
|
420
|
+
n2v2=n2v2,
|
|
421
|
+
groups=groups,
|
|
422
|
+
)
|
|
423
|
+
self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
|
|
424
|
+
in_channels=num_channels_init * groups,
|
|
425
|
+
out_channels=num_classes,
|
|
426
|
+
kernel_size=1,
|
|
427
|
+
groups=groups,
|
|
428
|
+
)
|
|
429
|
+
self.final_activation = get_activation(final_activation)
|
|
430
|
+
|
|
431
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
432
|
+
"""
|
|
433
|
+
Forward pass.
|
|
434
|
+
|
|
435
|
+
Parameters
|
|
436
|
+
----------
|
|
437
|
+
x : torch.Tensor
|
|
438
|
+
Input tensor.
|
|
439
|
+
|
|
440
|
+
Returns
|
|
441
|
+
-------
|
|
442
|
+
torch.Tensor
|
|
443
|
+
Output of the model.
|
|
444
|
+
"""
|
|
445
|
+
encoder_features = self.encoder(x)
|
|
446
|
+
x = self.decoder(*encoder_features)
|
|
447
|
+
x = self.final_conv(x)
|
|
448
|
+
x = self.final_activation(x)
|
|
449
|
+
return x
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Placeholder code snippets for noise model training integration.
|
|
2
|
+
|
|
3
|
+
This module contains template/placeholder code that demonstrates how noise model
|
|
4
|
+
training could be integrated into CAREamist. These are reference implementations
|
|
5
|
+
and should not be imported or used directly.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
from numpy.typing import NDArray
|
|
13
|
+
from pytorch_lightning.callbacks import Callback
|
|
14
|
+
|
|
15
|
+
from careamics.config.configuration import Configuration
|
|
16
|
+
from careamics.models.lvae.noise_models import (
|
|
17
|
+
GaussianMixtureNoiseModel,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# In src/careamics/careamist.py (newly added section only)
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
source: Union[Path, str, Configuration],
|
|
27
|
+
work_dir: Union[Path, str] | None = None,
|
|
28
|
+
callbacks: list[Callback] | None = None,
|
|
29
|
+
enable_progress_bar: bool = True,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Placeholder __init__ method showing noise model initialization.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
self : object
|
|
36
|
+
CAREamist instance.
|
|
37
|
+
source : Union[Path, str, Configuration]
|
|
38
|
+
Configuration source.
|
|
39
|
+
work_dir : Union[Path, str] | None, optional
|
|
40
|
+
Working directory, by default None.
|
|
41
|
+
callbacks : list[Callback] | None, optional
|
|
42
|
+
List of callbacks, by default None.
|
|
43
|
+
enable_progress_bar : bool, optional
|
|
44
|
+
Whether to show progress bar, by default True.
|
|
45
|
+
"""
|
|
46
|
+
# ... existing initialization code ...
|
|
47
|
+
|
|
48
|
+
# Initialize untrained noise models if needed
|
|
49
|
+
self.untrained_noise_models = None
|
|
50
|
+
if (
|
|
51
|
+
hasattr(self.cfg.algorithm_config, "train_noise_model")
|
|
52
|
+
and self.cfg.algorithm_config.train_noise_model_from_data
|
|
53
|
+
):
|
|
54
|
+
self._initialize_noise_models_for_training()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# In src/careamics/careamist.py
|
|
58
|
+
def train_noise_model(
|
|
59
|
+
self,
|
|
60
|
+
clean_data: Union[Path, str, NDArray],
|
|
61
|
+
noisy_data: Union[Path, str, NDArray],
|
|
62
|
+
learning_rate: float = 1e-1,
|
|
63
|
+
batch_size: int = 250000,
|
|
64
|
+
n_epochs: int = 2000,
|
|
65
|
+
lower_clip: float = 0.0,
|
|
66
|
+
upper_clip: float = 100.0,
|
|
67
|
+
save_noise_models: bool = True,
|
|
68
|
+
) -> None:
|
|
69
|
+
"""Train noise models from clean/noisy data pairs.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
self : object
|
|
74
|
+
CAREamist instance.
|
|
75
|
+
clean_data : Union[Path, str, NDArray]
|
|
76
|
+
Clean (signal) data for training noise models.
|
|
77
|
+
noisy_data : Union[Path, str, NDArray]
|
|
78
|
+
Noisy (observation) data for training noise models.
|
|
79
|
+
learning_rate : float, default=1e-1
|
|
80
|
+
Learning rate for noise model training.
|
|
81
|
+
batch_size : int, default=250000
|
|
82
|
+
Batch size for noise model training.
|
|
83
|
+
n_epochs : int, default=2000
|
|
84
|
+
Number of epochs for noise model training.
|
|
85
|
+
lower_clip : float, default=0.0
|
|
86
|
+
Lower percentile for clipping training data.
|
|
87
|
+
upper_clip : float, default=100.0
|
|
88
|
+
Upper percentile for clipping training data.
|
|
89
|
+
save_noise_models : bool, default=True
|
|
90
|
+
Whether to save trained noise models to disk.
|
|
91
|
+
|
|
92
|
+
Raises
|
|
93
|
+
------
|
|
94
|
+
ValueError
|
|
95
|
+
If noise models are not initialized for training.
|
|
96
|
+
ValueError
|
|
97
|
+
If data shapes don't match expectations.
|
|
98
|
+
"""
|
|
99
|
+
# Check if noise model is initialized (config should have MultiChannelNMConfig)
|
|
100
|
+
if self.cfg.algorithm_config.noise_model is None:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"No untrained noise models found. Set `train_noise_model=True` "
|
|
103
|
+
"in configuration."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Load data if paths provided (currently NM expects only numpy)
|
|
107
|
+
if isinstance(clean_data, (str, Path)):
|
|
108
|
+
clean_data = self._load_data(clean_data)
|
|
109
|
+
if isinstance(noisy_data, (str, Path)):
|
|
110
|
+
noisy_data = self._load_data(noisy_data)
|
|
111
|
+
|
|
112
|
+
# Type narrowing for mypy
|
|
113
|
+
assert not isinstance(clean_data, (str, Path))
|
|
114
|
+
assert not isinstance(noisy_data, (str, Path))
|
|
115
|
+
|
|
116
|
+
# Validate data shapes
|
|
117
|
+
if clean_data.shape != noisy_data.shape:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Clean and noisy data shapes must match. "
|
|
120
|
+
f"Got clean: {clean_data.shape}, noisy: {noisy_data.shape}"
|
|
121
|
+
)
|
|
122
|
+
# TODO other data shape checks
|
|
123
|
+
|
|
124
|
+
# parameter controlling the number of channels to split for MS, for HDN it's 1
|
|
125
|
+
output_channels = self.cfg.algorithm_config.model.output_channels
|
|
126
|
+
|
|
127
|
+
# Train noise model for each channel
|
|
128
|
+
trained_noise_models = []
|
|
129
|
+
for channel_idx in range(output_channels):
|
|
130
|
+
logger.info(
|
|
131
|
+
f"Training noise model for channel {channel_idx + 1}/{output_channels}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Extract single channel data
|
|
135
|
+
clean_channel = clean_data[:, channel_idx] # (N, H, W)
|
|
136
|
+
noisy_channel = noisy_data[:, channel_idx] # (N, H, W)
|
|
137
|
+
|
|
138
|
+
# Train noise model for this channel
|
|
139
|
+
noise_model = self.untrained_noise_models[channel_idx]
|
|
140
|
+
noise_model.fit(
|
|
141
|
+
signal=clean_channel,
|
|
142
|
+
observation=noisy_channel,
|
|
143
|
+
learning_rate=learning_rate,
|
|
144
|
+
batch_size=batch_size,
|
|
145
|
+
n_epochs=n_epochs,
|
|
146
|
+
lower_clip=lower_clip,
|
|
147
|
+
upper_clip=upper_clip,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
trained_noise_models.append(noise_model)
|
|
151
|
+
|
|
152
|
+
# Save individual noise model if requested
|
|
153
|
+
if save_noise_models:
|
|
154
|
+
save_path = self.work_dir / "noise_models"
|
|
155
|
+
noise_model.save(str(save_path), f"noise_model_ch{channel_idx}.npz")
|
|
156
|
+
logger.info(f"Saved noise model for channel {channel_idx} to {save_path}")
|
|
157
|
+
|
|
158
|
+
# Update the algorithm configuration with trained noise models
|
|
159
|
+
self._update_config_with_trained_noise_models(trained_noise_models)
|
|
160
|
+
|
|
161
|
+
logger.info("Noise model training completed successfully")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _update_config_with_trained_noise_models(
|
|
165
|
+
self, trained_models: list[GaussianMixtureNoiseModel]
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Update algorithm config with trained noise models.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
self : object
|
|
172
|
+
CAREamist instance.
|
|
173
|
+
trained_models : list[GaussianMixtureNoiseModel]
|
|
174
|
+
List of trained noise models, one per channel.
|
|
175
|
+
"""
|
|
176
|
+
# Currently the model is initialized in the __init__ of CAREamist
|
|
177
|
+
# multichannel_noise_model_factory inside VAEModule expects paths to noise models
|
|
178
|
+
# Ideally, we change that and call multichannel_noise_model_factory here after the
|
|
179
|
+
# model init and update the parameters of noise models right in the
|
|
180
|
+
# MultiChannelNoiseModel
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _load_data(self, data_path: Union[Path, str]) -> NDArray:
|
|
184
|
+
"""Load data from file path.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
self : object
|
|
189
|
+
CAREamist instance.
|
|
190
|
+
data_path : Union[Path, str]
|
|
191
|
+
Path to data file.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
NDArray
|
|
196
|
+
Loaded data array.
|
|
197
|
+
|
|
198
|
+
Raises
|
|
199
|
+
------
|
|
200
|
+
NotImplementedError
|
|
201
|
+
This is a placeholder method.
|
|
202
|
+
"""
|
|
203
|
+
raise NotImplementedError("Data loading not yet implemented")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Package to house various prediction utilies."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"convert_outputs",
|
|
5
|
+
"convert_outputs_microsplit",
|
|
6
|
+
"convert_outputs_pn2v",
|
|
7
|
+
"stitch_prediction",
|
|
8
|
+
"stitch_prediction_single",
|
|
9
|
+
"stitch_prediction_vae",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
from .prediction_outputs import (
|
|
13
|
+
convert_outputs,
|
|
14
|
+
convert_outputs_microsplit,
|
|
15
|
+
convert_outputs_pn2v,
|
|
16
|
+
)
|
|
17
|
+
from .stitch_prediction import (
|
|
18
|
+
stitch_prediction,
|
|
19
|
+
stitch_prediction_single,
|
|
20
|
+
stitch_prediction_vae,
|
|
21
|
+
)
|