careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,20 @@
1
+ """Patch filtering strategies."""
2
+
3
+ __all__ = [
4
+ "CoordinateFilterProtocol",
5
+ "MaskCoordFilter",
6
+ "MaxPatchFilter",
7
+ "MeanStdPatchFilter",
8
+ "PatchFilterProtocol",
9
+ "ShannonPatchFilter",
10
+ "create_coord_filter",
11
+ "create_patch_filter",
12
+ ]
13
+
14
+ from .coordinate_filter_protocol import CoordinateFilterProtocol
15
+ from .filter_factory import create_coord_filter, create_patch_filter
16
+ from .mask_filter import MaskCoordFilter
17
+ from .max_filter import MaxPatchFilter
18
+ from .mean_std_filter import MeanStdPatchFilter
19
+ from .patch_filter_protocol import PatchFilterProtocol
20
+ from .shannon_filter import ShannonPatchFilter
@@ -0,0 +1,27 @@
1
+ """A protocol for patch filtering."""
2
+
3
+ from typing import Protocol
4
+
5
+ from careamics.dataset_ng.patching_strategies import PatchSpecs
6
+
7
+
8
+ class CoordinateFilterProtocol(Protocol):
9
+ """
10
+ An interface for implementing coordinate filtering strategies.
11
+ """
12
+
13
+ def filter_out(self, patch: PatchSpecs) -> bool:
14
+ """
15
+ Determine whether to filter out a given patch based on its coordinates.
16
+
17
+ Parameters
18
+ ----------
19
+ patch : PatchSpecs
20
+ The patch coordinates to evaluate.
21
+
22
+ Returns
23
+ -------
24
+ bool
25
+ True if the patch should be filtered out (excluded), False otherwise.
26
+ """
27
+ ...
@@ -0,0 +1,95 @@
1
+ """Factories for coordinate and patch filters."""
2
+
3
+ from typing import Union
4
+
5
+ from careamics.config.data.patch_filter import (
6
+ FilterConfig,
7
+ MaskFilterConfig,
8
+ MaxFilterConfig,
9
+ MeanSTDFilterConfig,
10
+ ShannonFilterConfig,
11
+ )
12
+ from careamics.config.support.supported_filters import (
13
+ SupportedCoordinateFilters,
14
+ SupportedPatchFilters,
15
+ )
16
+ from careamics.dataset_ng.image_stack import GenericImageStack
17
+ from careamics.dataset_ng.patch_extractor import PatchExtractor
18
+
19
+ from .mask_filter import MaskCoordFilter
20
+ from .max_filter import MaxPatchFilter
21
+ from .mean_std_filter import MeanStdPatchFilter
22
+ from .shannon_filter import ShannonPatchFilter
23
+
24
+ PatchFilter = Union[
25
+ MaxPatchFilter,
26
+ MeanStdPatchFilter,
27
+ ShannonPatchFilter,
28
+ ]
29
+
30
+
31
+ CoordFilter = Union[MaskCoordFilter]
32
+
33
+
34
+ def create_coord_filter(
35
+ filter_model: FilterConfig, mask: PatchExtractor[GenericImageStack]
36
+ ) -> CoordFilter:
37
+ """Factory function to create coordinate filter instances based on the filter name.
38
+
39
+ Parameters
40
+ ----------
41
+ filter_model : FilterModel
42
+ Pydantic model of the filter to be created.
43
+ mask : PatchExtractor[GenericImageStack]
44
+ Mask extractor to be used for the mask filter.
45
+
46
+ Returns
47
+ -------
48
+ CoordFilter
49
+ Instance of the mask patch filter.
50
+ """
51
+ if filter_model.name == SupportedCoordinateFilters.MASK:
52
+ assert isinstance(filter_model, MaskFilterConfig)
53
+ return MaskCoordFilter(
54
+ mask_extractor=mask,
55
+ coverage=filter_model.coverage,
56
+ p=filter_model.p,
57
+ seed=filter_model.seed,
58
+ )
59
+ else:
60
+ raise ValueError(f"Unknown filter name: {filter_model}")
61
+
62
+
63
+ def create_patch_filter(filter_model: FilterConfig) -> PatchFilter:
64
+ """Factory function to create patch filter instances based on the filter name.
65
+
66
+ Parameters
67
+ ----------
68
+ filter_model : FilterModel
69
+ Pydantic model of the filter to be created.
70
+
71
+ Returns
72
+ -------
73
+ PatchFilter
74
+ Instance of the requested patch filter.
75
+ """
76
+ if filter_model.name == SupportedPatchFilters.MAX:
77
+ assert isinstance(filter_model, MaxFilterConfig)
78
+ return MaxPatchFilter(
79
+ threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
80
+ )
81
+ elif filter_model.name == SupportedPatchFilters.MEANSTD:
82
+ assert isinstance(filter_model, MeanSTDFilterConfig)
83
+ return MeanStdPatchFilter(
84
+ mean_threshold=filter_model.mean_threshold,
85
+ std_threshold=filter_model.std_threshold,
86
+ p=filter_model.p,
87
+ seed=filter_model.seed,
88
+ )
89
+ elif filter_model.name == SupportedPatchFilters.SHANNON:
90
+ assert isinstance(filter_model, ShannonFilterConfig)
91
+ return ShannonPatchFilter(
92
+ threshold=filter_model.threshold, p=filter_model.p, seed=filter_model.seed
93
+ )
94
+ else:
95
+ raise ValueError(f"Unknown filter name: {filter_model}")
@@ -0,0 +1,96 @@
1
+ """Filter using an image mask."""
2
+
3
+ import numpy as np
4
+
5
+ from careamics.dataset_ng.image_stack import GenericImageStack
6
+ from careamics.dataset_ng.patch_extractor import PatchExtractor
7
+ from careamics.dataset_ng.patch_filter.coordinate_filter_protocol import (
8
+ CoordinateFilterProtocol,
9
+ )
10
+ from careamics.dataset_ng.patching_strategies import PatchSpecs
11
+
12
+
13
+ # TODO is it more intuitive to have a negative mask? (mask of what to avoid)
14
+ class MaskCoordFilter(CoordinateFilterProtocol):
15
+ """
16
+ Filter patch coordinates based on an image mask.
17
+
18
+ Attributes
19
+ ----------
20
+ mask_extractor : PatchExtractor[GenericImageStack]
21
+ Patch extractor for the binary mask to use for filtering.
22
+ coverage_perc : float
23
+ Minimum percentage of masked pixels required to keep a patch.
24
+ p : float
25
+ Probability of applying the filter to a patch.
26
+ rng : np.random.Generator
27
+ Random number generator for stochastic filtering.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ mask_extractor: PatchExtractor[GenericImageStack],
33
+ coverage: float,
34
+ p: float = 1.0,
35
+ seed: int | None = None,
36
+ ) -> None:
37
+ """
38
+ Create a MaskCoordFilter.
39
+
40
+ This filter removes patches who fall below a threshold of masked pixels
41
+ percentage. The mask is expected to be a positive mask where masked pixels
42
+ correspond to regions of interest.
43
+
44
+ Parameters
45
+ ----------
46
+ mask_extractor : PatchExtractor[GenericImageStack]
47
+ The patch extractor for the mask used for filtering.
48
+ coverage : float
49
+ Minimum percentage of masked pixels required to keep a patch. Must be
50
+ between 0 and 1.
51
+ p : float, default=1
52
+ Probability of applying the filter to a patch. Must be between 0 and 1.
53
+ seed : int | None, default=None
54
+ Seed for the random number generator for reproducibility.
55
+
56
+ Raises
57
+ ------
58
+ ValueError
59
+ If coverage is not between 0 and 1.
60
+ ValueError
61
+ If p is not between 0 and 1.
62
+ """
63
+
64
+ if not (0 <= coverage <= 1):
65
+ raise ValueError("Probability p must be between 0 and 1.")
66
+ if not (0 <= p <= 1):
67
+ raise ValueError("Probability p must be between 0 and 1.")
68
+
69
+ self.mask_extractor = mask_extractor
70
+ self.coverage = coverage
71
+
72
+ self.p = p
73
+ self.rng = np.random.default_rng(seed)
74
+
75
+ def filter_out(self, patch_specs: PatchSpecs) -> bool:
76
+ """
77
+ Determine whether to filter out a patch based an image mask.
78
+
79
+ Parameters
80
+ ----------
81
+ patch : PatchSpecs
82
+ The patch coordinates to evaluate.
83
+
84
+ Returns
85
+ -------
86
+ bool
87
+ True if the patch should be filtered out, False otherwise.
88
+ """
89
+
90
+ if self.rng.uniform(0, 1) < self.p:
91
+ mask_patch = self.mask_extractor.extract_patch(**patch_specs)
92
+
93
+ masked_fraction = np.sum(mask_patch) / mask_patch.size
94
+ if masked_fraction < self.coverage:
95
+ return True
96
+ return False
@@ -0,0 +1,188 @@
1
+ """Filter patch using a maximum filter."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import numpy as np
6
+ from scipy.ndimage import maximum_filter
7
+ from tqdm import tqdm
8
+
9
+ from careamics.dataset_ng.image_stack_loader import load_arrays
10
+ from careamics.dataset_ng.patch_extractor import PatchExtractor
11
+ from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
12
+ from careamics.dataset_ng.patching_strategies import TilingStrategy
13
+ from careamics.utils import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class MaxPatchFilter(PatchFilterProtocol):
19
+ """
20
+ A patch filter based on thresholding the maximum filter of the patch.
21
+
22
+ Inspired by the CSBDeep approach.
23
+
24
+ Attributes
25
+ ----------
26
+ threshold : float
27
+ Threshold for the maximum filter of the patch.
28
+ p : float
29
+ Probability of applying the filter to a patch.
30
+ rng : np.random.Generator
31
+ Random number generator for stochastic filtering.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ threshold: float,
37
+ p: float = 1.0,
38
+ threshold_ratio: float = 0.25,
39
+ seed: int | None = None,
40
+ ) -> None:
41
+ """
42
+ Create a MaxPatchFilter.
43
+
44
+ This filter removes patches whose maximum filter valuepixels are below a
45
+ specified threshold.
46
+
47
+ Parameters
48
+ ----------
49
+ threshold : float
50
+ Threshold for the maximum filter of the patch.
51
+ p : float, default=1
52
+ Probability of applying the filter to a patch. Must be between 0 and 1.
53
+ threshold_ratio : float, default=0.25
54
+ Ratio of pixels that must be below threshold for patch to be filtered out.
55
+ Must be between 0 and 1.
56
+ seed : int | None, default=None
57
+ Seed for the random number generator for reproducibility.
58
+ """
59
+ self.threshold = threshold
60
+ self.threshold_ratio = threshold_ratio
61
+ self.p = p
62
+ self.rng = np.random.default_rng(seed)
63
+
64
+ def filter_out(self, patch: np.ndarray) -> bool:
65
+ if self.rng.uniform(0, 1) < self.p:
66
+
67
+ if np.max(patch) < self.threshold:
68
+ return True
69
+
70
+ patch_shape = [(p // 2 if p > 1 else 1) for p in patch.shape]
71
+ filtered = maximum_filter(patch, patch_shape, mode="constant")
72
+ return np.mean(filtered < self.threshold) > self.threshold_ratio
73
+
74
+ return False
75
+
76
+ @staticmethod
77
+ def filter_map(
78
+ image: np.ndarray,
79
+ patch_size: Sequence[int],
80
+ ) -> np.ndarray:
81
+ """
82
+ Compute the maximum map of an image.
83
+
84
+ The map is computed over non-overlapping patches. This method can be used
85
+ to assess a useful threshold for the MaxPatchFilter filter.
86
+
87
+ Parameters
88
+ ----------
89
+ image : numpy.NDArray
90
+ The image for which to compute the map, must be 2D or 3D.
91
+ patch_size : Sequence[int]
92
+ The size of the patches to compute the map over. Must be a sequence
93
+ of two integers.
94
+
95
+ Returns
96
+ -------
97
+ numpy.NDArray
98
+ The max map of the patch.
99
+
100
+ Raises
101
+ ------
102
+ ValueError
103
+ If the image is not 2D or 3D.
104
+
105
+ Example
106
+ -------
107
+ The `filter_map` method can be used to assess a useful threshold for the
108
+ Shannon entropy filter. Below is an example of how to compute and visualize
109
+ the Shannon entropy map of a random image and visualize thresholded versions
110
+ of the map.
111
+ >>> import numpy as np
112
+ >>> from matplotlib import pyplot as plt
113
+ >>> from careamics.dataset_ng.patch_filter import MaxPatchFilter
114
+ >>> rng = np.random.default_rng(42)
115
+ >>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
116
+ >>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
117
+ >>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
118
+ >>> patch_size = (16, 16)
119
+ >>> max_filtered = MaxPatchFilter.filter_map(image, patch_size)
120
+ >>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
121
+ >>> for i, thresh in enumerate([50 + i*5 for i in range(5)]):
122
+ ... ax[i].imshow(max_filtered >= thresh, cmap="gray") # doctest: +SKIP
123
+ ... ax[i].set_title(f"Threshold: {thresh}") # doctest: +SKIP
124
+ >>> plt.show() # doctest: +SKIP
125
+ """
126
+ if len(image.shape) < 2 or len(image.shape) > 3:
127
+ raise ValueError("Image must be 2D or 3D.")
128
+
129
+ axes = "YX" if len(patch_size) == 2 else "ZYX"
130
+
131
+ max_filtered = np.zeros_like(image, dtype=float)
132
+
133
+ image_stacks = load_arrays(source=[image], axes=axes)
134
+ extractor = PatchExtractor(image_stacks)
135
+ tiling = TilingStrategy(
136
+ data_shapes=[(1, 1, *image.shape)],
137
+ patch_size=patch_size,
138
+ overlaps=(0,) * len(patch_size), # no overlap
139
+ )
140
+ max_patch_size = [p // 2 for p in patch_size]
141
+
142
+ for idx in tqdm(range(tiling.n_patches), desc="Computing max map"):
143
+ patch_spec = tiling.get_patch_spec(idx)
144
+ patch = extractor.extract_patch(
145
+ data_idx=0,
146
+ sample_idx=0,
147
+ coords=patch_spec["coords"],
148
+ patch_size=patch_size,
149
+ )
150
+
151
+ coordinates = tuple(
152
+ slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
153
+ for i, p in enumerate(patch_size)
154
+ )
155
+ max_filtered[coordinates] = maximum_filter(
156
+ patch.squeeze(), max_patch_size, mode="constant"
157
+ )
158
+
159
+ return max_filtered
160
+
161
+ @staticmethod
162
+ def apply_filter(
163
+ filter_map: np.ndarray,
164
+ threshold: float,
165
+ ) -> np.ndarray:
166
+ """
167
+ Apply the max filter to a filter map.
168
+
169
+ The filter map is the output of the `filter_map` method.
170
+
171
+ Parameters
172
+ ----------
173
+ filter_map : numpy.NDArray
174
+ The max filter map of the image.
175
+ threshold : float
176
+ The threshold to apply to the filter map.
177
+
178
+ Returns
179
+ -------
180
+ numpy.NDArray
181
+ A boolean array where True indicates that the patch should be kept
182
+ (not filtered out) and False indicates that the patch should be filtered
183
+ out.
184
+ """
185
+ threshold_map = filter_map >= threshold
186
+ coverage = np.sum(threshold_map) * 100 / threshold_map.size
187
+ logger.info(f"Image coverage: {coverage:.2f}%")
188
+ return threshold_map
@@ -0,0 +1,218 @@
1
+ """Filter using mean and standard deviation thresholds."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ from careamics.dataset_ng.image_stack_loader import load_arrays
9
+ from careamics.dataset_ng.patch_extractor import PatchExtractor
10
+ from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
11
+ from careamics.dataset_ng.patching_strategies import TilingStrategy
12
+
13
+
14
+ class MeanStdPatchFilter(PatchFilterProtocol):
15
+ """
16
+ Filter patches based on mean and standard deviation thresholds.
17
+
18
+ Attributes
19
+ ----------
20
+ mean_threshold : float
21
+ Threshold for the mean of the patch.
22
+ std_threshold : float
23
+ Threshold for the standard deviation of the patch.
24
+ p : float
25
+ Probability of applying the filter to a patch.
26
+ rng : np.random.Generator
27
+ Random number generator for stochastic filtering.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ mean_threshold: float,
33
+ std_threshold: float | None = None,
34
+ p: float = 1.0,
35
+ seed: int | None = None,
36
+ ) -> None:
37
+ """
38
+ Create a MeanStdPatchFilter.
39
+
40
+ This filter removes patches whose mean and standard deviation are both below
41
+ specified thresholds. The filtering is applied with a probability `p`, allowing
42
+ for stochastic filtering.
43
+
44
+ Parameters
45
+ ----------
46
+ mean_threshold : float
47
+ Threshold for the mean of the patch.
48
+ std_threshold : float | None, default=None
49
+ Threshold for the standard deviation of the patch. If None, then no
50
+ standard deviation filtering is applied.
51
+ p : float, default=1
52
+ Probability of applying the filter to a patch. Must be between 0 and 1.
53
+ seed : int | None, default=None
54
+ Seed for the random number generator for reproducibility.
55
+
56
+ Raises
57
+ ------
58
+ ValueError
59
+ If mean_threshold or std_threshold is negative.
60
+ ValueError
61
+ If std_threshold is negative.
62
+ ValueError
63
+ If p is not between 0 and 1.
64
+ """
65
+
66
+ if mean_threshold < 0:
67
+ raise ValueError("Mean threshold must be non-negative.")
68
+ if std_threshold is not None and std_threshold < 0:
69
+ raise ValueError("Std threshold must be non-negative.")
70
+ if not (0 <= p <= 1):
71
+ raise ValueError("Probability p must be between 0 and 1.")
72
+
73
+ self.mean_threshold = mean_threshold
74
+ self.std_threshold = std_threshold
75
+
76
+ self.p = p
77
+ self.rng = np.random.default_rng(seed)
78
+
79
+ def filter_out(self, patch: np.ndarray) -> bool:
80
+ """
81
+ Determine whether to filter out a patch based on mean and std thresholds.
82
+
83
+ Parameters
84
+ ----------
85
+ patch : numpy.NDArray
86
+ The image patch to evaluate.
87
+
88
+ Returns
89
+ -------
90
+ bool
91
+ True if the patch should be filtered out, False otherwise.
92
+ """
93
+
94
+ if self.rng.uniform(0, 1) < self.p:
95
+ patch_mean = np.mean(patch)
96
+ patch_std = np.std(patch)
97
+
98
+ return (patch_mean < self.mean_threshold) or (
99
+ self.std_threshold is not None and patch_std < self.std_threshold
100
+ )
101
+ return False
102
+
103
+ @staticmethod
104
+ def filter_map(image: np.ndarray, patch_size: Sequence[int]) -> np.ndarray:
105
+ """
106
+ Compute the mean and std map of an image.
107
+
108
+ The mean and std are computed over non-overlapping patches. This method can be
109
+ used to assess a useful threshold for the MeanStd filter.
110
+
111
+ Parameters
112
+ ----------
113
+ image : numpy.NDArray
114
+ The full image to evaluate.
115
+ patch_size : Sequence[int]
116
+ The size of the patches to consider.
117
+
118
+ Returns
119
+ -------
120
+ np.ndarray
121
+ Stacked mean and std maps of the image.
122
+
123
+ Raises
124
+ ------
125
+ ValueError
126
+ If the image is not 2D or 3D.
127
+
128
+ Example
129
+ -------
130
+ The `filter_map` method can be used to assess useful thresholds for the
131
+ MeanStd filter.
132
+ >>> import numpy as np
133
+ >>> import matplotlib.pyplot as plt
134
+ >>> from careamics.dataset_ng.patch_filter import MeanStdPatchFilter
135
+ >>> rng = np.random.default_rng(42)
136
+ >>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
137
+ >>> image[64:192, 64:192] = rng.normal(50, 3, (128, 128))
138
+ >>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
139
+ >>> patch_size = (16, 16)
140
+ >>> meanstd_map = MeanStdPatchFilter.filter_map(image, patch_size)
141
+ >>> fig, ax = plt.subplots(3, 3, figsize=(10, 10)) # doctest: +SKIP
142
+ >>> for i, mean_thresh in enumerate([48 + i for i in range(3)]):
143
+ ... for j, std_thresh in enumerate([5 + i for i in range(3)]):
144
+ ... ax[i, j].imshow(
145
+ ... (meanstd_map[0, ...] > mean_thresh)
146
+ ... & (meanstd_map[1, ...] > std_thresh),
147
+ ... cmap="gray", vmin=0, vmax=1
148
+ ... ) # doctest: +SKIP
149
+ ... ax[i, j].set_title(
150
+ ... f"Mean: {mean_thresh}, Std: {std_thresh}"
151
+ ... ) # doctest: +SKIP
152
+ >>> plt.show() # doctest: +SKIP
153
+ """
154
+ if len(image.shape) < 2 or len(image.shape) > 3:
155
+ raise ValueError("Image must be 2D or 3D.")
156
+
157
+ axes = "YX" if len(patch_size) == 2 else "ZYX"
158
+
159
+ mean = np.zeros_like(image, dtype=float)
160
+ std = np.zeros_like(image, dtype=float)
161
+
162
+ image_stacks = load_arrays(source=[image], axes=axes)
163
+ extractor = PatchExtractor(image_stacks)
164
+ tiling = TilingStrategy(
165
+ data_shapes=[(1, 1, *image.shape)],
166
+ patch_size=patch_size,
167
+ overlaps=(0,) * len(patch_size), # no overlap
168
+ )
169
+
170
+ for idx in tqdm(range(tiling.n_patches), desc="Computing Mean/STD map"):
171
+ patch_spec = tiling.get_patch_spec(idx)
172
+ patch = extractor.extract_patch(
173
+ data_idx=0,
174
+ sample_idx=0,
175
+ coords=patch_spec["coords"],
176
+ patch_size=patch_size,
177
+ )
178
+
179
+ coordinates = tuple(
180
+ slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
181
+ for i, p in enumerate(patch_size)
182
+ )
183
+ mean[coordinates] = np.mean(patch)
184
+ std[coordinates] = np.std(patch)
185
+
186
+ return np.stack([mean, std], axis=0)
187
+
188
+ @staticmethod
189
+ def apply_filter(
190
+ filter_map: np.ndarray,
191
+ mean_threshold: float,
192
+ std_threshold: float | None = None,
193
+ ) -> np.ndarray:
194
+ """
195
+ Apply mean and std thresholds to a filter map.
196
+
197
+ The filter map is the output of the `filter_map` method.
198
+
199
+ Parameters
200
+ ----------
201
+ filter_map : np.ndarray
202
+ Stacked mean and std maps of the image.
203
+ mean_threshold : float
204
+ Threshold for the mean of the patch.
205
+ std_threshold : float | None, default=None
206
+ Threshold for the standard deviation of the patch. If None, then no
207
+ standard deviation filtering is applied.
208
+
209
+ Returns
210
+ -------
211
+ np.ndarray
212
+ A binary map where True indicates patches that pass the filter.
213
+ """
214
+ if std_threshold is not None:
215
+ return (filter_map[0, ...] > mean_threshold) & (
216
+ filter_map[1, ...] > std_threshold
217
+ )
218
+ return filter_map[0, ...] > mean_threshold
@@ -0,0 +1,50 @@
1
+ """A protocol for patch filtering."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Protocol
5
+
6
+ import numpy as np
7
+
8
+
9
+ class PatchFilterProtocol(Protocol):
10
+ """
11
+ An interface for implementing patch filtering strategies.
12
+ """
13
+
14
+ def filter_out(self, patch: np.ndarray) -> bool:
15
+ """
16
+ Determine whether to filter out a given patch.
17
+
18
+ Parameters
19
+ ----------
20
+ patch : numpy.NDArray
21
+ The image patch to evaluate.
22
+
23
+ Returns
24
+ -------
25
+ bool
26
+ True if the patch should be filtered out (excluded), False otherwise.
27
+ """
28
+ ...
29
+
30
+ @staticmethod
31
+ def filter_map(
32
+ image: np.ndarray,
33
+ patch_size: Sequence[int],
34
+ ) -> np.ndarray:
35
+ """
36
+ Compute a filter map for the entire image based on the patch filtering criteria.
37
+
38
+ Parameters
39
+ ----------
40
+ image : numpy.NDArray
41
+ The full image to evaluate.
42
+ patch_size : Sequence[int]
43
+ The size of the patches to consider.
44
+
45
+ Returns
46
+ -------
47
+ numpy.NDArray
48
+ A map where each element is the .
49
+ """
50
+ ...