careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,632 @@
1
+ """MicroSplit data module for training and validation."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ import pytorch_lightning as L
9
+ import tifffile
10
+ from numpy.typing import NDArray
11
+ from torch.utils.data import DataLoader
12
+
13
+ from careamics.dataset.dataset_utils.dataset_utils import reshape_array
14
+ from careamics.lvae_training.dataset import (
15
+ DataSplitType,
16
+ DataType,
17
+ LCMultiChDloader,
18
+ MicroSplitDataConfig,
19
+ )
20
+ from careamics.lvae_training.dataset.types import TilingMode
21
+
22
+
23
+ # TODO refactor
24
+ def load_one_file(fpath):
25
+ """Load a single 2D image file.
26
+
27
+ Parameters
28
+ ----------
29
+ fpath : str or Path
30
+ Path to the image file.
31
+
32
+ Returns
33
+ -------
34
+ numpy.ndarray
35
+ Reshaped image data.
36
+ """
37
+ data = tifffile.imread(fpath)
38
+ if len(data.shape) == 2:
39
+ axes = "YX"
40
+ elif len(data.shape) == 3:
41
+ axes = "SYX"
42
+ elif len(data.shape) == 4:
43
+ axes = "STYX"
44
+ else:
45
+ raise ValueError(f"Invalid data shape: {data.shape}")
46
+ data = reshape_array(data, axes)
47
+ data = data.reshape(-1, data.shape[-2], data.shape[-1])
48
+ return data
49
+
50
+
51
+ # TODO refactor
52
+ def load_data(datadir):
53
+ """Load data from a directory containing channel subdirectories with image files.
54
+
55
+ Parameters
56
+ ----------
57
+ datadir : str or Path
58
+ Path to the data directory containing channel subdirectories.
59
+
60
+ Returns
61
+ -------
62
+ numpy.ndarray
63
+ Stacked array of all channels' data.
64
+ """
65
+ data_path = Path(datadir)
66
+
67
+ channel_dirs = sorted(p for p in data_path.iterdir() if p.is_dir())
68
+ channels_data = []
69
+
70
+ for channel_dir in channel_dirs:
71
+ image_files = sorted(f for f in channel_dir.iterdir() if f.is_file())
72
+ channel_images = [load_one_file(image_path) for image_path in image_files]
73
+
74
+ channel_stack = np.concatenate(
75
+ channel_images, axis=0
76
+ ) # FIXME: this line works if images have a singleton channel dimension.
77
+ # Specify in the notebook or change with `torch.stack`??
78
+ channels_data.append(channel_stack)
79
+
80
+ final_data = np.stack(channels_data, axis=-1)
81
+ return final_data
82
+
83
+
84
+ # TODO refactor
85
+ def get_datasplit_tuples(val_fraction, test_fraction, data_length):
86
+ """Get train/val/test indices for data splitting.
87
+
88
+ Parameters
89
+ ----------
90
+ val_fraction : float or None
91
+ Fraction of data to use for validation.
92
+ test_fraction : float or None
93
+ Fraction of data to use for testing.
94
+ data_length : int
95
+ Total length of the dataset.
96
+
97
+ Returns
98
+ -------
99
+ tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
100
+ Training, validation, and test indices.
101
+ """
102
+ indices = np.arange(data_length)
103
+ np.random.shuffle(indices)
104
+
105
+ if val_fraction is None:
106
+ val_fraction = 0.0
107
+ if test_fraction is None:
108
+ test_fraction = 0.0
109
+
110
+ val_size = int(data_length * val_fraction)
111
+ test_size = int(data_length * test_fraction)
112
+ train_size = data_length - val_size - test_size
113
+
114
+ train_idx = indices[:train_size]
115
+ val_idx = indices[train_size : train_size + val_size]
116
+ test_idx = indices[train_size + val_size :]
117
+
118
+ return train_idx, val_idx, test_idx
119
+
120
+
121
+ # TODO refactor
122
+ def get_train_val_data(
123
+ data_config,
124
+ datadir,
125
+ datasplit_type: DataSplitType,
126
+ val_fraction=None,
127
+ test_fraction=None,
128
+ allow_generation=None,
129
+ **kwargs,
130
+ ):
131
+ """Load and split data according to configuration.
132
+
133
+ Parameters
134
+ ----------
135
+ data_config : MicroSplitDataConfig
136
+ Data configuration object.
137
+ datadir : str or Path
138
+ Path to the data directory.
139
+ datasplit_type : DataSplitType
140
+ Type of data split to return.
141
+ val_fraction : float, optional
142
+ Fraction of data to use for validation.
143
+ test_fraction : float, optional
144
+ Fraction of data to use for testing.
145
+ allow_generation : bool, optional
146
+ Whether to allow data generation.
147
+ **kwargs
148
+ Additional keyword arguments.
149
+
150
+ Returns
151
+ -------
152
+ numpy.ndarray
153
+ Split data array.
154
+ """
155
+ data = load_data(datadir)
156
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
157
+ val_fraction, test_fraction, len(data)
158
+ )
159
+
160
+ if datasplit_type == DataSplitType.All:
161
+ data = data.astype(np.float64)
162
+ elif datasplit_type == DataSplitType.Train:
163
+ data = data[train_idx].astype(np.float64)
164
+ elif datasplit_type == DataSplitType.Val:
165
+ data = data[val_idx].astype(np.float64)
166
+ elif datasplit_type == DataSplitType.Test:
167
+ # TODO this is only used for prediction, and only because old dataset uses it
168
+ data = data[test_idx].astype(np.float64)
169
+ else:
170
+ raise Exception("invalid datasplit")
171
+
172
+ return data
173
+
174
+
175
+ class MicroSplitDataModule(L.LightningDataModule):
176
+ """Lightning DataModule for MicroSplit-style datasets.
177
+
178
+ Matches the interface of TrainDataModule, but internally uses original MicroSplit
179
+ dataset logic.
180
+
181
+ Parameters
182
+ ----------
183
+ data_config : MicroSplitDataConfig
184
+ Configuration for the MicroSplit dataset.
185
+ train_data : str
186
+ Path to training data directory.
187
+ val_data : str, optional
188
+ Path to validation data directory.
189
+ train_data_target : str, optional
190
+ Path to training target data.
191
+ val_data_target : str, optional
192
+ Path to validation target data.
193
+ read_source_func : Callable, optional
194
+ Function to read source data.
195
+ extension_filter : str, optional
196
+ File extension filter.
197
+ val_percentage : float, optional
198
+ Percentage of data to use for validation, by default 0.1.
199
+ val_minimum_split : int, optional
200
+ Minimum number of samples for validation split, by default 5.
201
+ use_in_memory : bool, optional
202
+ Whether to use in-memory dataset, by default True.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ data_config: MicroSplitDataConfig,
208
+ train_data: str,
209
+ val_data: str | None = None,
210
+ train_data_target: str | None = None,
211
+ val_data_target: str | None = None,
212
+ read_source_func: Callable | None = None,
213
+ extension_filter: str = "",
214
+ val_percentage: float = 0.1,
215
+ val_minimum_split: int = 5,
216
+ use_in_memory: bool = True,
217
+ ):
218
+ """Initialize MicroSplitDataModule.
219
+
220
+ Parameters
221
+ ----------
222
+ data_config : MicroSplitDataConfig
223
+ Configuration for the MicroSplit dataset.
224
+ train_data : str
225
+ Path to training data directory.
226
+ val_data : str, optional
227
+ Path to validation data directory.
228
+ train_data_target : str, optional
229
+ Path to training target data.
230
+ val_data_target : str, optional
231
+ Path to validation target data.
232
+ read_source_func : Callable, optional
233
+ Function to read source data.
234
+ extension_filter : str, optional
235
+ File extension filter.
236
+ val_percentage : float, optional
237
+ Percentage of data to use for validation, by default 0.1.
238
+ val_minimum_split : int, optional
239
+ Minimum number of samples for validation split, by default 5.
240
+ use_in_memory : bool, optional
241
+ Whether to use in-memory dataset, by default True.
242
+ """
243
+ super().__init__()
244
+ # Dataset selection logic (adapted from create_train_val_datasets)
245
+ self.train_config = data_config # SHould configs be separated?
246
+ self.val_config = data_config
247
+ self.test_config = data_config
248
+
249
+ datapath = train_data
250
+ load_data_func = read_source_func
251
+
252
+ dataset_class = LCMultiChDloader # TODO hardcoded for now
253
+
254
+ # Create datasets
255
+ self.train_dataset = dataset_class(
256
+ self.train_config,
257
+ datapath,
258
+ load_data_fn=load_data_func,
259
+ val_fraction=val_percentage,
260
+ test_fraction=0.1,
261
+ )
262
+ max_val = self.train_dataset.get_max_val()
263
+ self.val_config.max_val = max_val
264
+ if self.train_config.datasplit_type == DataSplitType.All:
265
+ self.val_config.datasplit_type = DataSplitType.All
266
+ self.test_config.datasplit_type = DataSplitType.All
267
+ self.val_dataset = dataset_class(
268
+ self.val_config,
269
+ datapath,
270
+ load_data_fn=load_data_func,
271
+ val_fraction=val_percentage,
272
+ test_fraction=0.1,
273
+ )
274
+ self.test_config.max_val = max_val
275
+ self.test_dataset = dataset_class(
276
+ self.test_config,
277
+ datapath,
278
+ load_data_fn=load_data_func,
279
+ val_fraction=val_percentage,
280
+ test_fraction=0.1,
281
+ )
282
+ mean_val, std_val = self.train_dataset.compute_mean_std()
283
+ self.train_dataset.set_mean_std(mean_val, std_val)
284
+ self.val_dataset.set_mean_std(mean_val, std_val)
285
+ self.test_dataset.set_mean_std(mean_val, std_val)
286
+ data_stats = self.train_dataset.get_mean_std()
287
+
288
+ # Store data statistics
289
+ self.data_stats = (
290
+ data_stats[0],
291
+ data_stats[1],
292
+ ) # TODO repeats old logic, revisit
293
+
294
+ def train_dataloader(self):
295
+ """Create a dataloader for training.
296
+
297
+ Returns
298
+ -------
299
+ DataLoader
300
+ Training dataloader.
301
+ """
302
+ return DataLoader(
303
+ self.train_dataset,
304
+ batch_size=self.train_config.batch_size,
305
+ # TODO should be inside dataloader params?
306
+ **self.train_config.train_dataloader_params,
307
+ )
308
+
309
+ def val_dataloader(self):
310
+ """Create a dataloader for validation.
311
+
312
+ Returns
313
+ -------
314
+ DataLoader
315
+ Validation dataloader.
316
+ """
317
+ return DataLoader(
318
+ self.val_dataset,
319
+ batch_size=self.train_config.batch_size,
320
+ **self.val_config.val_dataloader_params, # TODO duplicated
321
+ )
322
+
323
+ def get_data_stats(self):
324
+ """Get data statistics.
325
+
326
+ Returns
327
+ -------
328
+ tuple[dict, dict]
329
+ A tuple containing two dictionaries:
330
+ - data_mean: mean values for input and target
331
+ - data_std: standard deviation values for input and target
332
+ """
333
+ return self.data_stats, self.val_config.max_val # TODO should be in the config?
334
+
335
+
336
+ def create_microsplit_train_datamodule(
337
+ train_data: str,
338
+ patch_size: tuple,
339
+ data_type: DataType,
340
+ axes: str, # TODO should be there after refactoring
341
+ batch_size: int,
342
+ val_data: str | None = None,
343
+ num_channels: int = 2,
344
+ depth3D: int = 1,
345
+ grid_size: tuple | None = None,
346
+ multiscale_count: int | None = None,
347
+ tiling_mode: TilingMode = TilingMode.ShiftBoundary,
348
+ read_source_func: Callable | None = None, # TODO should be there after refactoring
349
+ extension_filter: str = "",
350
+ val_percentage: float = 0.1,
351
+ val_minimum_split: int = 5,
352
+ use_in_memory: bool = True,
353
+ transforms: list | None = None, # TODO should it be here?
354
+ train_dataloader_params: dict | None = None,
355
+ val_dataloader_params: dict | None = None,
356
+ **dataset_kwargs,
357
+ ) -> MicroSplitDataModule:
358
+ """
359
+ Create a MicroSplitDataModule for MicroSplit-style datasets.
360
+
361
+ Parameters
362
+ ----------
363
+ train_data : str
364
+ Path to training data.
365
+ patch_size : tuple
366
+ Size of one patch of data.
367
+ data_type : DataType
368
+ Type of the dataset (must be a DataType enum value).
369
+ axes : str
370
+ Axes of the data (e.g., 'SYX').
371
+ batch_size : int
372
+ Batch size for dataloaders.
373
+ val_data : str, optional
374
+ Path to validation data.
375
+ num_channels : int, default=2
376
+ Number of channels in the input.
377
+ depth3D : int, default=1
378
+ Number of slices in 3D.
379
+ grid_size : tuple, optional
380
+ Grid size for patch extraction.
381
+ multiscale_count : int, optional
382
+ Number of LC scales.
383
+ tiling_mode : TilingMode, default=ShiftBoundary
384
+ Tiling mode for patch extraction.
385
+ read_source_func : Callable, optional
386
+ Function to read the source data.
387
+ extension_filter : str, optional
388
+ File extension filter.
389
+ val_percentage : float, default=0.1
390
+ Percentage of training data to use for validation.
391
+ val_minimum_split : int, default=5
392
+ Minimum number of patches/files for validation split.
393
+ use_in_memory : bool, default=True
394
+ Use in-memory dataset if possible.
395
+ transforms : list, optional
396
+ List of transforms to apply.
397
+ train_dataloader_params : dict, optional
398
+ Parameters for training dataloader.
399
+ val_dataloader_params : dict, optional
400
+ Parameters for validation dataloader.
401
+ **dataset_kwargs :
402
+ Additional arguments passed to DatasetConfig.
403
+
404
+ Returns
405
+ -------
406
+ MicroSplitDataModule
407
+ Configured MicroSplitDataModule instance.
408
+ """
409
+ # Create dataset configs with only valid parameters
410
+ dataset_config_params = {
411
+ "data_type": data_type,
412
+ "image_size": patch_size,
413
+ "num_channels": num_channels,
414
+ "depth3D": depth3D,
415
+ "grid_size": grid_size,
416
+ "multiscale_lowres_count": multiscale_count,
417
+ "tiling_mode": tiling_mode,
418
+ "batch_size": batch_size,
419
+ "train_dataloader_params": train_dataloader_params,
420
+ "val_dataloader_params": val_dataloader_params,
421
+ **dataset_kwargs,
422
+ }
423
+
424
+ train_config = MicroSplitDataConfig(
425
+ **dataset_config_params,
426
+ datasplit_type=DataSplitType.Train,
427
+ )
428
+ # val_config = MicroSplitDataConfig(
429
+ # **dataset_config_params,
430
+ # datasplit_type=DataSplitType.Val,
431
+ # )
432
+ # TODO, data config is duplicated here and in configuration
433
+
434
+ return MicroSplitDataModule(
435
+ data_config=train_config,
436
+ train_data=train_data,
437
+ val_data=val_data or train_data,
438
+ train_data_target=None,
439
+ val_data_target=None,
440
+ read_source_func=get_train_val_data, # Use our wrapped function
441
+ extension_filter=extension_filter,
442
+ val_percentage=val_percentage,
443
+ val_minimum_split=val_minimum_split,
444
+ use_in_memory=use_in_memory,
445
+ )
446
+
447
+
448
+ class MicroSplitPredictDataModule(L.LightningDataModule):
449
+ """Lightning DataModule for MicroSplit-style prediction datasets.
450
+
451
+ Matches the interface of PredictDataModule, but internally uses MicroSplit
452
+ dataset logic for prediction.
453
+
454
+ Parameters
455
+ ----------
456
+ pred_config : MicroSplitDataConfig
457
+ Configuration for MicroSplit prediction.
458
+ pred_data : str or Path or numpy.ndarray
459
+ Prediction data, can be a path to a folder, a file or a numpy array.
460
+ read_source_func : Callable, optional
461
+ Function to read custom types.
462
+ extension_filter : str, optional
463
+ Filter to filter file extensions for custom types.
464
+ dataloader_params : dict, optional
465
+ Dataloader parameters.
466
+ """
467
+
468
+ def __init__(
469
+ self,
470
+ pred_config: MicroSplitDataConfig,
471
+ pred_data: Union[str, Path, NDArray],
472
+ read_source_func: Callable | None = None,
473
+ extension_filter: str = "",
474
+ dataloader_params: dict | None = None,
475
+ ) -> None:
476
+ """
477
+ Constructor for MicroSplit prediction data module.
478
+
479
+ Parameters
480
+ ----------
481
+ pred_config : MicroSplitDataConfig
482
+ Configuration for MicroSplit prediction.
483
+ pred_data : str or Path or numpy.ndarray
484
+ Prediction data, can be a path to a folder, a file or a numpy array.
485
+ read_source_func : Callable, optional
486
+ Function to read custom types, by default None.
487
+ extension_filter : str, optional
488
+ Filter to filter file extensions for custom types, by default "".
489
+ dataloader_params : dict, optional
490
+ Dataloader parameters, by default {}.
491
+ """
492
+ super().__init__()
493
+
494
+ if dataloader_params is None:
495
+ dataloader_params = {}
496
+ self.pred_config = pred_config
497
+ self.pred_data = pred_data
498
+ self.read_source_func = read_source_func or get_train_val_data
499
+ self.extension_filter = extension_filter
500
+ self.dataloader_params = dataloader_params
501
+
502
+ def prepare_data(self) -> None:
503
+ """Hook used to prepare the data before calling `setup`."""
504
+ # # TODO currently data preparation is handled in dataset creation, revisit!
505
+ pass
506
+
507
+ def setup(self, stage: str | None = None) -> None:
508
+ """
509
+ Hook called at the beginning of predict.
510
+
511
+ Parameters
512
+ ----------
513
+ stage : Optional[str], optional
514
+ Stage, by default None.
515
+ """
516
+ # Create prediction dataset using LCMultiChDloader
517
+ self.predict_dataset = LCMultiChDloader(
518
+ self.pred_config,
519
+ self.pred_data,
520
+ load_data_fn=self.read_source_func,
521
+ val_fraction=0.0, # No validation split for prediction
522
+ test_fraction=1.0, # No test split for prediction
523
+ )
524
+ self.predict_dataset.set_mean_std(*self.pred_config.data_stats)
525
+
526
+ def predict_dataloader(self) -> DataLoader:
527
+ """
528
+ Create a dataloader for prediction.
529
+
530
+ Returns
531
+ -------
532
+ DataLoader
533
+ Prediction dataloader.
534
+ """
535
+ return DataLoader(
536
+ self.predict_dataset,
537
+ batch_size=self.pred_config.batch_size,
538
+ **self.dataloader_params,
539
+ )
540
+
541
+
542
+ def create_microsplit_predict_datamodule(
543
+ pred_data: Union[str, Path, NDArray],
544
+ tile_size: tuple,
545
+ data_type: DataType,
546
+ axes: str,
547
+ batch_size: int = 1,
548
+ num_channels: int = 2,
549
+ depth3D: int = 1,
550
+ grid_size: int | None = None,
551
+ multiscale_count: int | None = None,
552
+ data_stats: tuple | None = None,
553
+ tiling_mode: TilingMode = TilingMode.ShiftBoundary,
554
+ read_source_func: Callable | None = None,
555
+ extension_filter: str = "",
556
+ dataloader_params: dict | None = None,
557
+ **dataset_kwargs,
558
+ ) -> MicroSplitPredictDataModule:
559
+ """
560
+ Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.
561
+
562
+ Parameters
563
+ ----------
564
+ pred_data : str or Path or numpy.ndarray
565
+ Prediction data, can be a path to a folder, a file or a numpy array.
566
+ tile_size : tuple
567
+ Size of one tile of data.
568
+ data_type : DataType
569
+ Type of the dataset (must be a DataType enum value).
570
+ axes : str
571
+ Axes of the data (e.g., 'SYX').
572
+ batch_size : int, default=1
573
+ Batch size for prediction dataloader.
574
+ num_channels : int, default=2
575
+ Number of channels in the input.
576
+ depth3D : int, default=1
577
+ Number of slices in 3D.
578
+ grid_size : tuple, optional
579
+ Grid size for patch extraction.
580
+ multiscale_count : int, optional
581
+ Number of LC scales.
582
+ data_stats : tuple, optional
583
+ Data statistics, by default None.
584
+ tiling_mode : TilingMode, default=ShiftBoundary
585
+ Tiling mode for patch extraction.
586
+ read_source_func : Callable, optional
587
+ Function to read the source data.
588
+ extension_filter : str, optional
589
+ File extension filter.
590
+ dataloader_params : dict, optional
591
+ Parameters for prediction dataloader.
592
+ **dataset_kwargs :
593
+ Additional arguments passed to MicroSplitDataConfig.
594
+
595
+ Returns
596
+ -------
597
+ MicroSplitPredictDataModule
598
+ Configured MicroSplitPredictDataModule instance.
599
+ """
600
+ if dataloader_params is None:
601
+ dataloader_params = {}
602
+
603
+ # Create prediction config with only valid parameters
604
+ prediction_config_params = {
605
+ "data_type": data_type,
606
+ "image_size": tile_size,
607
+ "num_channels": num_channels,
608
+ "depth3D": depth3D,
609
+ "grid_size": grid_size,
610
+ "multiscale_lowres_count": multiscale_count,
611
+ "data_stats": data_stats,
612
+ "tiling_mode": tiling_mode,
613
+ "batch_size": batch_size,
614
+ "datasplit_type": DataSplitType.Test, # For prediction, use all data
615
+ **dataset_kwargs,
616
+ }
617
+
618
+ pred_config = MicroSplitDataConfig(**prediction_config_params)
619
+
620
+ # Remove batch_size from dataloader_params if present
621
+ if "batch_size" in dataloader_params:
622
+ del dataloader_params["batch_size"]
623
+
624
+ return MicroSplitPredictDataModule(
625
+ pred_config=pred_config,
626
+ pred_data=pred_data,
627
+ read_source_func=(
628
+ read_source_func if read_source_func is not None else get_train_val_data
629
+ ),
630
+ extension_filter=extension_filter,
631
+ dataloader_params=dataloader_params,
632
+ )