careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,212 @@
1
+ # The CAREamics Dataset
2
+
3
+ Welcome to the CAREamics dataset!
4
+
5
+ A PyTorch based dataset, designed to be used with microscopy data. It is universal for the training, validation and prediction stages of a machine learning pipeline.
6
+
7
+ The key ethos is to create a modular and maintainable dataset comprised of swappable components that interact through interfaces. This should facilitate a smooth development process when extending the dataset's function to new features, and also enable advanced users to easily customize the dataset to their needs, by writing custom components. This is achieved by following a few key software engineering principles, detailed at the end of this README file.
8
+
9
+
10
+ ## Dataset Component overview
11
+
12
+ ```mermaid
13
+ ---
14
+ title: CAREamicsDataset
15
+ ---
16
+ classDiagram
17
+ class CAREamicsDataset{
18
+ +PatchExtractor input_extractor
19
+ +Optional[PatchExtractor] target_extractor
20
+ +PatchingStrategy patching_strategy
21
+ +list~Transform~ transforms
22
+ +\_\_getitem\_\_(int index) NDArray
23
+ }
24
+ class PatchingStrategy{
25
+ <<interface>>
26
+ +n_patches int
27
+ +get_patch_spec(index: int) PatchSpecs
28
+ }
29
+ class RandomPatchingStrategy{
30
+ }
31
+ class FixedRandomPatchingStrategy{
32
+ }
33
+ class SequentialPatchingStrategy{
34
+ }
35
+ class TilingStrategy{
36
+ +get_patch_spec(index: int) TileSpecs
37
+ }
38
+
39
+ class PatchExtractor{
40
+ +list~ImageStack~ image_stacks
41
+ +extract_patch(PatchSpecs) NDArray
42
+ }
43
+ class PatchSpecs {
44
+ <<TypedDict>>
45
+ +int data_idx
46
+ +int sample_idx
47
+ +Sequence~int~ coords
48
+ +Sequence~int~ patch_size
49
+ }
50
+ class TileSpecs {
51
+ <<TypedDict>>
52
+ +Sequence~int~ crop_coords
53
+ +Sequence~int~ crop_size
54
+ +Sequence~int~ stitch_coords
55
+ }
56
+
57
+ class ImageStack{
58
+ <<interface>>
59
+ +Union[Path, Literal["array"]] source
60
+ +Sequence~int~ data_shape
61
+ +DTypeLike data_type
62
+ +extract_patch(sample_idx, coords, patch_size) NDArray
63
+ }
64
+ class InMemoryImageStack {
65
+ }
66
+ class ZarrImageStack {
67
+ +Path source
68
+ }
69
+
70
+ CAREamicsDataset --* PatchExtractor: Is composed of
71
+ CAREamicsDataset --* PatchingStrategy: Is composed of
72
+ PatchExtractor --o ImageStack: Aggregates
73
+ ImageStack <|-- InMemoryImageStack: Implements
74
+ ImageStack <|-- ZarrImageStack: Implements
75
+ PatchingStrategy <|-- RandomPatchingStrategy: Implements
76
+ PatchingStrategy <|-- FixedRandomPatchingStrategy: Implements
77
+ PatchingStrategy <|-- SequentialPatchingStrategy: Implements
78
+ PatchingStrategy <|-- TilingStrategy: Implements
79
+ PatchSpecs <|-- TileSpecs: Inherits from
80
+ ```
81
+
82
+ ### `ImageStack` and implementations
83
+
84
+ This interface represents a set of image data, which can be saved with any subset of the
85
+ axes STCZYX, in any order, see below for a description of the dimensions. The `ImageStack`
86
+ interface's job is to act as an adapter for different data storage types, so that higher
87
+ level classes can access the image data without having to know the implementation details of
88
+ how to load or read data from each storage type. This means we can decide to support new storage
89
+ types by implementing a new concrete `ImageStack` class without having to change anything
90
+ in the `CAREamistDataset` class. Advanced users can also choose to create their own
91
+ `ImageStack` class if they want to work with their own data storage type.
92
+
93
+ The interface provides an `extract_patch` method which will produce a patch from the image,
94
+ as a NumPy array, with the dimensions C(Z)YX. This method should be thought of as simply
95
+ a wrapper for the equivalent to NumPy slicing for each of the storage types.
96
+
97
+ #### Concrete implementations
98
+
99
+ - `InMemoryImageStack`: The underlying data is stored as a NumPy array in memory. It has some
100
+ additional constructor methods to load the data from known file formats such as TIFF files.
101
+ - `ZarrImageStack`: The underlying data is stored as a ZARR file on disk.
102
+
103
+ #### Axes description
104
+
105
+ - S is a generic sample dimension,
106
+ - T is a time dimension,
107
+ - C is a channel dimension,
108
+ - Z is a spatial dimension,
109
+ - Y is a spatial dimension,
110
+ - X is a spatial dimension.
111
+
112
+ ### `PatchExtractor`
113
+
114
+ The `PatchExtractor` class aggregates many `ImageStack` instances, this allows for multiple
115
+ images with different dimensions, and possibly different storage types to be treated as a single entity.
116
+ The class has an `extract_patch` method to extract a patch from any one of its `ImageStack`
117
+ objects. It can also possibly be extended when extra logic to extract patches is needed,
118
+ for example when constructing lateral-context inputs for the MicroSplit LVAE models.
119
+
120
+ ### `PatchingStrategy`
121
+
122
+ The `PatchingStrategy` class is an interface to generate patch specifications, where each of the
123
+ concrete implementations produce a set of patch specifications using a different strategy.
124
+
125
+ It has a `n_patches` attribute that can be accessed to find out how many patches the
126
+ strategy will produce, given the shapes of the image stacks it has been initialized with.
127
+ This is needed by the `CAREamicsDataset` to return its length.
128
+
129
+ Most importantly it has a `get_patch_spec` method, that takes an index and returns a
130
+ patch specification. For deterministic patching strategies, this method will always
131
+ return the same patch specification given the same index, but there are also random strategies
132
+ where the returned patch specification will change every time. The given index can never
133
+ be greater than `n_patches`.
134
+
135
+ #### Concrete implementations
136
+
137
+ - `RandomPatchingStrategy`: this strategy will produce random patches that will change
138
+ even if the `extract_patch` method is called with the same index.
139
+ - `FixedRandomPatchingStrategy`: this strategy will produce random patches, but the patch
140
+ will be the same if the `extract_patch` method is called with the same index. This is
141
+ useful for making sure validation is comparable epoch to epoch.
142
+ - `SequentialPatchingStrategy`: this strategy is deterministic and the patches will be
143
+ sequential with some specified overlap.
144
+ - `TilingStrategy`: this strategy is deterministic and the patches will be
145
+ sequential with some specified overlap. Rather than a `PatchSpecs` dictionary it will
146
+ produce a `TileSpecs` dictionary which includes some extra fields that are used for
147
+ stitching the tiles back together.
148
+
149
+ #### PatchSpecs
150
+
151
+ The `get_patch_spec` returns a dictionary containing the keys `data_idx`, `sample_idx`, `coords` and `patch_size`.
152
+ These are the exact arguments that the `PatchExtractor.extract_patch` method takes. The patch specification
153
+ produced by the patching strategy is received by the `PatchExtractor` to in-turn produce an image patch.
154
+
155
+ For type hinting, `PatchSpecs` is defined as a `TypedDict`.
156
+
157
+ ## Key Principles
158
+
159
+ The aim of all these principles is to create a system of interacting classes that have
160
+ low coupling. This allows for one section to be changed or extended without breaking functionality
161
+ elsewhere in the codebase.
162
+
163
+ ### Composition over inheritance
164
+
165
+ The principle of composition over inheritance is: rather than using inheritance to
166
+ extend or change the behavior of a class, instead, a class can be composed of modules
167
+ that can be swapped to extend or change behavior.
168
+
169
+ The reason to use composition is that it promotes the easy reuse of the underlying
170
+ components, it can prevent a subclass explosion, and it leads to a maintainable and
171
+ easily extendable design. A software architecture based on composition is normally
172
+ maintainable and extendable because if a component needs to change then the whole class
173
+ shouldn't have to be refactored and if a new feature needs to be added, usually an additional
174
+ component can be added to the class.
175
+
176
+ The `CAREamicsDataset` is composed of `PatchExtractor` and `PatchingStrategy` and `Transfrom` components.
177
+ The `PatchingStrategy` classes implement an interface so the dataset can switch between
178
+ different strategies. The `PatchExtractor` is composed of many `ImageStack` instances,
179
+ new image stacks can be added to extend the type of data that the dataset can read from.
180
+
181
+ ### Dependency Inversion
182
+
183
+ The dependency inversion principle states:
184
+
185
+ 1. High-level modules should not depend on low-level modules. Both high-level and
186
+ low-level modules should depend on abstractions (e.g. interfaces).
187
+ 2. Abstractions should not depend on details (concrete implementations). Details should
188
+ depend on abstractions.
189
+
190
+ In other words high level modules that provide complex logic should be easily reusable
191
+ and not depend on implementation details of low-level modules that provide utility functionality.
192
+ This can be achieved by introducing abstractions that decouple high and low level modules.
193
+
194
+ An example of the dependency inversion principle in use is how the `PatchExtractor` only
195
+ depends on the `ImageStack` interface, and does not have to have any knowledge of the
196
+ concrete implementations. The concrete `ImageStack` implementations also do not have
197
+ any knowledge of the `PatchExtractor` or any other higher-level functionality that the
198
+ dataset needs.
199
+
200
+ ### Single Responsibility Principle
201
+
202
+ Each component should have a small scope of responsibility that is easily defined. This
203
+ should make the code easier to maintain and hopefully reduce the number of places in the
204
+ code that have to change when introducing a new feature.
205
+
206
+ - `ImageStack` responsibility: to act as an adapter for loading and reading image data
207
+ from different underlying storage.
208
+ - `PatchExtractor` responsibility: to extract patches from a set of image stacks.
209
+ - `PatchingStrategy` responsibility: to produce patch specifications given an index, through
210
+ an interface that hides the underlying implementation.
211
+ - `CAREamicsDataset` responsibility: to orchestrate the interactions of its underlying
212
+ components to produce an input patch (and target patch when required) given an index.
File without changes
@@ -0,0 +1,365 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Any, Generic, Literal, NamedTuple, Union
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from torch.utils.data import Dataset
8
+ from tqdm.auto import tqdm
9
+
10
+ from careamics.config.data.ng_data_config import Mode, NGDataConfig, WholePatchingConfig
11
+ from careamics.config.transformations import NormalizeConfig
12
+ from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
13
+ from careamics.dataset.patching.patching import Stats
14
+ from careamics.transforms import Compose
15
+
16
+ from .image_stack import GenericImageStack, ZarrImageStack
17
+ from .patch_extractor import PatchExtractor
18
+ from .patch_filter import create_coord_filter, create_patch_filter
19
+ from .patching_strategies import (
20
+ PatchSpecs,
21
+ RegionSpecs,
22
+ create_patching_strategy,
23
+ )
24
+
25
+
26
+ class ImageRegionData(NamedTuple, Generic[RegionSpecs]):
27
+ """
28
+ Data structure for arrays produced by the dataset and propagated through models.
29
+
30
+ An ImageRegionData may be a patch during training/validation, a tile during
31
+ prediction with tiling, or a whole image during prediction without tiling.
32
+
33
+ `data_shape` may not correspond to the shape of the original data if a subset
34
+ of the channels has been requested, in which case the channel dimension may
35
+ be smaller than that of the original data and only correspond to the requested
36
+ number of channels.
37
+
38
+ ImageRegionData may be collated in batches during training by the DataLoader. In
39
+ that case:
40
+ - data: arrays are collated into NDArray of shape (B,C,Z,Y,X)
41
+ - source: list of str, length B
42
+ - data_shape: list of tuples of int, each tuple being of length B and representing
43
+ the shape of the original images in the corresponding dimension
44
+ - dtype: list of str, length B
45
+ - axes: list of str, length B
46
+ - region_spec: dict of {str: sequence}, each sequence being of length B
47
+ - additional_metadata: list of dict
48
+
49
+ Description of the fields is given for the uncollated case (non-batched).
50
+ """
51
+
52
+ data: NDArray
53
+ """Patch, tile or image in C(Z)YX format."""
54
+
55
+ source: Union[str, Literal["array"]]
56
+ """Source of the data, e.g. file path, zarr URI, or "array" for in-memory arrays."""
57
+
58
+ data_shape: Sequence[int]
59
+ """Shape of the original image in (SCZ)YX format and order. If channels are
60
+ subsetted, the channel dimension corresponds to the number of requested channels."""
61
+
62
+ dtype: str # dtype should be str for collate
63
+ """Data type of the original image as a string."""
64
+
65
+ axes: str
66
+ """Axes of the original data array, in format SCZYX."""
67
+
68
+ region_spec: RegionSpecs # PatchSpecs or subclasses, e.g. TileSpecs
69
+ """Specifications of the region within the original image from where `data` is
70
+ extracted. Of type PatchSpecs during training/validation and prediction without
71
+ tiling, and TileSpecs during prediction with tiling.
72
+ """
73
+
74
+ additional_metadata: dict[str, Any]
75
+ """Additional metadata to be stored with the image region. Currently used to store
76
+ chunk and shard information for zarr image stacks."""
77
+
78
+
79
+ InputType = Union[Sequence[NDArray[Any]], Sequence[Path]]
80
+
81
+
82
+ def _adjust_shape_for_channels(
83
+ shape: Sequence[int],
84
+ channels: Sequence[int] | None,
85
+ value: int | Literal["channels"] = "channels",
86
+ ) -> tuple[int, ...]:
87
+ """Adjust shape to account for channel subsetting.
88
+
89
+ Parameters
90
+ ----------
91
+ shape : Sequence[int]
92
+ The original data shape in SC(Z)YX format.
93
+ channels : Sequence[int] | None
94
+ The list of channels to select. If None, no adjustment is made.
95
+ value : int | Literal["channels"], default="channels"
96
+ The value to replace the channel dimension with. If "channels", the length
97
+ of the channels list is used, by default "channels".
98
+
99
+ Returns
100
+ -------
101
+ tuple[int, ...]
102
+ The adjusted data shape in SC(Z)YX format.
103
+ """
104
+ if channels is not None:
105
+ adjusted_shape = list(shape)
106
+ adjusted_shape[1] = len(channels) if value == "channels" else value
107
+ return tuple(adjusted_shape)
108
+ return tuple(shape)
109
+
110
+
111
+ def _patch_size_within_data_shapes(
112
+ data_shapes: Sequence[Sequence[int]], patch_size: Sequence[int]
113
+ ) -> bool:
114
+ """Determine whether all the data_shapes are greater than the patch size.
115
+
116
+ Parameters
117
+ ----------
118
+ data_shapes : Sequence[Sequence[int]]
119
+ A sequence of data shapes. They must be in the format SC(Z)YX.
120
+ patch_size : Sequence[int]
121
+ A patch size that must specify the size of the patch in all the spatial
122
+ dimensions, in the format (Z)YX.
123
+
124
+ Returns
125
+ -------
126
+ bool
127
+ If all the data shapes are greater than the patch size.
128
+ """
129
+ smaller_than_shapes = [
130
+ # skip sample and channel dimension in data_shape
131
+ (np.array(patch_size) < np.array(data_shape[2:])).all()
132
+ for data_shape in data_shapes
133
+ ]
134
+ return all(smaller_than_shapes)
135
+
136
+
137
+ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
138
+ def __init__(
139
+ self,
140
+ data_config: NGDataConfig,
141
+ input_extractor: PatchExtractor[GenericImageStack],
142
+ target_extractor: PatchExtractor[GenericImageStack] | None = None,
143
+ mask_extractor: PatchExtractor[GenericImageStack] | None = None,
144
+ ) -> None:
145
+
146
+ # Make sure all the image sizes are greater than the patch size for training
147
+ data_shapes = [
148
+ image_stack.data_shape for image_stack in input_extractor.image_stacks
149
+ ]
150
+ if data_config.mode != Mode.PREDICTING:
151
+ if not isinstance(
152
+ data_config.patching, WholePatchingConfig
153
+ ) and not _patch_size_within_data_shapes(
154
+ data_shapes, data_config.patching.patch_size
155
+ ):
156
+ raise ValueError(
157
+ "Not all images sizes are greater than the patch size for training "
158
+ "and validation."
159
+ )
160
+
161
+ self.config = data_config
162
+
163
+ self.input_extractor = input_extractor
164
+ self.target_extractor = target_extractor
165
+
166
+ self.patch_filter = (
167
+ create_patch_filter(self.config.patch_filter)
168
+ if self.config.patch_filter is not None
169
+ else None
170
+ )
171
+ self.coord_filter = (
172
+ create_coord_filter(self.config.coord_filter, mask=mask_extractor)
173
+ if self.config.coord_filter is not None and mask_extractor is not None
174
+ else None
175
+ )
176
+ self.patch_filter_patience = self.config.patch_filter_patience
177
+
178
+ self.patching_strategy = create_patching_strategy(
179
+ data_shapes=self.input_extractor.shapes,
180
+ patching_config=self.config.patching,
181
+ )
182
+
183
+ self.input_stats, self.target_stats = self._initialize_statistics()
184
+
185
+ self.transforms = self._initialize_transforms()
186
+
187
+ def _initialize_transforms(self) -> Compose | None:
188
+ normalize = NormalizeConfig(
189
+ image_means=self.input_stats.means,
190
+ image_stds=self.input_stats.stds,
191
+ target_means=self.target_stats.means,
192
+ target_stds=self.target_stats.stds,
193
+ )
194
+ if self.config.mode == Mode.TRAINING:
195
+ # TODO: initialize normalization separately depending on configuration
196
+ return Compose(transform_list=[normalize] + list(self.config.transforms))
197
+
198
+ # TODO: add TTA
199
+ return Compose(transform_list=[normalize])
200
+
201
+ def _calculate_stats(
202
+ self, data_extractor: PatchExtractor[GenericImageStack]
203
+ ) -> Stats:
204
+ image_stats = WelfordStatistics()
205
+ n_patches = self.patching_strategy.n_patches
206
+
207
+ for idx in tqdm(range(n_patches), desc="Computing statistics"):
208
+ patch_spec = self.patching_strategy.get_patch_spec(idx)
209
+ patch = data_extractor.extract_channel_patch(
210
+ data_idx=patch_spec["data_idx"],
211
+ sample_idx=patch_spec["sample_idx"],
212
+ channels=self.config.channels,
213
+ coords=patch_spec["coords"],
214
+ patch_size=patch_spec["patch_size"],
215
+ )
216
+ # TODO: statistics accept SCYX format, while patch is CYX
217
+ image_stats.update(patch[None, ...], sample_idx=idx)
218
+
219
+ image_means, image_stds = image_stats.finalize()
220
+ return Stats(image_means, image_stds)
221
+
222
+ # TODO: add running stats
223
+ def _initialize_statistics(self) -> tuple[Stats, Stats]:
224
+ if self.config.image_means is not None and self.config.image_stds is not None:
225
+ input_stats = Stats(self.config.image_means, self.config.image_stds)
226
+ else:
227
+ input_stats = self._calculate_stats(self.input_extractor)
228
+
229
+ target_stats = Stats((), ())
230
+
231
+ if self.config.target_means is not None and self.config.target_stds is not None:
232
+ target_stats = Stats(self.config.target_means, self.config.target_stds)
233
+ elif self.target_extractor is not None:
234
+ target_stats = self._calculate_stats(self.target_extractor)
235
+
236
+ return input_stats, target_stats
237
+
238
+ def __len__(self):
239
+ return self.patching_strategy.n_patches
240
+
241
+ def _create_image_region(
242
+ self, patch: np.ndarray, patch_spec: PatchSpecs, extractor: PatchExtractor
243
+ ) -> ImageRegionData:
244
+ data_idx = patch_spec["data_idx"]
245
+ image_stack: GenericImageStack = extractor.image_stacks[data_idx]
246
+
247
+ # adjust the number of channels in data_shape if needed
248
+ data_shape = _adjust_shape_for_channels(
249
+ shape=image_stack.data_shape,
250
+ channels=self.config.channels,
251
+ )
252
+
253
+ # additional metadata for zarr image stacks
254
+ if isinstance(image_stack, ZarrImageStack):
255
+ additional_metadata = {
256
+ "chunks": image_stack.chunks,
257
+ }
258
+
259
+ if image_stack.shards is not None:
260
+ additional_metadata["shards"] = image_stack.shards
261
+ else:
262
+ additional_metadata = {}
263
+
264
+ return ImageRegionData(
265
+ data=patch,
266
+ source=str(image_stack.source),
267
+ dtype=str(image_stack.data_dtype),
268
+ data_shape=data_shape,
269
+ # TODO: should it be axes of the original image instead?
270
+ axes=self.config.axes,
271
+ region_spec=patch_spec,
272
+ additional_metadata=additional_metadata,
273
+ )
274
+
275
+ def _extract_patches(
276
+ self, patch_spec: PatchSpecs
277
+ ) -> tuple[NDArray, NDArray | None]:
278
+ """Extract input and target patches based on patch specifications."""
279
+ input_patch = self.input_extractor.extract_channel_patch(
280
+ data_idx=patch_spec["data_idx"],
281
+ sample_idx=patch_spec["sample_idx"],
282
+ channels=self.config.channels,
283
+ coords=patch_spec["coords"],
284
+ patch_size=patch_spec["patch_size"],
285
+ )
286
+
287
+ target_patch = (
288
+ self.target_extractor.extract_channel_patch(
289
+ data_idx=patch_spec["data_idx"],
290
+ sample_idx=patch_spec["sample_idx"],
291
+ # TODO does not allow selecting different channels for target
292
+ channels=self.config.channels,
293
+ coords=patch_spec["coords"],
294
+ patch_size=patch_spec["patch_size"],
295
+ )
296
+ if self.target_extractor is not None
297
+ else None
298
+ )
299
+ return input_patch, target_patch
300
+
301
+ def _get_filtered_patch(
302
+ self, index: int
303
+ ) -> tuple[NDArray[Any], NDArray[Any] | None, PatchSpecs]:
304
+ """Extract a patch that passes filtering criteria with retry logic."""
305
+ should_filter = self.config.mode == Mode.TRAINING and (
306
+ self.patch_filter is not None or self.coord_filter is not None
307
+ )
308
+ empty_patch = True
309
+ patch_filter_patience = self.patch_filter_patience # reset patience
310
+
311
+ while empty_patch and patch_filter_patience > 0:
312
+ # query patches
313
+ patch_spec = self.patching_strategy.get_patch_spec(index)
314
+
315
+ # filter patch based on coordinates if needed
316
+ if should_filter and self.coord_filter is not None:
317
+ if self.coord_filter.filter_out(patch_spec):
318
+ patch_filter_patience -= 1
319
+
320
+ # TODO should we raise an error rather than silently accept patches?
321
+ # if patience runs out without ever finding coordinates
322
+ # then we need to guard against an exist before defining
323
+ # input_patch and target_patch
324
+ if patch_filter_patience != 0:
325
+ continue
326
+
327
+ input_patch, target_patch = self._extract_patches(patch_spec)
328
+
329
+ # filter patch based on values if needed
330
+ if should_filter and self.patch_filter is not None:
331
+ empty_patch = self.patch_filter.filter_out(input_patch)
332
+ patch_filter_patience -= 1 # decrease patience
333
+ else:
334
+ empty_patch = False
335
+
336
+ return input_patch, target_patch, patch_spec
337
+
338
+ def __getitem__(
339
+ self, index: int
340
+ ) -> Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]]:
341
+ input_patch, target_patch, patch_spec = self._get_filtered_patch(index)
342
+
343
+ # apply transforms
344
+ if self.transforms is not None:
345
+ if self.target_extractor is not None:
346
+ input_patch, target_patch = self.transforms(input_patch, target_patch)
347
+ else:
348
+ # TODO: compose doesn't return None for target patch anymore
349
+ # so have to do this annoying if else
350
+ (input_patch,) = self.transforms(input_patch, target_patch)
351
+ target_patch = None
352
+
353
+ input_data = self._create_image_region(
354
+ patch=input_patch, patch_spec=patch_spec, extractor=self.input_extractor
355
+ )
356
+
357
+ if target_patch is not None and self.target_extractor is not None:
358
+ target_data = self._create_image_region(
359
+ patch=target_patch,
360
+ patch_spec=patch_spec,
361
+ extractor=self.target_extractor,
362
+ )
363
+ return input_data, target_data
364
+ else:
365
+ return (input_data,)