careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,680 @@
1
+ """Training and validation Lightning data modules."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Literal, Optional, Union
5
+
6
+ import numpy as np
7
+ import pytorch_lightning as L
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import DataLoader
10
+
11
+ from careamics.config import DataConfig
12
+ from careamics.config.data_model import TRANSFORMS_UNION
13
+ from careamics.config.support import SupportedData
14
+ from careamics.dataset.dataset_utils import (
15
+ get_files_size,
16
+ list_files,
17
+ validate_source_target_files,
18
+ )
19
+ from careamics.dataset.in_memory_dataset import (
20
+ InMemoryDataset,
21
+ )
22
+ from careamics.dataset.iterable_dataset import (
23
+ PathIterableDataset,
24
+ )
25
+ from careamics.file_io.read import get_read_func
26
+ from careamics.utils import get_logger, get_ram_size
27
+
28
+ DatasetType = Union[InMemoryDataset, PathIterableDataset]
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class TrainDataModule(L.LightningDataModule):
34
+ """
35
+ CAREamics Ligthning training and validation data module.
36
+
37
+ The data module can be used with Path, str or numpy arrays. In the case of
38
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
39
+ inputs, it calculates the total file size and estimate whether it can fit in
40
+ memory. If it does not, it iterates through the files. This behaviour can be
41
+ deactivated by setting `use_in_memory` to False, in which case it will
42
+ always use the iterating dataset to train on a Path or str.
43
+
44
+ The data can be either a folder containing images or a single file.
45
+
46
+ Validation can be omitted, in which case the validation data is extracted from
47
+ the training data. The percentage of the training data to use for validation,
48
+ as well as the minimum number of patches or files to split from the training
49
+ data can be set using `val_percentage` and `val_minimum_split`, respectively.
50
+
51
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
52
+ and provide a function that returns a numpy array from a path as
53
+ `read_source_func` parameter. The function will receive a Path object and
54
+ an axies string as arguments, the axes being derived from the `data_config`.
55
+
56
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
57
+ "*.czi") to filter the files extension using `extension_filter`.
58
+
59
+ Parameters
60
+ ----------
61
+ data_config : DataModel
62
+ Pydantic model for CAREamics data configuration.
63
+ train_data : pathlib.Path or str or numpy.ndarray
64
+ Training data, can be a path to a folder, a file or a numpy array.
65
+ val_data : pathlib.Path or str or numpy.ndarray, optional
66
+ Validation data, can be a path to a folder, a file or a numpy array, by
67
+ default None.
68
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
69
+ Training target data, can be a path to a folder, a file or a numpy array, by
70
+ default None.
71
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
72
+ Validation target data, can be a path to a folder, a file or a numpy array,
73
+ by default None.
74
+ read_source_func : Callable, optional
75
+ Function to read the source data, by default None. Only used for `custom`
76
+ data type (see DataModel).
77
+ extension_filter : str, optional
78
+ Filter for file extensions, by default "". Only used for `custom` data types
79
+ (see DataModel).
80
+ val_percentage : float, optional
81
+ Percentage of the training data to use for validation, by default 0.1. Only
82
+ used if `val_data` is None.
83
+ val_minimum_split : int, optional
84
+ Minimum number of patches or files to split from the training data for
85
+ validation, by default 5. Only used if `val_data` is None.
86
+ use_in_memory : bool, optional
87
+ Use in memory dataset if possible, by default True.
88
+
89
+ Attributes
90
+ ----------
91
+ data_config : DataModel
92
+ CAREamics data configuration.
93
+ data_type : SupportedData
94
+ Expected data type, one of "tiff", "array" or "custom".
95
+ batch_size : int
96
+ Batch size.
97
+ use_in_memory : bool
98
+ Whether to use in memory dataset if possible.
99
+ train_data : pathlib.Path or numpy.ndarray
100
+ Training data.
101
+ val_data : pathlib.Path or numpy.ndarray
102
+ Validation data.
103
+ train_data_target : pathlib.Path or numpy.ndarray
104
+ Training target data.
105
+ val_data_target : pathlib.Path or numpy.ndarray
106
+ Validation target data.
107
+ val_percentage : float
108
+ Percentage of the training data to use for validation, if no validation data is
109
+ provided.
110
+ val_minimum_split : int
111
+ Minimum number of patches or files to split from the training data for
112
+ validation, if no validation data is provided.
113
+ read_source_func : Optional[Callable]
114
+ Function to read the source data, used if `data_type` is `custom`.
115
+ extension_filter : str
116
+ Filter for file extensions, used if `data_type` is `custom`.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ data_config: DataConfig,
122
+ train_data: Union[Path, str, NDArray],
123
+ val_data: Optional[Union[Path, str, NDArray]] = None,
124
+ train_data_target: Optional[Union[Path, str, NDArray]] = None,
125
+ val_data_target: Optional[Union[Path, str, NDArray]] = None,
126
+ read_source_func: Optional[Callable] = None,
127
+ extension_filter: str = "",
128
+ val_percentage: float = 0.1,
129
+ val_minimum_split: int = 5,
130
+ use_in_memory: bool = True,
131
+ ) -> None:
132
+ """
133
+ Constructor.
134
+
135
+ Parameters
136
+ ----------
137
+ data_config : DataModel
138
+ Pydantic model for CAREamics data configuration.
139
+ train_data : pathlib.Path or str or numpy.ndarray
140
+ Training data, can be a path to a folder, a file or a numpy array.
141
+ val_data : pathlib.Path or str or numpy.ndarray, optional
142
+ Validation data, can be a path to a folder, a file or a numpy array, by
143
+ default None.
144
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
145
+ Training target data, can be a path to a folder, a file or a numpy array, by
146
+ default None.
147
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
148
+ Validation target data, can be a path to a folder, a file or a numpy array,
149
+ by default None.
150
+ read_source_func : Callable, optional
151
+ Function to read the source data, by default None. Only used for `custom`
152
+ data type (see DataModel).
153
+ extension_filter : str, optional
154
+ Filter for file extensions, by default "". Only used for `custom` data types
155
+ (see DataModel).
156
+ val_percentage : float, optional
157
+ Percentage of the training data to use for validation, by default 0.1. Only
158
+ used if `val_data` is None.
159
+ val_minimum_split : int, optional
160
+ Minimum number of patches or files to split from the training data for
161
+ validation, by default 5. Only used if `val_data` is None.
162
+ use_in_memory : bool, optional
163
+ Use in memory dataset if possible, by default True.
164
+
165
+ Raises
166
+ ------
167
+ NotImplementedError
168
+ Raised if target data is provided.
169
+ ValueError
170
+ If the input types are mixed (e.g. Path and numpy.ndarray).
171
+ ValueError
172
+ If the data type is `custom` and no `read_source_func` is provided.
173
+ ValueError
174
+ If the data type is `array` and the input is not a numpy array.
175
+ ValueError
176
+ If the data type is `tiff` and the input is neither a Path nor a str.
177
+ """
178
+ super().__init__()
179
+
180
+ # check input types coherence (no mixed types)
181
+ inputs = [train_data, val_data, train_data_target, val_data_target]
182
+ types_set = {type(i) for i in inputs}
183
+ if len(types_set) > 2: # None + expected type
184
+ raise ValueError(
185
+ f"Inputs for `train_data`, `val_data`, `train_data_target` and "
186
+ f"`val_data_target` must be of the same type or None. Got "
187
+ f"{types_set}."
188
+ )
189
+
190
+ # check that a read source function is provided for custom types
191
+ if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
192
+ raise ValueError(
193
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
194
+ f"specifying a `read_source_func` and an `extension_filer`."
195
+ )
196
+
197
+ # check correct input type
198
+ if (
199
+ isinstance(train_data, np.ndarray)
200
+ and data_config.data_type != SupportedData.ARRAY
201
+ ):
202
+ raise ValueError(
203
+ f"Received a numpy array as input, but the data type was set to "
204
+ f"{data_config.data_type}. Set the data type in the configuration "
205
+ f"to {SupportedData.ARRAY} to train on numpy arrays."
206
+ )
207
+
208
+ # and that Path or str are passed, if tiff file type specified
209
+ elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
210
+ data_config.data_type != SupportedData.TIFF
211
+ and data_config.data_type != SupportedData.CUSTOM
212
+ ):
213
+ raise ValueError(
214
+ f"Received a path as input, but the data type was neither set to "
215
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
216
+ f"in the configuration to {SupportedData.TIFF} or "
217
+ f"{SupportedData.CUSTOM} to train on files."
218
+ )
219
+
220
+ # configuration
221
+ self.data_config: DataConfig = data_config
222
+ self.data_type: str = data_config.data_type
223
+ self.batch_size: int = data_config.batch_size
224
+ self.use_in_memory: bool = use_in_memory
225
+
226
+ # data: make data Path or np.ndarray, use type annotations for mypy
227
+ self.train_data: Union[Path, NDArray] = (
228
+ Path(train_data) if isinstance(train_data, str) else train_data
229
+ )
230
+
231
+ self.val_data: Union[Path, NDArray] = (
232
+ Path(val_data) if isinstance(val_data, str) else val_data
233
+ )
234
+
235
+ self.train_data_target: Union[Path, NDArray] = (
236
+ Path(train_data_target)
237
+ if isinstance(train_data_target, str)
238
+ else train_data_target
239
+ )
240
+
241
+ self.val_data_target: Union[Path, NDArray] = (
242
+ Path(val_data_target)
243
+ if isinstance(val_data_target, str)
244
+ else val_data_target
245
+ )
246
+
247
+ # validation split
248
+ self.val_percentage = val_percentage
249
+ self.val_minimum_split = val_minimum_split
250
+
251
+ # read source function corresponding to the requested type
252
+ if data_config.data_type == SupportedData.CUSTOM.value:
253
+ # mypy check
254
+ assert read_source_func is not None
255
+
256
+ self.read_source_func: Callable = read_source_func
257
+
258
+ elif data_config.data_type != SupportedData.ARRAY:
259
+ self.read_source_func = get_read_func(data_config.data_type)
260
+
261
+ self.extension_filter: str = extension_filter
262
+
263
+ # Pytorch dataloader parameters
264
+ self.dataloader_params: dict[str, Any] = (
265
+ data_config.dataloader_params if data_config.dataloader_params else {}
266
+ )
267
+
268
+ def prepare_data(self) -> None:
269
+ """
270
+ Hook used to prepare the data before calling `setup`.
271
+
272
+ Here, we only need to examine the data if it was provided as a str or a Path.
273
+
274
+ TODO: from lightning doc:
275
+ prepare_data is called from the main process. It is not recommended to assign
276
+ state here (e.g. self.x = y) since it is called on a single process and if you
277
+ assign states here then they won't be available for other processes.
278
+
279
+ https://lightning.ai/docs/pytorch/stable/data/datamodule.html
280
+ """
281
+ # if the data is a Path or a str
282
+ if (
283
+ not isinstance(self.train_data, np.ndarray)
284
+ and not isinstance(self.val_data, np.ndarray)
285
+ and not isinstance(self.train_data_target, np.ndarray)
286
+ and not isinstance(self.val_data_target, np.ndarray)
287
+ ):
288
+ # list training files
289
+ self.train_files = list_files(
290
+ self.train_data, self.data_type, self.extension_filter
291
+ )
292
+ self.train_files_size = get_files_size(self.train_files)
293
+
294
+ # list validation files
295
+ if self.val_data is not None:
296
+ self.val_files = list_files(
297
+ self.val_data, self.data_type, self.extension_filter
298
+ )
299
+
300
+ # same for target data
301
+ if self.train_data_target is not None:
302
+ self.train_target_files: list[Path] = list_files(
303
+ self.train_data_target, self.data_type, self.extension_filter
304
+ )
305
+
306
+ # verify that they match the training data
307
+ validate_source_target_files(self.train_files, self.train_target_files)
308
+
309
+ if self.val_data_target is not None:
310
+ self.val_target_files = list_files(
311
+ self.val_data_target, self.data_type, self.extension_filter
312
+ )
313
+
314
+ # verify that they match the validation data
315
+ validate_source_target_files(self.val_files, self.val_target_files)
316
+
317
+ def setup(self, *args: Any, **kwargs: Any) -> None:
318
+ """Hook called at the beginning of fit, validate, or predict.
319
+
320
+ Parameters
321
+ ----------
322
+ *args : Any
323
+ Unused.
324
+ **kwargs : Any
325
+ Unused.
326
+ """
327
+ # if numpy array
328
+ if self.data_type == SupportedData.ARRAY:
329
+ # mypy checks
330
+ assert isinstance(self.train_data, np.ndarray)
331
+ if self.train_data_target is not None:
332
+ assert isinstance(self.train_data_target, np.ndarray)
333
+
334
+ # train dataset
335
+ self.train_dataset: DatasetType = InMemoryDataset(
336
+ data_config=self.data_config,
337
+ inputs=self.train_data,
338
+ input_target=self.train_data_target,
339
+ )
340
+
341
+ # validation dataset
342
+ if self.val_data is not None:
343
+ # mypy checks
344
+ assert isinstance(self.val_data, np.ndarray)
345
+ if self.val_data_target is not None:
346
+ assert isinstance(self.val_data_target, np.ndarray)
347
+
348
+ # create its own dataset
349
+ self.val_dataset: DatasetType = InMemoryDataset(
350
+ data_config=self.data_config,
351
+ inputs=self.val_data,
352
+ input_target=self.val_data_target,
353
+ )
354
+ else:
355
+ # extract validation from the training patches
356
+ self.val_dataset = self.train_dataset.split_dataset(
357
+ percentage=self.val_percentage,
358
+ minimum_patches=self.val_minimum_split,
359
+ )
360
+
361
+ # else we read files
362
+ else:
363
+ # Heuristics, if the file size is smaller than 80% of the RAM,
364
+ # we run the training in memory, otherwise we switch to iterable dataset
365
+ # The switch is deactivated if use_in_memory is False
366
+ if self.use_in_memory and self.train_files_size < get_ram_size() * 0.8:
367
+ # train dataset
368
+ self.train_dataset = InMemoryDataset(
369
+ data_config=self.data_config,
370
+ inputs=self.train_files,
371
+ input_target=(
372
+ self.train_target_files if self.train_data_target else None
373
+ ),
374
+ read_source_func=self.read_source_func,
375
+ )
376
+
377
+ # validation dataset
378
+ if self.val_data is not None:
379
+ self.val_dataset = InMemoryDataset(
380
+ data_config=self.data_config,
381
+ inputs=self.val_files,
382
+ input_target=(
383
+ self.val_target_files if self.val_data_target else None
384
+ ),
385
+ read_source_func=self.read_source_func,
386
+ )
387
+ else:
388
+ # split dataset
389
+ self.val_dataset = self.train_dataset.split_dataset(
390
+ percentage=self.val_percentage,
391
+ minimum_patches=self.val_minimum_split,
392
+ )
393
+
394
+ # else if the data is too large, load file by file during training
395
+ else:
396
+ # create training dataset
397
+ self.train_dataset = PathIterableDataset(
398
+ data_config=self.data_config,
399
+ src_files=self.train_files,
400
+ target_files=(
401
+ self.train_target_files if self.train_data_target else None
402
+ ),
403
+ read_source_func=self.read_source_func,
404
+ )
405
+
406
+ # create validation dataset
407
+ if self.val_data is not None:
408
+ # create its own dataset
409
+ self.val_dataset = PathIterableDataset(
410
+ data_config=self.data_config,
411
+ src_files=self.val_files,
412
+ target_files=(
413
+ self.val_target_files if self.val_data_target else None
414
+ ),
415
+ read_source_func=self.read_source_func,
416
+ )
417
+ elif len(self.train_files) <= self.val_minimum_split:
418
+ raise ValueError(
419
+ f"Not enough files to split a minimum of "
420
+ f"{self.val_minimum_split} files, got {len(self.train_files)} "
421
+ f"files."
422
+ )
423
+ else:
424
+ # extract validation from the training patches
425
+ self.val_dataset = self.train_dataset.split_dataset(
426
+ percentage=self.val_percentage,
427
+ minimum_number=self.val_minimum_split,
428
+ )
429
+
430
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
431
+ """Return training data statistics.
432
+
433
+ Returns
434
+ -------
435
+ tuple of list
436
+ Means and standard deviations across channels of the training data.
437
+ """
438
+ return self.train_dataset.get_data_statistics()
439
+
440
+ def train_dataloader(self) -> Any:
441
+ """
442
+ Create a dataloader for training.
443
+
444
+ Returns
445
+ -------
446
+ Any
447
+ Training dataloader.
448
+ """
449
+ return DataLoader(
450
+ self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
451
+ )
452
+
453
+ def val_dataloader(self) -> Any:
454
+ """
455
+ Create a dataloader for validation.
456
+
457
+ Returns
458
+ -------
459
+ Any
460
+ Validation dataloader.
461
+ """
462
+ return DataLoader(
463
+ self.val_dataset,
464
+ batch_size=self.batch_size,
465
+ )
466
+
467
+
468
+ def create_train_datamodule(
469
+ train_data: Union[str, Path, NDArray],
470
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
471
+ patch_size: list[int],
472
+ axes: str,
473
+ batch_size: int,
474
+ val_data: Optional[Union[str, Path, NDArray]] = None,
475
+ transforms: Optional[list[TRANSFORMS_UNION]] = None,
476
+ train_target_data: Optional[Union[str, Path, NDArray]] = None,
477
+ val_target_data: Optional[Union[str, Path, NDArray]] = None,
478
+ read_source_func: Optional[Callable] = None,
479
+ extension_filter: str = "",
480
+ val_percentage: float = 0.1,
481
+ val_minimum_patches: int = 5,
482
+ dataloader_params: Optional[dict] = None,
483
+ use_in_memory: bool = True,
484
+ use_n2v2: bool = False,
485
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
486
+ struct_n2v_span: int = 5,
487
+ ) -> TrainDataModule:
488
+ """Create a TrainDataModule.
489
+
490
+ This function is used to explicitly pass the parameters usually contained in a
491
+ `data_model` configuration to a TrainDataModule.
492
+
493
+ Since the lightning datamodule has no access to the model, make sure that the
494
+ parameters passed to the datamodule are consistent with the model's requirements and
495
+ are coherent.
496
+
497
+ The data module can be used with Path, str or numpy arrays. In the case of
498
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
499
+ inputs, it calculates the total file size and estimate whether it can fit in
500
+ memory. If it does not, it iterates through the files. This behaviour can be
501
+ deactivated by setting `use_in_memory` to False, in which case it will
502
+ always use the iterating dataset to train on a Path or str.
503
+
504
+ To use array data, set `data_type` to `array` and pass a numpy array to
505
+ `train_data`.
506
+
507
+ In particular, N2V requires a specific transformation (N2V manipulates), which is
508
+ not compatible with supervised training. The default transformations applied to the
509
+ training patches are defined in `careamics.config.data_model`. To use different
510
+ transformations, pass a list of transforms. See examples for more details.
511
+
512
+ By default, CAREamics only supports types defined in
513
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
514
+ `data_type` to `custom` and provide a function that returns a numpy array from a
515
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression (e.g.
516
+ "*.jpeg") to filter the files extension using `extension_filter`.
517
+
518
+ In the absence of validation data, the validation data is extracted from the
519
+ training data. The percentage of the training data to use for validation, as well as
520
+ the minimum number of patches to split from the training data for validation can be
521
+ set using `val_percentage` and `val_minimum_patches`, respectively.
522
+
523
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
524
+ except for `batch_size`, which is set by the `batch_size` parameter.
525
+
526
+ Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to
527
+ use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define
528
+ the axis and span of the structN2V mask. These parameters are without effect if
529
+ a `train_target_data` or if `transforms` are provided.
530
+
531
+ Parameters
532
+ ----------
533
+ train_data : pathlib.Path or str or numpy.ndarray
534
+ Training data.
535
+ data_type : {"array", "tiff", "custom"}
536
+ Data type, see `SupportedData` for available options.
537
+ patch_size : list of int
538
+ Patch size, 2D or 3D patch size.
539
+ axes : str
540
+ Axes of the data, chosen amongst SCZYX.
541
+ batch_size : int
542
+ Batch size.
543
+ val_data : pathlib.Path or str or numpy.ndarray, optional
544
+ Validation data, by default None.
545
+ transforms : list of Transforms, optional
546
+ List of transforms to apply to training patches. If None, default transforms
547
+ are applied.
548
+ train_target_data : pathlib.Path or str or numpy.ndarray, optional
549
+ Training target data, by default None.
550
+ val_target_data : pathlib.Path or str or numpy.ndarray, optional
551
+ Validation target data, by default None.
552
+ read_source_func : Callable, optional
553
+ Function to read the source data, used if `data_type` is `custom`, by
554
+ default None.
555
+ extension_filter : str, optional
556
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
557
+ val_percentage : float, optional
558
+ Percentage of the training data to use for validation if no validation data
559
+ is given, by default 0.1.
560
+ val_minimum_patches : int, optional
561
+ Minimum number of patches to split from the training data for validation if
562
+ no validation data is given, by default 5.
563
+ dataloader_params : dict, optional
564
+ Pytorch dataloader parameters, by default {}.
565
+ use_in_memory : bool, optional
566
+ Use in memory dataset if possible, by default True.
567
+ use_n2v2 : bool, optional
568
+ Use N2V2 transformation during training, by default False.
569
+ struct_n2v_axis : {"horizontal", "vertical", "none"}, optional
570
+ Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
571
+ default "none".
572
+ struct_n2v_span : int, optional
573
+ Span for the structN2V mask, by default 5.
574
+
575
+ Returns
576
+ -------
577
+ TrainDataModule
578
+ CAREamics training Lightning data module.
579
+
580
+ Examples
581
+ --------
582
+ Create a TrainingDataModule with default transforms with a numpy array:
583
+ >>> import numpy as np
584
+ >>> from careamics.lightning import create_train_datamodule
585
+ >>> my_array = np.arange(256).reshape(16, 16)
586
+ >>> data_module = create_train_datamodule(
587
+ ... train_data=my_array,
588
+ ... data_type="array",
589
+ ... patch_size=(8, 8),
590
+ ... axes='YX',
591
+ ... batch_size=2,
592
+ ... )
593
+
594
+ For custom data types (those not supported by CAREamics), then one can pass a read
595
+ function and a filter for the files extension:
596
+ >>> import numpy as np
597
+ >>> from careamics.lightning import create_train_datamodule
598
+ >>>
599
+ >>> def read_npy(path):
600
+ ... return np.load(path)
601
+ >>>
602
+ >>> data_module = create_train_datamodule(
603
+ ... train_data="path/to/data",
604
+ ... data_type="custom",
605
+ ... patch_size=(8, 8),
606
+ ... axes='YX',
607
+ ... batch_size=2,
608
+ ... read_source_func=read_npy,
609
+ ... extension_filter="*.npy",
610
+ ... )
611
+
612
+ If you want to use a different set of transformations, you can pass a list of
613
+ transforms:
614
+ >>> import numpy as np
615
+ >>> from careamics.lightning import create_train_datamodule
616
+ >>> from careamics.config.support import SupportedTransform
617
+ >>> my_array = np.arange(256).reshape(16, 16)
618
+ >>> my_transforms = [
619
+ ... {
620
+ ... "name": SupportedTransform.XY_FLIP.value,
621
+ ... }
622
+ ... ]
623
+ >>> data_module = create_train_datamodule(
624
+ ... train_data=my_array,
625
+ ... data_type="array",
626
+ ... patch_size=(8, 8),
627
+ ... axes='YX',
628
+ ... batch_size=2,
629
+ ... transforms=my_transforms,
630
+ ... )
631
+ """
632
+ if dataloader_params is None:
633
+ dataloader_params = {}
634
+
635
+ data_dict: dict[str, Any] = {
636
+ "mode": "train",
637
+ "data_type": data_type,
638
+ "patch_size": patch_size,
639
+ "axes": axes,
640
+ "batch_size": batch_size,
641
+ "dataloader_params": dataloader_params,
642
+ }
643
+
644
+ # if transforms are passed (otherwise it will use the default ones)
645
+ if transforms is not None:
646
+ data_dict["transforms"] = transforms
647
+
648
+ # validate configuration
649
+ data_config = DataConfig(**data_dict)
650
+
651
+ # N2V specific checks, N2V, structN2V, and transforms
652
+ if data_config.has_n2v_manipulate():
653
+ # there is not target, n2v2 and structN2V can be changed
654
+ if train_target_data is None:
655
+ data_config.set_N2V2(use_n2v2)
656
+ data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
657
+ else:
658
+ raise ValueError(
659
+ "Cannot have both supervised training (target data) and "
660
+ "N2V manipulation in the transforms. Pass a list of transforms "
661
+ "that is compatible with your supervised training."
662
+ )
663
+
664
+ # sanity check on the dataloader parameters
665
+ if "batch_size" in dataloader_params:
666
+ # remove it
667
+ del dataloader_params["batch_size"]
668
+
669
+ return TrainDataModule(
670
+ data_config=data_config,
671
+ train_data=train_data,
672
+ val_data=val_data,
673
+ train_data_target=train_target_data,
674
+ val_data_target=val_target_data,
675
+ read_source_func=read_source_func,
676
+ extension_filter=extension_filter,
677
+ val_percentage=val_percentage,
678
+ val_minimum_split=val_minimum_patches,
679
+ use_in_memory=use_in_memory,
680
+ )
@@ -0,0 +1,15 @@
1
+ """Losses module."""
2
+
3
+ __all__ = [
4
+ "loss_factory",
5
+ "mae_loss",
6
+ "mse_loss",
7
+ "n2v_loss",
8
+ "denoisplit_loss",
9
+ "musplit_loss",
10
+ "denoisplit_musplit_loss",
11
+ ]
12
+
13
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss
14
+ from .loss_factory import loss_factory
15
+ from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
@@ -0,0 +1 @@
1
+ """FCN losses."""