careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Layer module.
|
|
3
|
+
|
|
4
|
+
This submodule contains layers used in the CAREamics models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from torch.nn import functional as F
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Conv_Block(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Convolution block used in UNets.
|
|
17
|
+
|
|
18
|
+
Convolution block consist of two convolution layers with optional batch norm,
|
|
19
|
+
dropout and with a final activation function.
|
|
20
|
+
|
|
21
|
+
The parameters are directly mapped to PyTorch Conv2D and Conv3d parameters, see
|
|
22
|
+
PyTorch torch.nn.Conv2d and torch.nn.Conv3d for more information.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
conv_dim : int
|
|
27
|
+
Number of dimension of the convolutions, 2 or 3.
|
|
28
|
+
in_channels : int
|
|
29
|
+
Number of input channels.
|
|
30
|
+
out_channels : int
|
|
31
|
+
Number of output channels.
|
|
32
|
+
intermediate_channel_multiplier : int, optional
|
|
33
|
+
Multiplied for the number of output channels, by default 1.
|
|
34
|
+
stride : int, optional
|
|
35
|
+
Stride of the convolutions, by default 1.
|
|
36
|
+
padding : int, optional
|
|
37
|
+
Padding of the convolutions, by default 1.
|
|
38
|
+
bias : bool, optional
|
|
39
|
+
Bias of the convolutions, by default True.
|
|
40
|
+
groups : int, optional
|
|
41
|
+
Controls the connections between inputs and outputs, by default 1.
|
|
42
|
+
activation : str, optional
|
|
43
|
+
Activation function, by default "ReLU".
|
|
44
|
+
dropout_perc : float, optional
|
|
45
|
+
Dropout percentage, by default 0.
|
|
46
|
+
use_batch_norm : bool, optional
|
|
47
|
+
Use batch norm, by default False.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
conv_dim: int,
|
|
53
|
+
in_channels: int,
|
|
54
|
+
out_channels: int,
|
|
55
|
+
intermediate_channel_multiplier: int = 1,
|
|
56
|
+
stride: int = 1,
|
|
57
|
+
padding: int = 1,
|
|
58
|
+
bias: bool = True,
|
|
59
|
+
groups: int = 1,
|
|
60
|
+
activation: str = "ReLU",
|
|
61
|
+
dropout_perc: float = 0,
|
|
62
|
+
use_batch_norm: bool = False,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Constructor.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
conv_dim : int
|
|
70
|
+
Number of dimension of the convolutions, 2 or 3.
|
|
71
|
+
in_channels : int
|
|
72
|
+
Number of input channels.
|
|
73
|
+
out_channels : int
|
|
74
|
+
Number of output channels.
|
|
75
|
+
intermediate_channel_multiplier : int, optional
|
|
76
|
+
Multiplied for the number of output channels, by default 1.
|
|
77
|
+
stride : int, optional
|
|
78
|
+
Stride of the convolutions, by default 1.
|
|
79
|
+
padding : int, optional
|
|
80
|
+
Padding of the convolutions, by default 1.
|
|
81
|
+
bias : bool, optional
|
|
82
|
+
Bias of the convolutions, by default True.
|
|
83
|
+
groups : int, optional
|
|
84
|
+
Controls the connections between inputs and outputs, by default 1.
|
|
85
|
+
activation : str, optional
|
|
86
|
+
Activation function, by default "ReLU".
|
|
87
|
+
dropout_perc : float, optional
|
|
88
|
+
Dropout percentage, by default 0.
|
|
89
|
+
use_batch_norm : bool, optional
|
|
90
|
+
Use batch norm, by default False.
|
|
91
|
+
"""
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.use_batch_norm = use_batch_norm
|
|
94
|
+
self.conv1 = getattr(nn, f"Conv{conv_dim}d")(
|
|
95
|
+
in_channels,
|
|
96
|
+
out_channels * intermediate_channel_multiplier,
|
|
97
|
+
kernel_size=3,
|
|
98
|
+
stride=stride,
|
|
99
|
+
padding=padding,
|
|
100
|
+
bias=bias,
|
|
101
|
+
groups=groups,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.conv2 = getattr(nn, f"Conv{conv_dim}d")(
|
|
105
|
+
out_channels * intermediate_channel_multiplier,
|
|
106
|
+
out_channels,
|
|
107
|
+
kernel_size=3,
|
|
108
|
+
stride=stride,
|
|
109
|
+
padding=padding,
|
|
110
|
+
bias=bias,
|
|
111
|
+
groups=groups,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.batch_norm1 = getattr(nn, f"BatchNorm{conv_dim}d")(
|
|
115
|
+
out_channels * intermediate_channel_multiplier
|
|
116
|
+
)
|
|
117
|
+
self.batch_norm2 = getattr(nn, f"BatchNorm{conv_dim}d")(out_channels)
|
|
118
|
+
|
|
119
|
+
self.dropout = (
|
|
120
|
+
getattr(nn, f"Dropout{conv_dim}d")(dropout_perc)
|
|
121
|
+
if dropout_perc > 0
|
|
122
|
+
else None
|
|
123
|
+
)
|
|
124
|
+
self.activation = (
|
|
125
|
+
getattr(nn, f"{activation}")() if activation is not None else nn.Identity()
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
129
|
+
"""
|
|
130
|
+
Forward pass.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
x : torch.Tensor
|
|
135
|
+
Input tensor.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
torch.Tensor
|
|
140
|
+
Output tensor.
|
|
141
|
+
"""
|
|
142
|
+
if self.use_batch_norm:
|
|
143
|
+
x = self.conv1(x)
|
|
144
|
+
x = self.batch_norm1(x)
|
|
145
|
+
x = self.activation(x)
|
|
146
|
+
x = self.conv2(x)
|
|
147
|
+
x = self.batch_norm2(x)
|
|
148
|
+
x = self.activation(x)
|
|
149
|
+
else:
|
|
150
|
+
x = self.conv1(x)
|
|
151
|
+
x = self.activation(x)
|
|
152
|
+
x = self.conv2(x)
|
|
153
|
+
x = self.activation(x)
|
|
154
|
+
if self.dropout is not None:
|
|
155
|
+
x = self.dropout(x)
|
|
156
|
+
return x
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _unpack_kernel_size(
|
|
160
|
+
kernel_size: Union[tuple[int, ...], int], dim: int
|
|
161
|
+
) -> tuple[int, ...]:
|
|
162
|
+
"""Unpack kernel_size to a tuple of ints.
|
|
163
|
+
|
|
164
|
+
Inspired by Kornia implementation. TODO: link
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
kernel_size : Union[tuple[int, ...], int]
|
|
169
|
+
Kernel size.
|
|
170
|
+
dim : int
|
|
171
|
+
Number of dimensions.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
tuple[int, ...]
|
|
176
|
+
Kernel size tuple.
|
|
177
|
+
"""
|
|
178
|
+
if isinstance(kernel_size, int):
|
|
179
|
+
kernel_dims = tuple([kernel_size for _ in range(dim)])
|
|
180
|
+
else:
|
|
181
|
+
kernel_dims = kernel_size
|
|
182
|
+
return kernel_dims
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _compute_zero_padding(
|
|
186
|
+
kernel_size: Union[tuple[int, ...], int], dim: int
|
|
187
|
+
) -> tuple[int, ...]:
|
|
188
|
+
"""Utility function that computes zero padding tuple.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
kernel_size : Union[tuple[int, ...], int]
|
|
193
|
+
Kernel size.
|
|
194
|
+
dim : int
|
|
195
|
+
Number of dimensions.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
tuple[int, ...]
|
|
200
|
+
Zero padding tuple.
|
|
201
|
+
"""
|
|
202
|
+
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
203
|
+
return tuple([(kd - 1) // 2 for kd in kernel_dims])
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def get_pascal_kernel_1d(
|
|
207
|
+
kernel_size: int,
|
|
208
|
+
norm: bool = False,
|
|
209
|
+
*,
|
|
210
|
+
device: torch.device | None = None,
|
|
211
|
+
dtype: torch.dtype | None = None,
|
|
212
|
+
) -> torch.Tensor:
|
|
213
|
+
"""Generate Yang Hui triangle (Pascal's triangle) for a given number.
|
|
214
|
+
|
|
215
|
+
Inspired by Kornia implementation. TODO link
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
kernel_size : int
|
|
220
|
+
Kernel size.
|
|
221
|
+
norm : bool
|
|
222
|
+
Normalize the kernel, by default False.
|
|
223
|
+
device : Optional[torch.device]
|
|
224
|
+
Device of the tensor, by default None.
|
|
225
|
+
dtype : Optional[torch.dtype]
|
|
226
|
+
Data type of the tensor, by default None.
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
-------
|
|
230
|
+
torch.Tensor
|
|
231
|
+
Pascal kernel.
|
|
232
|
+
|
|
233
|
+
Examples
|
|
234
|
+
--------
|
|
235
|
+
>>> get_pascal_kernel_1d(1)
|
|
236
|
+
tensor([1.])
|
|
237
|
+
>>> get_pascal_kernel_1d(2)
|
|
238
|
+
tensor([1., 1.])
|
|
239
|
+
>>> get_pascal_kernel_1d(3)
|
|
240
|
+
tensor([1., 2., 1.])
|
|
241
|
+
>>> get_pascal_kernel_1d(4)
|
|
242
|
+
tensor([1., 3., 3., 1.])
|
|
243
|
+
>>> get_pascal_kernel_1d(5)
|
|
244
|
+
tensor([1., 4., 6., 4., 1.])
|
|
245
|
+
>>> get_pascal_kernel_1d(6)
|
|
246
|
+
tensor([ 1., 5., 10., 10., 5., 1.])
|
|
247
|
+
"""
|
|
248
|
+
pre: list[float] = []
|
|
249
|
+
cur: list[float] = []
|
|
250
|
+
for i in range(kernel_size):
|
|
251
|
+
cur = [1.0] * (i + 1)
|
|
252
|
+
|
|
253
|
+
for j in range(1, i // 2 + 1):
|
|
254
|
+
value = pre[j - 1] + pre[j]
|
|
255
|
+
cur[j] = value
|
|
256
|
+
if i != 2 * j:
|
|
257
|
+
cur[-j - 1] = value
|
|
258
|
+
pre = cur
|
|
259
|
+
|
|
260
|
+
out = torch.tensor(cur, device=device, dtype=dtype)
|
|
261
|
+
|
|
262
|
+
if norm:
|
|
263
|
+
out = out / out.sum()
|
|
264
|
+
|
|
265
|
+
return out
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _get_pascal_kernel_nd(
|
|
269
|
+
kernel_size: Union[tuple[int, int], int],
|
|
270
|
+
norm: bool = True,
|
|
271
|
+
dim: int = 2,
|
|
272
|
+
*,
|
|
273
|
+
device: torch.device | None = None,
|
|
274
|
+
dtype: torch.dtype | None = None,
|
|
275
|
+
) -> torch.Tensor:
|
|
276
|
+
"""Generate pascal filter kernel by kernel size.
|
|
277
|
+
|
|
278
|
+
If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
|
|
279
|
+
otherwise the kernel will be shaped as kernel_size
|
|
280
|
+
|
|
281
|
+
Inspired by Kornia implementation.
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
kernel_size : Union[tuple[int, int], int]
|
|
286
|
+
Kernel size for the pascal kernel.
|
|
287
|
+
norm : bool
|
|
288
|
+
Normalize the kernel, by default True.
|
|
289
|
+
dim : int
|
|
290
|
+
Number of dimensions, by default 2.
|
|
291
|
+
device : Optional[torch.device]
|
|
292
|
+
Device of the tensor, by default None.
|
|
293
|
+
dtype : Optional[torch.dtype]
|
|
294
|
+
Data type of the tensor, by default None.
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
torch.Tensor
|
|
299
|
+
Pascal kernel.
|
|
300
|
+
|
|
301
|
+
Examples
|
|
302
|
+
--------
|
|
303
|
+
>>> _get_pascal_kernel_nd(1)
|
|
304
|
+
tensor([[1.]])
|
|
305
|
+
>>> _get_pascal_kernel_nd(4)
|
|
306
|
+
tensor([[0.0156, 0.0469, 0.0469, 0.0156],
|
|
307
|
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
|
308
|
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
|
309
|
+
[0.0156, 0.0469, 0.0469, 0.0156]])
|
|
310
|
+
>>> _get_pascal_kernel_nd(4, norm=False)
|
|
311
|
+
tensor([[1., 3., 3., 1.],
|
|
312
|
+
[3., 9., 9., 3.],
|
|
313
|
+
[3., 9., 9., 3.],
|
|
314
|
+
[1., 3., 3., 1.]])
|
|
315
|
+
"""
|
|
316
|
+
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
317
|
+
|
|
318
|
+
kernel = [
|
|
319
|
+
get_pascal_kernel_1d(kd, device=device, dtype=dtype) for kd in kernel_dims
|
|
320
|
+
]
|
|
321
|
+
|
|
322
|
+
if dim == 2:
|
|
323
|
+
kernel = kernel[0][:, None] * kernel[1][None, :]
|
|
324
|
+
elif dim == 3:
|
|
325
|
+
kernel = (
|
|
326
|
+
kernel[0][:, None, None]
|
|
327
|
+
* kernel[1][None, :, None]
|
|
328
|
+
* kernel[2][None, None, :]
|
|
329
|
+
)
|
|
330
|
+
if norm:
|
|
331
|
+
kernel = kernel / torch.sum(kernel)
|
|
332
|
+
return kernel
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _max_blur_pool_by_kernel2d(
|
|
336
|
+
x: torch.Tensor,
|
|
337
|
+
kernel: torch.Tensor,
|
|
338
|
+
stride: int,
|
|
339
|
+
max_pool_size: int,
|
|
340
|
+
ceil_mode: bool,
|
|
341
|
+
) -> torch.Tensor:
|
|
342
|
+
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
|
|
343
|
+
|
|
344
|
+
Inspired by Kornia implementation.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
x : torch.Tensor
|
|
349
|
+
Input tensor.
|
|
350
|
+
kernel : torch.Tensor
|
|
351
|
+
Kernel tensor.
|
|
352
|
+
stride : int
|
|
353
|
+
Stride.
|
|
354
|
+
max_pool_size : int
|
|
355
|
+
Maximum pool size.
|
|
356
|
+
ceil_mode : bool
|
|
357
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
torch.Tensor
|
|
362
|
+
Output tensor.
|
|
363
|
+
"""
|
|
364
|
+
# compute local maxima
|
|
365
|
+
x = F.max_pool2d(
|
|
366
|
+
x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
|
|
367
|
+
)
|
|
368
|
+
# blur and downsample
|
|
369
|
+
padding = _compute_zero_padding((kernel.shape[-2], kernel.shape[-1]), dim=2)
|
|
370
|
+
return F.conv2d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _max_blur_pool_by_kernel3d(
|
|
374
|
+
x: torch.Tensor,
|
|
375
|
+
kernel: torch.Tensor,
|
|
376
|
+
stride: int,
|
|
377
|
+
max_pool_size: int,
|
|
378
|
+
ceil_mode: bool,
|
|
379
|
+
) -> torch.Tensor:
|
|
380
|
+
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
|
|
381
|
+
|
|
382
|
+
Inspired by Kornia implementation.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
x : torch.Tensor
|
|
387
|
+
Input tensor.
|
|
388
|
+
kernel : torch.Tensor
|
|
389
|
+
Kernel tensor.
|
|
390
|
+
stride : int
|
|
391
|
+
Stride.
|
|
392
|
+
max_pool_size : int
|
|
393
|
+
Maximum pool size.
|
|
394
|
+
ceil_mode : bool
|
|
395
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
torch.Tensor
|
|
400
|
+
Output tensor.
|
|
401
|
+
"""
|
|
402
|
+
# compute local maxima
|
|
403
|
+
x = F.max_pool3d(
|
|
404
|
+
x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
|
|
405
|
+
)
|
|
406
|
+
# blur and downsample
|
|
407
|
+
padding = _compute_zero_padding(
|
|
408
|
+
(kernel.shape[-3], kernel.shape[-2], kernel.shape[-1]), dim=3
|
|
409
|
+
)
|
|
410
|
+
return F.conv3d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class MaxBlurPool(nn.Module):
|
|
414
|
+
"""Compute pools and blurs and downsample a given feature map.
|
|
415
|
+
|
|
416
|
+
Inspired by Kornia MaxBlurPool implementation. Equivalent to
|
|
417
|
+
```nn.Sequential(nn.MaxPool2d(...), BlurPool2D(...))```
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
dim : int
|
|
422
|
+
Toggles between 2D and 3D.
|
|
423
|
+
kernel_size : Union[tuple[int, int], int]
|
|
424
|
+
Kernel size for max pooling.
|
|
425
|
+
stride : int
|
|
426
|
+
Stride for pooling.
|
|
427
|
+
max_pool_size : int
|
|
428
|
+
Max kernel size for max pooling.
|
|
429
|
+
ceil_mode : bool
|
|
430
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def __init__(
|
|
434
|
+
self,
|
|
435
|
+
dim: int,
|
|
436
|
+
kernel_size: Union[tuple[int, int], int],
|
|
437
|
+
stride: int = 2,
|
|
438
|
+
max_pool_size: int = 2,
|
|
439
|
+
ceil_mode: bool = False,
|
|
440
|
+
) -> None:
|
|
441
|
+
"""Constructor.
|
|
442
|
+
|
|
443
|
+
Parameters
|
|
444
|
+
----------
|
|
445
|
+
dim : int
|
|
446
|
+
Dimension of the convolution.
|
|
447
|
+
kernel_size : Union[tuple[int, int], int]
|
|
448
|
+
Kernel size for max pooling.
|
|
449
|
+
stride : int, optional
|
|
450
|
+
Stride, by default 2.
|
|
451
|
+
max_pool_size : int, optional
|
|
452
|
+
Maximum pool size, by default 2.
|
|
453
|
+
ceil_mode : bool, optional
|
|
454
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
455
|
+
"""
|
|
456
|
+
super().__init__()
|
|
457
|
+
self.dim = dim
|
|
458
|
+
self.kernel_size = kernel_size
|
|
459
|
+
self.stride = stride
|
|
460
|
+
self.max_pool_size = max_pool_size
|
|
461
|
+
self.ceil_mode = ceil_mode
|
|
462
|
+
kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
|
|
463
|
+
self.register_buffer("kernel", kernel, persistent=False)
|
|
464
|
+
|
|
465
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
466
|
+
"""Forward pass of the function.
|
|
467
|
+
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
x : torch.Tensor
|
|
471
|
+
Input tensor.
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
torch.Tensor
|
|
476
|
+
Output tensor.
|
|
477
|
+
"""
|
|
478
|
+
kernel = self.kernel.to(dtype=x.dtype)
|
|
479
|
+
num_channels = int(x.size(1))
|
|
480
|
+
if self.dim == 2:
|
|
481
|
+
return _max_blur_pool_by_kernel2d(
|
|
482
|
+
x,
|
|
483
|
+
kernel.repeat((num_channels, 1, 1, 1)),
|
|
484
|
+
self.stride,
|
|
485
|
+
self.max_pool_size,
|
|
486
|
+
self.ceil_mode,
|
|
487
|
+
)
|
|
488
|
+
else:
|
|
489
|
+
return _max_blur_pool_by_kernel3d(
|
|
490
|
+
x,
|
|
491
|
+
kernel.repeat((num_channels, 1, 1, 1, 1)),
|
|
492
|
+
self.stride,
|
|
493
|
+
self.max_pool_size,
|
|
494
|
+
self.ceil_mode,
|
|
495
|
+
)
|