careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""N2V manipulation transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
9
|
+
from careamics.transforms.transform import Transform
|
|
10
|
+
|
|
11
|
+
from .pixel_manipulation import median_manipulate, uniform_manipulate
|
|
12
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class N2VManipulate(Transform):
|
|
16
|
+
"""
|
|
17
|
+
Default augmentation for the N2V model.
|
|
18
|
+
|
|
19
|
+
This transform expects C(Z)YX dimensions.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
roi_size : int, optional
|
|
24
|
+
Size of the replacement area, by default 11.
|
|
25
|
+
masked_pixel_percentage : float, optional
|
|
26
|
+
Percentage of pixels to mask, by default 0.2.
|
|
27
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
28
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
29
|
+
remove_center : bool, optional
|
|
30
|
+
Whether to remove central pixel from patch, by default True.
|
|
31
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
32
|
+
StructN2V mask axis, by default "none".
|
|
33
|
+
struct_mask_span : int, optional
|
|
34
|
+
StructN2V mask span, by default 5.
|
|
35
|
+
seed : Optional[int], optional
|
|
36
|
+
Random seed, by default None.
|
|
37
|
+
|
|
38
|
+
Attributes
|
|
39
|
+
----------
|
|
40
|
+
masked_pixel_percentage : float
|
|
41
|
+
Percentage of pixels to mask.
|
|
42
|
+
roi_size : int
|
|
43
|
+
Size of the replacement area.
|
|
44
|
+
strategy : Literal[ "uniform", "median" ]
|
|
45
|
+
Replaccement strategy, uniform or median.
|
|
46
|
+
remove_center : bool
|
|
47
|
+
Whether to remove central pixel from patch.
|
|
48
|
+
struct_mask : Optional[StructMaskParameters]
|
|
49
|
+
StructN2V mask parameters.
|
|
50
|
+
rng : Generator
|
|
51
|
+
Random number generator.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
roi_size: int = 11,
|
|
57
|
+
masked_pixel_percentage: float = 0.2,
|
|
58
|
+
strategy: Literal[
|
|
59
|
+
"uniform", "median"
|
|
60
|
+
] = SupportedPixelManipulation.UNIFORM.value,
|
|
61
|
+
remove_center: bool = True,
|
|
62
|
+
struct_mask_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
63
|
+
struct_mask_span: int = 5,
|
|
64
|
+
seed: int | None = None,
|
|
65
|
+
):
|
|
66
|
+
"""Constructor.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
roi_size : int, optional
|
|
71
|
+
Size of the replacement area, by default 11.
|
|
72
|
+
masked_pixel_percentage : float, optional
|
|
73
|
+
Percentage of pixels to mask, by default 0.2.
|
|
74
|
+
strategy : Literal[ "uniform", "median" ], optional
|
|
75
|
+
Replaccement strategy, uniform or median, by default uniform.
|
|
76
|
+
remove_center : bool, optional
|
|
77
|
+
Whether to remove central pixel from patch, by default True.
|
|
78
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"], optional
|
|
79
|
+
StructN2V mask axis, by default "none".
|
|
80
|
+
struct_mask_span : int, optional
|
|
81
|
+
StructN2V mask span, by default 5.
|
|
82
|
+
seed : Optional[int], optional
|
|
83
|
+
Random seed, by default None.
|
|
84
|
+
"""
|
|
85
|
+
self.masked_pixel_percentage = masked_pixel_percentage
|
|
86
|
+
self.roi_size = roi_size
|
|
87
|
+
self.strategy = strategy
|
|
88
|
+
self.remove_center = remove_center # TODO is this ever used?
|
|
89
|
+
|
|
90
|
+
if struct_mask_axis == SupportedStructAxis.NONE:
|
|
91
|
+
self.struct_mask: StructMaskParameters | None = None
|
|
92
|
+
else:
|
|
93
|
+
self.struct_mask = StructMaskParameters(
|
|
94
|
+
axis=0 if struct_mask_axis == SupportedStructAxis.HORIZONTAL else 1,
|
|
95
|
+
span=struct_mask_span,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# numpy random generator
|
|
99
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
100
|
+
|
|
101
|
+
def __call__(
|
|
102
|
+
self, patch: NDArray, *args: Any, **kwargs: Any
|
|
103
|
+
) -> tuple[NDArray, NDArray, NDArray]:
|
|
104
|
+
"""Apply the transform to the image.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
patch : np.ndarray
|
|
109
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
110
|
+
*args : Any
|
|
111
|
+
Additional arguments, unused.
|
|
112
|
+
**kwargs : Any
|
|
113
|
+
Additional keyword arguments, unused.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
tuple[np.ndarray, np.ndarray, np.ndarray]
|
|
118
|
+
Masked patch, original patch, and mask.
|
|
119
|
+
"""
|
|
120
|
+
masked = np.zeros_like(patch)
|
|
121
|
+
mask = np.zeros_like(patch)
|
|
122
|
+
if self.strategy == SupportedPixelManipulation.UNIFORM:
|
|
123
|
+
# Iterate over the channels to apply manipulation separately
|
|
124
|
+
for c in range(patch.shape[0]):
|
|
125
|
+
masked[c, ...], mask[c, ...] = uniform_manipulate(
|
|
126
|
+
patch=patch[c, ...],
|
|
127
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
128
|
+
subpatch_size=self.roi_size,
|
|
129
|
+
remove_center=self.remove_center,
|
|
130
|
+
struct_params=self.struct_mask,
|
|
131
|
+
rng=self.rng,
|
|
132
|
+
)
|
|
133
|
+
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
134
|
+
# Iterate over the channels to apply manipulation separately
|
|
135
|
+
for c in range(patch.shape[0]):
|
|
136
|
+
masked[c, ...], mask[c, ...] = median_manipulate(
|
|
137
|
+
patch=patch[c, ...],
|
|
138
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
139
|
+
subpatch_size=self.roi_size,
|
|
140
|
+
struct_params=self.struct_mask,
|
|
141
|
+
rng=self.rng,
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
145
|
+
|
|
146
|
+
# TODO: Output does not match other transforms, how to resolve?
|
|
147
|
+
# - Don't include in Compose and apply after if algorithm is N2V?
|
|
148
|
+
# - or just don't return patch? but then mask is in the target position
|
|
149
|
+
# TODO why return patch?
|
|
150
|
+
return masked, patch, mask
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""N2V manipulation transform for PyTorch."""
|
|
2
|
+
|
|
3
|
+
import platform
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
9
|
+
from careamics.config.transformations import N2VManipulateConfig
|
|
10
|
+
|
|
11
|
+
from .pixel_manipulation_torch import (
|
|
12
|
+
median_manipulate_torch,
|
|
13
|
+
uniform_manipulate_torch,
|
|
14
|
+
)
|
|
15
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class N2VManipulateTorch:
|
|
19
|
+
"""
|
|
20
|
+
Default augmentation for the N2V model.
|
|
21
|
+
|
|
22
|
+
This transform expects C(Z)YX dimensions.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
n2v_manipulate_config : N2VManipulateConfig
|
|
27
|
+
N2V manipulation configuration.
|
|
28
|
+
seed : Optional[int], optional
|
|
29
|
+
Random seed, by default None.
|
|
30
|
+
device : str
|
|
31
|
+
The device on which operations take place, e.g. "cuda", "cpu" or "mps".
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
masked_pixel_percentage : float
|
|
36
|
+
Percentage of pixels to mask.
|
|
37
|
+
roi_size : int
|
|
38
|
+
Size of the replacement area.
|
|
39
|
+
strategy : Literal[ "uniform", "median" ]
|
|
40
|
+
Replacement strategy, uniform or median.
|
|
41
|
+
remove_center : bool
|
|
42
|
+
Whether to remove central pixel from patch.
|
|
43
|
+
struct_mask : Optional[StructMaskParameters]
|
|
44
|
+
StructN2V mask parameters.
|
|
45
|
+
rng : Generator
|
|
46
|
+
Random number generator.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
n2v_manipulate_config: N2VManipulateConfig,
|
|
52
|
+
seed: int | None = None,
|
|
53
|
+
device: str | None = None,
|
|
54
|
+
):
|
|
55
|
+
"""Constructor.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
n2v_manipulate_config : N2VManipulateConfig
|
|
60
|
+
N2V manipulation configuration.
|
|
61
|
+
seed : Optional[int], optional
|
|
62
|
+
Random seed, by default None.
|
|
63
|
+
device : str
|
|
64
|
+
The device on which operations take place, e.g. "cuda", "cpu" or "mps".
|
|
65
|
+
"""
|
|
66
|
+
self.masked_pixel_percentage = n2v_manipulate_config.masked_pixel_percentage
|
|
67
|
+
self.roi_size = n2v_manipulate_config.roi_size
|
|
68
|
+
self.strategy = n2v_manipulate_config.strategy
|
|
69
|
+
self.remove_center = n2v_manipulate_config.remove_center
|
|
70
|
+
|
|
71
|
+
if n2v_manipulate_config.struct_mask_axis == SupportedStructAxis.NONE:
|
|
72
|
+
self.struct_mask: StructMaskParameters | None = None
|
|
73
|
+
else:
|
|
74
|
+
self.struct_mask = StructMaskParameters(
|
|
75
|
+
axis=(
|
|
76
|
+
0
|
|
77
|
+
if n2v_manipulate_config.struct_mask_axis
|
|
78
|
+
== SupportedStructAxis.HORIZONTAL
|
|
79
|
+
else 1
|
|
80
|
+
),
|
|
81
|
+
span=n2v_manipulate_config.struct_mask_span,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# PyTorch random generator
|
|
85
|
+
# TODO refactor into careamics.utils.torch_utils.get_device
|
|
86
|
+
if device is None:
|
|
87
|
+
if torch.cuda.is_available():
|
|
88
|
+
device = "cuda"
|
|
89
|
+
elif torch.backends.mps.is_available() and platform.processor() in (
|
|
90
|
+
"arm",
|
|
91
|
+
"arm64",
|
|
92
|
+
):
|
|
93
|
+
device = "mps"
|
|
94
|
+
else:
|
|
95
|
+
device = "cpu"
|
|
96
|
+
|
|
97
|
+
self.rng = (
|
|
98
|
+
torch.Generator(device=device).manual_seed(seed)
|
|
99
|
+
if seed is not None
|
|
100
|
+
else torch.Generator(device=device)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def __call__(
|
|
104
|
+
self, batch: torch.Tensor, *args: Any, **kwargs: Any
|
|
105
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
106
|
+
"""Apply the transform to the image.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
batch : torch.Tensor
|
|
111
|
+
Batch if image patches, 2D or 3D, shape BC(Z)YX.
|
|
112
|
+
*args : Any
|
|
113
|
+
Additional arguments, unused.
|
|
114
|
+
**kwargs : Any
|
|
115
|
+
Additional keyword arguments, unused.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
120
|
+
Masked patch, original patch, and mask.
|
|
121
|
+
"""
|
|
122
|
+
masked = torch.zeros_like(batch)
|
|
123
|
+
mask = torch.zeros_like(batch, dtype=torch.uint8)
|
|
124
|
+
|
|
125
|
+
if self.strategy == SupportedPixelManipulation.UNIFORM:
|
|
126
|
+
# Iterate over the channels to apply manipulation separately
|
|
127
|
+
for c in range(batch.shape[1]):
|
|
128
|
+
masked[:, c, ...], mask[:, c, ...] = uniform_manipulate_torch(
|
|
129
|
+
patch=batch[:, c, ...],
|
|
130
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
131
|
+
subpatch_size=self.roi_size,
|
|
132
|
+
remove_center=self.remove_center,
|
|
133
|
+
struct_params=self.struct_mask,
|
|
134
|
+
rng=self.rng,
|
|
135
|
+
)
|
|
136
|
+
elif self.strategy == SupportedPixelManipulation.MEDIAN:
|
|
137
|
+
# Iterate over the channels to apply manipulation separately
|
|
138
|
+
for c in range(batch.shape[1]):
|
|
139
|
+
masked[:, c, ...], mask[:, c, ...] = median_manipulate_torch(
|
|
140
|
+
batch=batch[:, c, ...],
|
|
141
|
+
mask_pixel_percentage=self.masked_pixel_percentage,
|
|
142
|
+
subpatch_size=self.roi_size,
|
|
143
|
+
struct_params=self.struct_mask,
|
|
144
|
+
rng=self.rng,
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
raise ValueError(f"Unknown masking strategy ({self.strategy}).")
|
|
148
|
+
|
|
149
|
+
return masked, batch, mask
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""Normalization and denormalization transforms for image patches."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
from careamics.transforms.transform import Transform
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _reshape_stats(stats: list[float], ndim: int) -> NDArray:
|
|
12
|
+
"""Reshape stats to match the number of dimensions of the input image.
|
|
13
|
+
|
|
14
|
+
This allows to broadcast the stats (mean or std) to the image dimensions, and
|
|
15
|
+
thus directly perform a vectorial calculation.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
stats : list of float
|
|
20
|
+
List of stats, mean or standard deviation.
|
|
21
|
+
ndim : int
|
|
22
|
+
Number of dimensions of the image, including the C channel.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
NDArray
|
|
27
|
+
Reshaped stats.
|
|
28
|
+
"""
|
|
29
|
+
return np.array(stats)[(..., *[np.newaxis] * (ndim - 1))]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _reshape_stats_torch(stats: list[float], ndim: int) -> Tensor:
|
|
33
|
+
"""Torch equivalent of `_reshape_stats` for broadcasting over image dims.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
stats : list of float
|
|
38
|
+
List of stats, mean or standard deviation.
|
|
39
|
+
ndim : int
|
|
40
|
+
Number of dimensions of the tensor, including the C channel.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Tensor
|
|
45
|
+
Reshaped stats tensor.
|
|
46
|
+
"""
|
|
47
|
+
t = torch.tensor(stats)
|
|
48
|
+
# Add singleton dimensions to match input tensor ndim for broadcasting
|
|
49
|
+
return t[(..., *[None] * (ndim - 1))]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Normalize(Transform):
|
|
53
|
+
"""
|
|
54
|
+
Normalize an image or image patch.
|
|
55
|
+
|
|
56
|
+
Normalization is a zero mean and unit variance. This transform expects C(Z)YX
|
|
57
|
+
dimensions.
|
|
58
|
+
|
|
59
|
+
Not that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
60
|
+
division by zero and that it returns a float32 image.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
image_means : list of float
|
|
65
|
+
Mean value per channel.
|
|
66
|
+
image_stds : list of float
|
|
67
|
+
Standard deviation value per channel.
|
|
68
|
+
target_means : list of float, optional
|
|
69
|
+
Target mean value per channel, by default None.
|
|
70
|
+
target_stds : list of float, optional
|
|
71
|
+
Target standard deviation value per channel, by default None.
|
|
72
|
+
|
|
73
|
+
Attributes
|
|
74
|
+
----------
|
|
75
|
+
image_means : list of float
|
|
76
|
+
Mean value per channel.
|
|
77
|
+
image_stds : list of float
|
|
78
|
+
Standard deviation value per channel.
|
|
79
|
+
target_means :list of float, optional
|
|
80
|
+
Target mean value per channel, by default None.
|
|
81
|
+
target_stds : list of float, optional
|
|
82
|
+
Target standard deviation value per channel, by default None.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
image_means: list[float],
|
|
88
|
+
image_stds: list[float],
|
|
89
|
+
target_means: list[float] | None = None,
|
|
90
|
+
target_stds: list[float] | None = None,
|
|
91
|
+
):
|
|
92
|
+
"""Constructor.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
image_means : list of float
|
|
97
|
+
Mean value per channel.
|
|
98
|
+
image_stds : list of float
|
|
99
|
+
Standard deviation value per channel.
|
|
100
|
+
target_means : list of float, optional
|
|
101
|
+
Target mean value per channel, by default None.
|
|
102
|
+
target_stds : list of float, optional
|
|
103
|
+
Target standard deviation value per channel, by default None.
|
|
104
|
+
"""
|
|
105
|
+
self.image_means = image_means
|
|
106
|
+
self.image_stds = image_stds
|
|
107
|
+
self.target_means = target_means
|
|
108
|
+
self.target_stds = target_stds
|
|
109
|
+
|
|
110
|
+
self.eps = 1e-6
|
|
111
|
+
|
|
112
|
+
def __call__(
|
|
113
|
+
self,
|
|
114
|
+
patch: np.ndarray,
|
|
115
|
+
target: NDArray | None = None,
|
|
116
|
+
**additional_arrays: NDArray,
|
|
117
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
118
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
patch : NDArray
|
|
123
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
124
|
+
target : NDArray, optional
|
|
125
|
+
Target for the patch, by default None.
|
|
126
|
+
**additional_arrays : NDArray
|
|
127
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
128
|
+
`target`.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
tuple of NDArray
|
|
133
|
+
Transformed patch and target, the target can be returned as `None`.
|
|
134
|
+
"""
|
|
135
|
+
if len(self.image_means) != patch.shape[0]:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
138
|
+
f"number of channels (got shape {patch.shape} for C(Z)YX) do not match."
|
|
139
|
+
)
|
|
140
|
+
if len(additional_arrays) != 0:
|
|
141
|
+
raise NotImplementedError(
|
|
142
|
+
"Transforming additional arrays is currently not supported for "
|
|
143
|
+
"`Normalize`."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# reshape mean and std and apply the normalization to the patch
|
|
147
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
148
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
149
|
+
norm_patch = self._apply(patch, means, stds)
|
|
150
|
+
|
|
151
|
+
# same for the target patch
|
|
152
|
+
if target is None:
|
|
153
|
+
norm_target = None
|
|
154
|
+
else:
|
|
155
|
+
if not self.target_means or not self.target_stds:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"Target means and standard deviations must be provided "
|
|
158
|
+
"if target is not None."
|
|
159
|
+
)
|
|
160
|
+
if len(self.target_means) == 0 and len(self.target_stds) == 0:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
"Target means and standard deviations must be provided "
|
|
163
|
+
"if target is not None."
|
|
164
|
+
)
|
|
165
|
+
if len(self.target_means) != target.shape[0]:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"Target means and standard deviations must have the same length "
|
|
168
|
+
"as the target."
|
|
169
|
+
)
|
|
170
|
+
target_means = _reshape_stats(self.target_means, target.ndim)
|
|
171
|
+
target_stds = _reshape_stats(self.target_stds, target.ndim)
|
|
172
|
+
norm_target = self._apply(target, target_means, target_stds)
|
|
173
|
+
|
|
174
|
+
return norm_patch, norm_target, additional_arrays
|
|
175
|
+
|
|
176
|
+
def _apply(self, patch: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
177
|
+
"""
|
|
178
|
+
Apply the transform to the image.
|
|
179
|
+
|
|
180
|
+
Parameters
|
|
181
|
+
----------
|
|
182
|
+
patch : NDArray
|
|
183
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
184
|
+
mean : NDArray
|
|
185
|
+
Mean values.
|
|
186
|
+
std : NDArray
|
|
187
|
+
Standard deviations.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
NDArray
|
|
192
|
+
Normalized image patch.
|
|
193
|
+
"""
|
|
194
|
+
return ((patch - mean) / (std + self.eps)).astype(np.float32)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class Denormalize:
|
|
198
|
+
"""
|
|
199
|
+
Denormalize an image.
|
|
200
|
+
|
|
201
|
+
Denormalization is performed expecting a zero mean and unit variance input. This
|
|
202
|
+
transform expects C(Z)YX dimensions.
|
|
203
|
+
|
|
204
|
+
Note that an epsilon value of 1e-6 is added to the standard deviation to avoid
|
|
205
|
+
division by zero during the normalization step, which is taken into account during
|
|
206
|
+
denormalization.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
image_means : list or tuple of float
|
|
211
|
+
Mean value per channel.
|
|
212
|
+
image_stds : list or tuple of float
|
|
213
|
+
Standard deviation value per channel.
|
|
214
|
+
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
image_means: list[float],
|
|
220
|
+
image_stds: list[float],
|
|
221
|
+
):
|
|
222
|
+
"""Constructor.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
image_means : list of float
|
|
227
|
+
Mean value per channel.
|
|
228
|
+
image_stds : list of float
|
|
229
|
+
Standard deviation value per channel.
|
|
230
|
+
"""
|
|
231
|
+
self.image_means = image_means
|
|
232
|
+
self.image_stds = image_stds
|
|
233
|
+
|
|
234
|
+
self.eps = 1e-6
|
|
235
|
+
|
|
236
|
+
def __call__(self, patch: NDArray) -> NDArray:
|
|
237
|
+
"""Reverse the normalization operation for a batch of patches.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
patch : NDArray
|
|
242
|
+
Patch, 2D or 3D, shape BC(Z)YX.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
NDArray
|
|
247
|
+
Transformed array.
|
|
248
|
+
"""
|
|
249
|
+
# if len(self.image_means) != patch.shape[1]:
|
|
250
|
+
# raise ValueError(
|
|
251
|
+
# f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
252
|
+
# f"number of channels (got shape {patch.shape} for BC(Z)YX) do not "
|
|
253
|
+
# f"match."
|
|
254
|
+
# )
|
|
255
|
+
# TODO for pn2v channel handling needs to be changed
|
|
256
|
+
means = _reshape_stats(self.image_means, patch.ndim)
|
|
257
|
+
stds = _reshape_stats(self.image_stds, patch.ndim)
|
|
258
|
+
|
|
259
|
+
denorm_array = self._apply(
|
|
260
|
+
patch,
|
|
261
|
+
np.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
|
|
262
|
+
np.swapaxes(stds, 0, 1),
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return denorm_array.astype(np.float32)
|
|
266
|
+
|
|
267
|
+
def _apply(self, array: NDArray, mean: NDArray, std: NDArray) -> NDArray:
|
|
268
|
+
"""
|
|
269
|
+
Apply the transform to the image.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
array : NDArray
|
|
274
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
275
|
+
mean : NDArray
|
|
276
|
+
Mean values.
|
|
277
|
+
std : NDArray
|
|
278
|
+
Standard deviations.
|
|
279
|
+
|
|
280
|
+
Returns
|
|
281
|
+
-------
|
|
282
|
+
NDArray
|
|
283
|
+
Denormalized image array.
|
|
284
|
+
"""
|
|
285
|
+
return array * (std + self.eps) + mean
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class TrainDenormalize:
|
|
289
|
+
"""
|
|
290
|
+
Denormalize an image tensor for training-time tensors.
|
|
291
|
+
|
|
292
|
+
This class mirrors `Denormalize` but operates on torch tensors. It expects
|
|
293
|
+
the input tensor to have shape BC(Z)YX with the channel dimension at index 1.
|
|
294
|
+
|
|
295
|
+
Parameters
|
|
296
|
+
----------
|
|
297
|
+
image_means : list or tuple of float
|
|
298
|
+
Mean value per channel.
|
|
299
|
+
image_stds : list or tuple of float
|
|
300
|
+
Standard deviation value per channel.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
def __init__(
|
|
304
|
+
self,
|
|
305
|
+
image_means: list[float],
|
|
306
|
+
image_stds: list[float],
|
|
307
|
+
) -> None:
|
|
308
|
+
"""Initialize Denormalize transform.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
image_means : list of float
|
|
313
|
+
Mean values per channel.
|
|
314
|
+
image_stds : list of float
|
|
315
|
+
Standard deviation values per channel.
|
|
316
|
+
"""
|
|
317
|
+
self.image_means = image_means
|
|
318
|
+
self.image_stds = image_stds
|
|
319
|
+
self.eps = 1e-6
|
|
320
|
+
|
|
321
|
+
def __call__(self, patch: Tensor) -> Tensor:
|
|
322
|
+
"""Reverse the normalization operation for a batch of patches.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
patch : Tensor
|
|
327
|
+
Patch, 2D or 3D, shape BC(Z)YX.
|
|
328
|
+
|
|
329
|
+
Returns
|
|
330
|
+
-------
|
|
331
|
+
Tensor
|
|
332
|
+
Denormalized tensor with dtype float32.
|
|
333
|
+
"""
|
|
334
|
+
# if len(self.image_means) != patch.shape[1]:
|
|
335
|
+
# raise ValueError(
|
|
336
|
+
# f"Number of means (got a list of size {len(self.image_means)}) and "
|
|
337
|
+
# f"number of channels (got shape {tuple(patch.shape)} for BC(Z)YX) "
|
|
338
|
+
# f"don't match."
|
|
339
|
+
# )
|
|
340
|
+
# TODO for pn2v channel handling needs to be changed
|
|
341
|
+
|
|
342
|
+
means = _reshape_stats_torch(self.image_means, patch.ndim).to(
|
|
343
|
+
device=patch.device, dtype=patch.dtype
|
|
344
|
+
)
|
|
345
|
+
stds = _reshape_stats_torch(self.image_stds, patch.ndim).to(
|
|
346
|
+
device=patch.device, dtype=patch.dtype
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
denorm_tensor = self._apply(
|
|
350
|
+
patch,
|
|
351
|
+
torch.swapaxes(means, 0, 1), # swap axes as C channel is axis 1
|
|
352
|
+
torch.swapaxes(stds, 0, 1),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
return denorm_tensor.float()
|
|
356
|
+
|
|
357
|
+
def _apply(self, array: Tensor, mean: Tensor, std: Tensor) -> Tensor:
|
|
358
|
+
"""Apply the denormalization to the tensor.
|
|
359
|
+
|
|
360
|
+
Parameters
|
|
361
|
+
----------
|
|
362
|
+
array : Tensor
|
|
363
|
+
Input tensor.
|
|
364
|
+
mean : Tensor
|
|
365
|
+
Mean values.
|
|
366
|
+
std : Tensor
|
|
367
|
+
Standard deviation values.
|
|
368
|
+
|
|
369
|
+
Returns
|
|
370
|
+
-------
|
|
371
|
+
Tensor
|
|
372
|
+
Denormalized tensor.
|
|
373
|
+
"""
|
|
374
|
+
return array * (std + self.eps) + mean
|