careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,449 @@
1
+ """
2
+ UNet model.
3
+
4
+ A UNet encoder, decoder and complete model.
5
+ """
6
+
7
+ from typing import Any, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from ..config.support import SupportedActivation
13
+ from .activation import get_activation
14
+ from .layers import Conv_Block, MaxBlurPool
15
+
16
+
17
+ class UnetEncoder(nn.Module):
18
+ """
19
+ Unet encoder pathway.
20
+
21
+ Parameters
22
+ ----------
23
+ conv_dim : int
24
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
25
+ in_channels : int, optional
26
+ Number of input channels, by default 1.
27
+ depth : int, optional
28
+ Number of encoder blocks, by default 3.
29
+ num_channels_init : int, optional
30
+ Number of channels in the first encoder block, by default 64.
31
+ use_batch_norm : bool, optional
32
+ Whether to use batch normalization, by default True.
33
+ dropout : float, optional
34
+ Dropout probability, by default 0.0.
35
+ pool_kernel : int, optional
36
+ Kernel size for the max pooling layers, by default 2.
37
+ n2v2 : bool, optional
38
+ Whether to use N2V2 architecture, by default False.
39
+ groups : int, optional
40
+ Number of blocked connections from input channels to output
41
+ channels, by default 1.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ conv_dim: int,
47
+ in_channels: int = 1,
48
+ depth: int = 3,
49
+ num_channels_init: int = 64,
50
+ use_batch_norm: bool = True,
51
+ dropout: float = 0.0,
52
+ pool_kernel: int = 2,
53
+ n2v2: bool = False,
54
+ groups: int = 1,
55
+ ) -> None:
56
+ """
57
+ Constructor.
58
+
59
+ Parameters
60
+ ----------
61
+ conv_dim : int
62
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
63
+ in_channels : int, optional
64
+ Number of input channels, by default 1.
65
+ depth : int, optional
66
+ Number of encoder blocks, by default 3.
67
+ num_channels_init : int, optional
68
+ Number of channels in the first encoder block, by default 64.
69
+ use_batch_norm : bool, optional
70
+ Whether to use batch normalization, by default True.
71
+ dropout : float, optional
72
+ Dropout probability, by default 0.0.
73
+ pool_kernel : int, optional
74
+ Kernel size for the max pooling layers, by default 2.
75
+ n2v2 : bool, optional
76
+ Whether to use N2V2 architecture, by default False.
77
+ groups : int, optional
78
+ Number of blocked connections from input channels to output
79
+ channels, by default 1.
80
+ """
81
+ super().__init__()
82
+
83
+ self.pooling = (
84
+ getattr(nn, f"MaxPool{conv_dim}d")(kernel_size=pool_kernel)
85
+ if not n2v2
86
+ else MaxBlurPool(dim=conv_dim, kernel_size=3, max_pool_size=pool_kernel)
87
+ )
88
+
89
+ encoder_blocks = []
90
+
91
+ for n in range(depth):
92
+ out_channels = num_channels_init * (2**n) * groups
93
+ in_channels = in_channels if n == 0 else out_channels // 2
94
+ encoder_blocks.append(
95
+ Conv_Block(
96
+ conv_dim,
97
+ in_channels=in_channels,
98
+ out_channels=out_channels,
99
+ dropout_perc=dropout,
100
+ use_batch_norm=use_batch_norm,
101
+ groups=groups,
102
+ )
103
+ )
104
+ encoder_blocks.append(self.pooling)
105
+ self.encoder_blocks = nn.ModuleList(encoder_blocks)
106
+
107
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
108
+ """
109
+ Forward pass.
110
+
111
+ Parameters
112
+ ----------
113
+ x : torch.Tensor
114
+ Input tensor.
115
+
116
+ Returns
117
+ -------
118
+ list[torch.Tensor]
119
+ Output of each encoder block (skip connections) and final output of the
120
+ encoder.
121
+ """
122
+ encoder_features = []
123
+ for module in self.encoder_blocks:
124
+ x = module(x)
125
+ if isinstance(module, Conv_Block):
126
+ encoder_features.append(x)
127
+ features = [x, *encoder_features]
128
+ return features
129
+
130
+
131
+ class UnetDecoder(nn.Module):
132
+ """
133
+ Unet decoder pathway.
134
+
135
+ Parameters
136
+ ----------
137
+ conv_dim : int
138
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
139
+ depth : int, optional
140
+ Number of decoder blocks, by default 3.
141
+ num_channels_init : int, optional
142
+ Number of channels in the first encoder block, by default 64.
143
+ use_batch_norm : bool, optional
144
+ Whether to use batch normalization, by default True.
145
+ dropout : float, optional
146
+ Dropout probability, by default 0.0.
147
+ n2v2 : bool, optional
148
+ Whether to use N2V2 architecture, by default False.
149
+ groups : int, optional
150
+ Number of blocked connections from input channels to output
151
+ channels, by default 1.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ conv_dim: int,
157
+ depth: int = 3,
158
+ num_channels_init: int = 64,
159
+ use_batch_norm: bool = True,
160
+ dropout: float = 0.0,
161
+ n2v2: bool = False,
162
+ groups: int = 1,
163
+ ) -> None:
164
+ """
165
+ Constructor.
166
+
167
+ Parameters
168
+ ----------
169
+ conv_dim : int
170
+ Number of dimension of the convolution layers, 2 for 2D or 3 for 3D.
171
+ depth : int, optional
172
+ Number of decoder blocks, by default 3.
173
+ num_channels_init : int, optional
174
+ Number of channels in the first encoder block, by default 64.
175
+ use_batch_norm : bool, optional
176
+ Whether to use batch normalization, by default True.
177
+ dropout : float, optional
178
+ Dropout probability, by default 0.0.
179
+ n2v2 : bool, optional
180
+ Whether to use N2V2 architecture, by default False.
181
+ groups : int, optional
182
+ Number of blocked connections from input channels to output
183
+ channels, by default 1.
184
+ """
185
+ super().__init__()
186
+
187
+ upsampling = nn.Upsample(
188
+ scale_factor=2, mode="bilinear" if conv_dim == 2 else "trilinear"
189
+ )
190
+ in_channels = out_channels = num_channels_init * groups * (2 ** (depth - 1))
191
+
192
+ self.n2v2 = n2v2
193
+ self.groups = groups
194
+
195
+ self.bottleneck = Conv_Block(
196
+ conv_dim,
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ intermediate_channel_multiplier=2,
200
+ use_batch_norm=use_batch_norm,
201
+ dropout_perc=dropout,
202
+ groups=self.groups,
203
+ )
204
+
205
+ decoder_blocks: list[nn.Module] = []
206
+ for n in range(depth):
207
+ decoder_blocks.append(upsampling)
208
+
209
+ in_channels = (num_channels_init * 2 ** (depth - n - 1)) * groups
210
+ # final decoder block has the same number in and out features
211
+ out_channels = in_channels // 2 if n != depth - 1 else in_channels
212
+ if not (n2v2 and (n == depth - 1)):
213
+ in_channels = in_channels * 2 # accounting for skip connection concat
214
+
215
+ decoder_blocks.append(
216
+ Conv_Block(
217
+ conv_dim,
218
+ in_channels=in_channels,
219
+ out_channels=out_channels,
220
+ # TODO: Tensorflow n2v implementation has intermediate channel
221
+ # multiplication for skip_skipone=True but not skip_skipone=False
222
+ # this needs to be benchmarked.
223
+ # final decoder block doesn't multiply the intermediate features
224
+ intermediate_channel_multiplier=2 if n != depth - 1 else 1,
225
+ dropout_perc=dropout,
226
+ activation="ReLU",
227
+ use_batch_norm=use_batch_norm,
228
+ groups=groups,
229
+ )
230
+ )
231
+
232
+ self.decoder_blocks = nn.ModuleList(decoder_blocks)
233
+
234
+ def forward(self, *features: torch.Tensor) -> torch.Tensor:
235
+ """
236
+ Forward pass.
237
+
238
+ Parameters
239
+ ----------
240
+ *features : list[torch.Tensor]
241
+ List containing the output of each encoder block(skip connections) and final
242
+ output of the encoder.
243
+
244
+ Returns
245
+ -------
246
+ torch.Tensor
247
+ Output of the decoder.
248
+ """
249
+ x: torch.Tensor = features[0]
250
+ skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
251
+ depth = len(skip_connections)
252
+
253
+ x = self.bottleneck(x)
254
+
255
+ for i, module in enumerate(self.decoder_blocks):
256
+ x = module(x)
257
+ if isinstance(module, nn.Upsample):
258
+ # divide index by 2 because of upsampling layers
259
+ skip_connection: torch.Tensor = skip_connections[i // 2]
260
+ # top level skip connection not added for n2v2
261
+ if (not self.n2v2) or (self.n2v2 and (i // 2 < depth - 1)):
262
+ x = self._interleave(x, skip_connection, self.groups)
263
+ return x
264
+
265
+ @staticmethod
266
+ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor:
267
+ """Interleave two tensors.
268
+
269
+ Splits the tensors `A` and `B` into equally sized groups along the channel
270
+ axis (axis=1); then concatenates the groups in alternating order along the
271
+ channel axis, starting with the first group from tensor A.
272
+
273
+ Parameters
274
+ ----------
275
+ A : torch.Tensor
276
+ First tensor.
277
+ B : torch.Tensor
278
+ Second tensor.
279
+ groups : int
280
+ The number of groups.
281
+
282
+ Returns
283
+ -------
284
+ torch.Tensor
285
+ Interleaved tensor.
286
+
287
+ Raises
288
+ ------
289
+ ValueError:
290
+ If either of `A` or `B`'s channel axis is not divisible by `groups`.
291
+ """
292
+ if (A.shape[1] % groups != 0) or (B.shape[1] % groups != 0):
293
+ raise ValueError(f"Number of channels not divisible by {groups} groups.")
294
+
295
+ m = A.shape[1] // groups
296
+ n = B.shape[1] // groups
297
+
298
+ A_groups: list[torch.Tensor] = [
299
+ A[:, i * m : (i + 1) * m] for i in range(groups)
300
+ ]
301
+ B_groups: list[torch.Tensor] = [
302
+ B[:, i * n : (i + 1) * n] for i in range(groups)
303
+ ]
304
+
305
+ interleaved = torch.cat(
306
+ [
307
+ tensor_list[i]
308
+ for i in range(groups)
309
+ for tensor_list in [A_groups, B_groups]
310
+ ],
311
+ dim=1,
312
+ )
313
+
314
+ return interleaved
315
+
316
+
317
+ class UNet(nn.Module):
318
+ """
319
+ UNet model.
320
+
321
+ Adapted for PyTorch from:
322
+ https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.
323
+
324
+ Parameters
325
+ ----------
326
+ conv_dims : int
327
+ Number of dimensions of the convolution layers (2 or 3).
328
+ num_classes : int, optional
329
+ Number of classes to predict, by default 1.
330
+ in_channels : int, optional
331
+ Number of input channels, by default 1.
332
+ depth : int, optional
333
+ Number of downsamplings, by default 3.
334
+ num_channels_init : int, optional
335
+ Number of filters in the first convolution layer, by default 64.
336
+ use_batch_norm : bool, optional
337
+ Whether to use batch normalization, by default True.
338
+ dropout : float, optional
339
+ Dropout probability, by default 0.0.
340
+ pool_kernel : int, optional
341
+ Kernel size of the pooling layers, by default 2.
342
+ final_activation : Optional[Callable], optional
343
+ Activation function to use for the last layer, by default None.
344
+ n2v2 : bool, optional
345
+ Whether to use N2V2 architecture, by default False.
346
+ independent_channels : bool
347
+ Whether to train the channels independently, by default True.
348
+ **kwargs : Any
349
+ Additional keyword arguments, unused.
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ conv_dims: int,
355
+ num_classes: int = 1,
356
+ in_channels: int = 1,
357
+ depth: int = 3,
358
+ num_channels_init: int = 64,
359
+ use_batch_norm: bool = True,
360
+ dropout: float = 0.0,
361
+ pool_kernel: int = 2,
362
+ final_activation: Union[SupportedActivation, str] = SupportedActivation.NONE,
363
+ n2v2: bool = False,
364
+ independent_channels: bool = True,
365
+ **kwargs: Any,
366
+ ) -> None:
367
+ """
368
+ Constructor.
369
+
370
+ Parameters
371
+ ----------
372
+ conv_dims : int
373
+ Number of dimensions of the convolution layers (2 or 3).
374
+ num_classes : int, optional
375
+ Number of classes to predict, by default 1.
376
+ in_channels : int, optional
377
+ Number of input channels, by default 1.
378
+ depth : int, optional
379
+ Number of downsamplings, by default 3.
380
+ num_channels_init : int, optional
381
+ Number of filters in the first convolution layer, by default 64.
382
+ use_batch_norm : bool, optional
383
+ Whether to use batch normalization, by default True.
384
+ dropout : float, optional
385
+ Dropout probability, by default 0.0.
386
+ pool_kernel : int, optional
387
+ Kernel size of the pooling layers, by default 2.
388
+ final_activation : Optional[Callable], optional
389
+ Activation function to use for the last layer, by default None.
390
+ n2v2 : bool, optional
391
+ Whether to use N2V2 architecture, by default False.
392
+ independent_channels : bool
393
+ Whether to train parallel independent networks for each channel, by
394
+ default True.
395
+ **kwargs : Any
396
+ Additional keyword arguments, unused.
397
+ """
398
+ super().__init__()
399
+
400
+ groups = in_channels if independent_channels else 1
401
+
402
+ self.encoder = UnetEncoder(
403
+ conv_dims,
404
+ in_channels=in_channels,
405
+ depth=depth,
406
+ num_channels_init=num_channels_init,
407
+ use_batch_norm=use_batch_norm,
408
+ dropout=dropout,
409
+ pool_kernel=pool_kernel,
410
+ n2v2=n2v2,
411
+ groups=groups,
412
+ )
413
+
414
+ self.decoder = UnetDecoder(
415
+ conv_dims,
416
+ depth=depth,
417
+ num_channels_init=num_channels_init,
418
+ use_batch_norm=use_batch_norm,
419
+ dropout=dropout,
420
+ n2v2=n2v2,
421
+ groups=groups,
422
+ )
423
+ self.final_conv = getattr(nn, f"Conv{conv_dims}d")(
424
+ in_channels=num_channels_init * groups,
425
+ out_channels=num_classes,
426
+ kernel_size=1,
427
+ groups=groups,
428
+ )
429
+ self.final_activation = get_activation(final_activation)
430
+
431
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
432
+ """
433
+ Forward pass.
434
+
435
+ Parameters
436
+ ----------
437
+ x : torch.Tensor
438
+ Input tensor.
439
+
440
+ Returns
441
+ -------
442
+ torch.Tensor
443
+ Output of the model.
444
+ """
445
+ encoder_features = self.encoder(x)
446
+ x = self.decoder(*encoder_features)
447
+ x = self.final_conv(x)
448
+ x = self.final_activation(x)
449
+ return x
@@ -0,0 +1,203 @@
1
+ """Placeholder code snippets for noise model training integration.
2
+
3
+ This module contains template/placeholder code that demonstrates how noise model
4
+ training could be integrated into CAREamist. These are reference implementations
5
+ and should not be imported or used directly.
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Union
11
+
12
+ from numpy.typing import NDArray
13
+ from pytorch_lightning.callbacks import Callback
14
+
15
+ from careamics.config.configuration import Configuration
16
+ from careamics.models.lvae.noise_models import (
17
+ GaussianMixtureNoiseModel,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ # In src/careamics/careamist.py (newly added section only)
24
+ def __init__(
25
+ self,
26
+ source: Union[Path, str, Configuration],
27
+ work_dir: Union[Path, str] | None = None,
28
+ callbacks: list[Callback] | None = None,
29
+ enable_progress_bar: bool = True,
30
+ ) -> None:
31
+ """Placeholder __init__ method showing noise model initialization.
32
+
33
+ Parameters
34
+ ----------
35
+ self : object
36
+ CAREamist instance.
37
+ source : Union[Path, str, Configuration]
38
+ Configuration source.
39
+ work_dir : Union[Path, str] | None, optional
40
+ Working directory, by default None.
41
+ callbacks : list[Callback] | None, optional
42
+ List of callbacks, by default None.
43
+ enable_progress_bar : bool, optional
44
+ Whether to show progress bar, by default True.
45
+ """
46
+ # ... existing initialization code ...
47
+
48
+ # Initialize untrained noise models if needed
49
+ self.untrained_noise_models = None
50
+ if (
51
+ hasattr(self.cfg.algorithm_config, "train_noise_model")
52
+ and self.cfg.algorithm_config.train_noise_model_from_data
53
+ ):
54
+ self._initialize_noise_models_for_training()
55
+
56
+
57
+ # In src/careamics/careamist.py
58
+ def train_noise_model(
59
+ self,
60
+ clean_data: Union[Path, str, NDArray],
61
+ noisy_data: Union[Path, str, NDArray],
62
+ learning_rate: float = 1e-1,
63
+ batch_size: int = 250000,
64
+ n_epochs: int = 2000,
65
+ lower_clip: float = 0.0,
66
+ upper_clip: float = 100.0,
67
+ save_noise_models: bool = True,
68
+ ) -> None:
69
+ """Train noise models from clean/noisy data pairs.
70
+
71
+ Parameters
72
+ ----------
73
+ self : object
74
+ CAREamist instance.
75
+ clean_data : Union[Path, str, NDArray]
76
+ Clean (signal) data for training noise models.
77
+ noisy_data : Union[Path, str, NDArray]
78
+ Noisy (observation) data for training noise models.
79
+ learning_rate : float, default=1e-1
80
+ Learning rate for noise model training.
81
+ batch_size : int, default=250000
82
+ Batch size for noise model training.
83
+ n_epochs : int, default=2000
84
+ Number of epochs for noise model training.
85
+ lower_clip : float, default=0.0
86
+ Lower percentile for clipping training data.
87
+ upper_clip : float, default=100.0
88
+ Upper percentile for clipping training data.
89
+ save_noise_models : bool, default=True
90
+ Whether to save trained noise models to disk.
91
+
92
+ Raises
93
+ ------
94
+ ValueError
95
+ If noise models are not initialized for training.
96
+ ValueError
97
+ If data shapes don't match expectations.
98
+ """
99
+ # Check if noise model is initialized (config should have MultiChannelNMConfig)
100
+ if self.cfg.algorithm_config.noise_model is None:
101
+ raise ValueError(
102
+ "No untrained noise models found. Set `train_noise_model=True` "
103
+ "in configuration."
104
+ )
105
+
106
+ # Load data if paths provided (currently NM expects only numpy)
107
+ if isinstance(clean_data, (str, Path)):
108
+ clean_data = self._load_data(clean_data)
109
+ if isinstance(noisy_data, (str, Path)):
110
+ noisy_data = self._load_data(noisy_data)
111
+
112
+ # Type narrowing for mypy
113
+ assert not isinstance(clean_data, (str, Path))
114
+ assert not isinstance(noisy_data, (str, Path))
115
+
116
+ # Validate data shapes
117
+ if clean_data.shape != noisy_data.shape:
118
+ raise ValueError(
119
+ f"Clean and noisy data shapes must match. "
120
+ f"Got clean: {clean_data.shape}, noisy: {noisy_data.shape}"
121
+ )
122
+ # TODO other data shape checks
123
+
124
+ # parameter controlling the number of channels to split for MS, for HDN it's 1
125
+ output_channels = self.cfg.algorithm_config.model.output_channels
126
+
127
+ # Train noise model for each channel
128
+ trained_noise_models = []
129
+ for channel_idx in range(output_channels):
130
+ logger.info(
131
+ f"Training noise model for channel {channel_idx + 1}/{output_channels}"
132
+ )
133
+
134
+ # Extract single channel data
135
+ clean_channel = clean_data[:, channel_idx] # (N, H, W)
136
+ noisy_channel = noisy_data[:, channel_idx] # (N, H, W)
137
+
138
+ # Train noise model for this channel
139
+ noise_model = self.untrained_noise_models[channel_idx]
140
+ noise_model.fit(
141
+ signal=clean_channel,
142
+ observation=noisy_channel,
143
+ learning_rate=learning_rate,
144
+ batch_size=batch_size,
145
+ n_epochs=n_epochs,
146
+ lower_clip=lower_clip,
147
+ upper_clip=upper_clip,
148
+ )
149
+
150
+ trained_noise_models.append(noise_model)
151
+
152
+ # Save individual noise model if requested
153
+ if save_noise_models:
154
+ save_path = self.work_dir / "noise_models"
155
+ noise_model.save(str(save_path), f"noise_model_ch{channel_idx}.npz")
156
+ logger.info(f"Saved noise model for channel {channel_idx} to {save_path}")
157
+
158
+ # Update the algorithm configuration with trained noise models
159
+ self._update_config_with_trained_noise_models(trained_noise_models)
160
+
161
+ logger.info("Noise model training completed successfully")
162
+
163
+
164
+ def _update_config_with_trained_noise_models(
165
+ self, trained_models: list[GaussianMixtureNoiseModel]
166
+ ) -> None:
167
+ """Update algorithm config with trained noise models.
168
+
169
+ Parameters
170
+ ----------
171
+ self : object
172
+ CAREamist instance.
173
+ trained_models : list[GaussianMixtureNoiseModel]
174
+ List of trained noise models, one per channel.
175
+ """
176
+ # Currently the model is initialized in the __init__ of CAREamist
177
+ # multichannel_noise_model_factory inside VAEModule expects paths to noise models
178
+ # Ideally, we change that and call multichannel_noise_model_factory here after the
179
+ # model init and update the parameters of noise models right in the
180
+ # MultiChannelNoiseModel
181
+
182
+
183
+ def _load_data(self, data_path: Union[Path, str]) -> NDArray:
184
+ """Load data from file path.
185
+
186
+ Parameters
187
+ ----------
188
+ self : object
189
+ CAREamist instance.
190
+ data_path : Union[Path, str]
191
+ Path to data file.
192
+
193
+ Returns
194
+ -------
195
+ NDArray
196
+ Loaded data array.
197
+
198
+ Raises
199
+ ------
200
+ NotImplementedError
201
+ This is a placeholder method.
202
+ """
203
+ raise NotImplementedError("Data loading not yet implemented")
@@ -0,0 +1,21 @@
1
+ """Package to house various prediction utilies."""
2
+
3
+ __all__ = [
4
+ "convert_outputs",
5
+ "convert_outputs_microsplit",
6
+ "convert_outputs_pn2v",
7
+ "stitch_prediction",
8
+ "stitch_prediction_single",
9
+ "stitch_prediction_vae",
10
+ ]
11
+
12
+ from .prediction_outputs import (
13
+ convert_outputs,
14
+ convert_outputs_microsplit,
15
+ convert_outputs_pn2v,
16
+ )
17
+ from .stitch_prediction import (
18
+ stitch_prediction,
19
+ stitch_prediction_single,
20
+ stitch_prediction_vae,
21
+ )