careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,495 @@
1
+ """
2
+ Layer module.
3
+
4
+ This submodule contains layers used in the CAREamics models.
5
+ """
6
+
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+
13
+
14
+ class Conv_Block(nn.Module):
15
+ """
16
+ Convolution block used in UNets.
17
+
18
+ Convolution block consist of two convolution layers with optional batch norm,
19
+ dropout and with a final activation function.
20
+
21
+ The parameters are directly mapped to PyTorch Conv2D and Conv3d parameters, see
22
+ PyTorch torch.nn.Conv2d and torch.nn.Conv3d for more information.
23
+
24
+ Parameters
25
+ ----------
26
+ conv_dim : int
27
+ Number of dimension of the convolutions, 2 or 3.
28
+ in_channels : int
29
+ Number of input channels.
30
+ out_channels : int
31
+ Number of output channels.
32
+ intermediate_channel_multiplier : int, optional
33
+ Multiplied for the number of output channels, by default 1.
34
+ stride : int, optional
35
+ Stride of the convolutions, by default 1.
36
+ padding : int, optional
37
+ Padding of the convolutions, by default 1.
38
+ bias : bool, optional
39
+ Bias of the convolutions, by default True.
40
+ groups : int, optional
41
+ Controls the connections between inputs and outputs, by default 1.
42
+ activation : str, optional
43
+ Activation function, by default "ReLU".
44
+ dropout_perc : float, optional
45
+ Dropout percentage, by default 0.
46
+ use_batch_norm : bool, optional
47
+ Use batch norm, by default False.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ conv_dim: int,
53
+ in_channels: int,
54
+ out_channels: int,
55
+ intermediate_channel_multiplier: int = 1,
56
+ stride: int = 1,
57
+ padding: int = 1,
58
+ bias: bool = True,
59
+ groups: int = 1,
60
+ activation: str = "ReLU",
61
+ dropout_perc: float = 0,
62
+ use_batch_norm: bool = False,
63
+ ) -> None:
64
+ """
65
+ Constructor.
66
+
67
+ Parameters
68
+ ----------
69
+ conv_dim : int
70
+ Number of dimension of the convolutions, 2 or 3.
71
+ in_channels : int
72
+ Number of input channels.
73
+ out_channels : int
74
+ Number of output channels.
75
+ intermediate_channel_multiplier : int, optional
76
+ Multiplied for the number of output channels, by default 1.
77
+ stride : int, optional
78
+ Stride of the convolutions, by default 1.
79
+ padding : int, optional
80
+ Padding of the convolutions, by default 1.
81
+ bias : bool, optional
82
+ Bias of the convolutions, by default True.
83
+ groups : int, optional
84
+ Controls the connections between inputs and outputs, by default 1.
85
+ activation : str, optional
86
+ Activation function, by default "ReLU".
87
+ dropout_perc : float, optional
88
+ Dropout percentage, by default 0.
89
+ use_batch_norm : bool, optional
90
+ Use batch norm, by default False.
91
+ """
92
+ super().__init__()
93
+ self.use_batch_norm = use_batch_norm
94
+ self.conv1 = getattr(nn, f"Conv{conv_dim}d")(
95
+ in_channels,
96
+ out_channels * intermediate_channel_multiplier,
97
+ kernel_size=3,
98
+ stride=stride,
99
+ padding=padding,
100
+ bias=bias,
101
+ groups=groups,
102
+ )
103
+
104
+ self.conv2 = getattr(nn, f"Conv{conv_dim}d")(
105
+ out_channels * intermediate_channel_multiplier,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=stride,
109
+ padding=padding,
110
+ bias=bias,
111
+ groups=groups,
112
+ )
113
+
114
+ self.batch_norm1 = getattr(nn, f"BatchNorm{conv_dim}d")(
115
+ out_channels * intermediate_channel_multiplier
116
+ )
117
+ self.batch_norm2 = getattr(nn, f"BatchNorm{conv_dim}d")(out_channels)
118
+
119
+ self.dropout = (
120
+ getattr(nn, f"Dropout{conv_dim}d")(dropout_perc)
121
+ if dropout_perc > 0
122
+ else None
123
+ )
124
+ self.activation = (
125
+ getattr(nn, f"{activation}")() if activation is not None else nn.Identity()
126
+ )
127
+
128
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
129
+ """
130
+ Forward pass.
131
+
132
+ Parameters
133
+ ----------
134
+ x : torch.Tensor
135
+ Input tensor.
136
+
137
+ Returns
138
+ -------
139
+ torch.Tensor
140
+ Output tensor.
141
+ """
142
+ if self.use_batch_norm:
143
+ x = self.conv1(x)
144
+ x = self.batch_norm1(x)
145
+ x = self.activation(x)
146
+ x = self.conv2(x)
147
+ x = self.batch_norm2(x)
148
+ x = self.activation(x)
149
+ else:
150
+ x = self.conv1(x)
151
+ x = self.activation(x)
152
+ x = self.conv2(x)
153
+ x = self.activation(x)
154
+ if self.dropout is not None:
155
+ x = self.dropout(x)
156
+ return x
157
+
158
+
159
+ def _unpack_kernel_size(
160
+ kernel_size: Union[tuple[int, ...], int], dim: int
161
+ ) -> tuple[int, ...]:
162
+ """Unpack kernel_size to a tuple of ints.
163
+
164
+ Inspired by Kornia implementation. TODO: link
165
+
166
+ Parameters
167
+ ----------
168
+ kernel_size : Union[tuple[int, ...], int]
169
+ Kernel size.
170
+ dim : int
171
+ Number of dimensions.
172
+
173
+ Returns
174
+ -------
175
+ tuple[int, ...]
176
+ Kernel size tuple.
177
+ """
178
+ if isinstance(kernel_size, int):
179
+ kernel_dims = tuple([kernel_size for _ in range(dim)])
180
+ else:
181
+ kernel_dims = kernel_size
182
+ return kernel_dims
183
+
184
+
185
+ def _compute_zero_padding(
186
+ kernel_size: Union[tuple[int, ...], int], dim: int
187
+ ) -> tuple[int, ...]:
188
+ """Utility function that computes zero padding tuple.
189
+
190
+ Parameters
191
+ ----------
192
+ kernel_size : Union[tuple[int, ...], int]
193
+ Kernel size.
194
+ dim : int
195
+ Number of dimensions.
196
+
197
+ Returns
198
+ -------
199
+ tuple[int, ...]
200
+ Zero padding tuple.
201
+ """
202
+ kernel_dims = _unpack_kernel_size(kernel_size, dim)
203
+ return tuple([(kd - 1) // 2 for kd in kernel_dims])
204
+
205
+
206
+ def get_pascal_kernel_1d(
207
+ kernel_size: int,
208
+ norm: bool = False,
209
+ *,
210
+ device: torch.device | None = None,
211
+ dtype: torch.dtype | None = None,
212
+ ) -> torch.Tensor:
213
+ """Generate Yang Hui triangle (Pascal's triangle) for a given number.
214
+
215
+ Inspired by Kornia implementation. TODO link
216
+
217
+ Parameters
218
+ ----------
219
+ kernel_size : int
220
+ Kernel size.
221
+ norm : bool
222
+ Normalize the kernel, by default False.
223
+ device : Optional[torch.device]
224
+ Device of the tensor, by default None.
225
+ dtype : Optional[torch.dtype]
226
+ Data type of the tensor, by default None.
227
+
228
+ Returns
229
+ -------
230
+ torch.Tensor
231
+ Pascal kernel.
232
+
233
+ Examples
234
+ --------
235
+ >>> get_pascal_kernel_1d(1)
236
+ tensor([1.])
237
+ >>> get_pascal_kernel_1d(2)
238
+ tensor([1., 1.])
239
+ >>> get_pascal_kernel_1d(3)
240
+ tensor([1., 2., 1.])
241
+ >>> get_pascal_kernel_1d(4)
242
+ tensor([1., 3., 3., 1.])
243
+ >>> get_pascal_kernel_1d(5)
244
+ tensor([1., 4., 6., 4., 1.])
245
+ >>> get_pascal_kernel_1d(6)
246
+ tensor([ 1., 5., 10., 10., 5., 1.])
247
+ """
248
+ pre: list[float] = []
249
+ cur: list[float] = []
250
+ for i in range(kernel_size):
251
+ cur = [1.0] * (i + 1)
252
+
253
+ for j in range(1, i // 2 + 1):
254
+ value = pre[j - 1] + pre[j]
255
+ cur[j] = value
256
+ if i != 2 * j:
257
+ cur[-j - 1] = value
258
+ pre = cur
259
+
260
+ out = torch.tensor(cur, device=device, dtype=dtype)
261
+
262
+ if norm:
263
+ out = out / out.sum()
264
+
265
+ return out
266
+
267
+
268
+ def _get_pascal_kernel_nd(
269
+ kernel_size: Union[tuple[int, int], int],
270
+ norm: bool = True,
271
+ dim: int = 2,
272
+ *,
273
+ device: torch.device | None = None,
274
+ dtype: torch.dtype | None = None,
275
+ ) -> torch.Tensor:
276
+ """Generate pascal filter kernel by kernel size.
277
+
278
+ If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
279
+ otherwise the kernel will be shaped as kernel_size
280
+
281
+ Inspired by Kornia implementation.
282
+
283
+ Parameters
284
+ ----------
285
+ kernel_size : Union[tuple[int, int], int]
286
+ Kernel size for the pascal kernel.
287
+ norm : bool
288
+ Normalize the kernel, by default True.
289
+ dim : int
290
+ Number of dimensions, by default 2.
291
+ device : Optional[torch.device]
292
+ Device of the tensor, by default None.
293
+ dtype : Optional[torch.dtype]
294
+ Data type of the tensor, by default None.
295
+
296
+ Returns
297
+ -------
298
+ torch.Tensor
299
+ Pascal kernel.
300
+
301
+ Examples
302
+ --------
303
+ >>> _get_pascal_kernel_nd(1)
304
+ tensor([[1.]])
305
+ >>> _get_pascal_kernel_nd(4)
306
+ tensor([[0.0156, 0.0469, 0.0469, 0.0156],
307
+ [0.0469, 0.1406, 0.1406, 0.0469],
308
+ [0.0469, 0.1406, 0.1406, 0.0469],
309
+ [0.0156, 0.0469, 0.0469, 0.0156]])
310
+ >>> _get_pascal_kernel_nd(4, norm=False)
311
+ tensor([[1., 3., 3., 1.],
312
+ [3., 9., 9., 3.],
313
+ [3., 9., 9., 3.],
314
+ [1., 3., 3., 1.]])
315
+ """
316
+ kernel_dims = _unpack_kernel_size(kernel_size, dim)
317
+
318
+ kernel = [
319
+ get_pascal_kernel_1d(kd, device=device, dtype=dtype) for kd in kernel_dims
320
+ ]
321
+
322
+ if dim == 2:
323
+ kernel = kernel[0][:, None] * kernel[1][None, :]
324
+ elif dim == 3:
325
+ kernel = (
326
+ kernel[0][:, None, None]
327
+ * kernel[1][None, :, None]
328
+ * kernel[2][None, None, :]
329
+ )
330
+ if norm:
331
+ kernel = kernel / torch.sum(kernel)
332
+ return kernel
333
+
334
+
335
+ def _max_blur_pool_by_kernel2d(
336
+ x: torch.Tensor,
337
+ kernel: torch.Tensor,
338
+ stride: int,
339
+ max_pool_size: int,
340
+ ceil_mode: bool,
341
+ ) -> torch.Tensor:
342
+ """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
343
+
344
+ Inspired by Kornia implementation.
345
+
346
+ Parameters
347
+ ----------
348
+ x : torch.Tensor
349
+ Input tensor.
350
+ kernel : torch.Tensor
351
+ Kernel tensor.
352
+ stride : int
353
+ Stride.
354
+ max_pool_size : int
355
+ Maximum pool size.
356
+ ceil_mode : bool
357
+ Ceil mode, by default False. Set to True to match output size of conv2d.
358
+
359
+ Returns
360
+ -------
361
+ torch.Tensor
362
+ Output tensor.
363
+ """
364
+ # compute local maxima
365
+ x = F.max_pool2d(
366
+ x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
367
+ )
368
+ # blur and downsample
369
+ padding = _compute_zero_padding((kernel.shape[-2], kernel.shape[-1]), dim=2)
370
+ return F.conv2d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
371
+
372
+
373
+ def _max_blur_pool_by_kernel3d(
374
+ x: torch.Tensor,
375
+ kernel: torch.Tensor,
376
+ stride: int,
377
+ max_pool_size: int,
378
+ ceil_mode: bool,
379
+ ) -> torch.Tensor:
380
+ """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
381
+
382
+ Inspired by Kornia implementation.
383
+
384
+ Parameters
385
+ ----------
386
+ x : torch.Tensor
387
+ Input tensor.
388
+ kernel : torch.Tensor
389
+ Kernel tensor.
390
+ stride : int
391
+ Stride.
392
+ max_pool_size : int
393
+ Maximum pool size.
394
+ ceil_mode : bool
395
+ Ceil mode, by default False. Set to True to match output size of conv2d.
396
+
397
+ Returns
398
+ -------
399
+ torch.Tensor
400
+ Output tensor.
401
+ """
402
+ # compute local maxima
403
+ x = F.max_pool3d(
404
+ x, kernel_size=max_pool_size, padding=0, stride=1, ceil_mode=ceil_mode
405
+ )
406
+ # blur and downsample
407
+ padding = _compute_zero_padding(
408
+ (kernel.shape[-3], kernel.shape[-2], kernel.shape[-1]), dim=3
409
+ )
410
+ return F.conv3d(x, kernel, padding=padding, stride=stride, groups=x.size(1))
411
+
412
+
413
+ class MaxBlurPool(nn.Module):
414
+ """Compute pools and blurs and downsample a given feature map.
415
+
416
+ Inspired by Kornia MaxBlurPool implementation. Equivalent to
417
+ ```nn.Sequential(nn.MaxPool2d(...), BlurPool2D(...))```
418
+
419
+ Parameters
420
+ ----------
421
+ dim : int
422
+ Toggles between 2D and 3D.
423
+ kernel_size : Union[tuple[int, int], int]
424
+ Kernel size for max pooling.
425
+ stride : int
426
+ Stride for pooling.
427
+ max_pool_size : int
428
+ Max kernel size for max pooling.
429
+ ceil_mode : bool
430
+ Ceil mode, by default False. Set to True to match output size of conv2d.
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ dim: int,
436
+ kernel_size: Union[tuple[int, int], int],
437
+ stride: int = 2,
438
+ max_pool_size: int = 2,
439
+ ceil_mode: bool = False,
440
+ ) -> None:
441
+ """Constructor.
442
+
443
+ Parameters
444
+ ----------
445
+ dim : int
446
+ Dimension of the convolution.
447
+ kernel_size : Union[tuple[int, int], int]
448
+ Kernel size for max pooling.
449
+ stride : int, optional
450
+ Stride, by default 2.
451
+ max_pool_size : int, optional
452
+ Maximum pool size, by default 2.
453
+ ceil_mode : bool, optional
454
+ Ceil mode, by default False. Set to True to match output size of conv2d.
455
+ """
456
+ super().__init__()
457
+ self.dim = dim
458
+ self.kernel_size = kernel_size
459
+ self.stride = stride
460
+ self.max_pool_size = max_pool_size
461
+ self.ceil_mode = ceil_mode
462
+ kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
463
+ self.register_buffer("kernel", kernel, persistent=False)
464
+
465
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
466
+ """Forward pass of the function.
467
+
468
+ Parameters
469
+ ----------
470
+ x : torch.Tensor
471
+ Input tensor.
472
+
473
+ Returns
474
+ -------
475
+ torch.Tensor
476
+ Output tensor.
477
+ """
478
+ kernel = self.kernel.to(dtype=x.dtype)
479
+ num_channels = int(x.size(1))
480
+ if self.dim == 2:
481
+ return _max_blur_pool_by_kernel2d(
482
+ x,
483
+ kernel.repeat((num_channels, 1, 1, 1)),
484
+ self.stride,
485
+ self.max_pool_size,
486
+ self.ceil_mode,
487
+ )
488
+ else:
489
+ return _max_blur_pool_by_kernel3d(
490
+ x,
491
+ kernel.repeat((num_channels, 1, 1, 1, 1)),
492
+ self.stride,
493
+ self.max_pool_size,
494
+ self.ceil_mode,
495
+ )
@@ -0,0 +1,3 @@
1
+ __all__ = ["LadderVAE"]
2
+
3
+ from .lvae import LadderVAE