careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,743 @@
1
+ """Training and validation Lightning data modules."""
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as L
7
+ from albumentations import Compose
8
+ from torch.utils.data import DataLoader
9
+
10
+ from careamics.config import DataConfig
11
+ from careamics.config.data_model import TRANSFORMS_UNION
12
+ from careamics.config.support import SupportedData
13
+ from careamics.dataset.dataset_utils import (
14
+ get_files_size,
15
+ get_read_func,
16
+ list_files,
17
+ validate_source_target_files,
18
+ )
19
+ from careamics.dataset.in_memory_dataset import (
20
+ InMemoryDataset,
21
+ )
22
+ from careamics.dataset.iterable_dataset import (
23
+ PathIterableDataset,
24
+ )
25
+ from careamics.utils import get_logger, get_ram_size
26
+
27
+ DatasetType = Union[InMemoryDataset, PathIterableDataset]
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ class CAREamicsTrainData(L.LightningDataModule):
33
+ """
34
+ CAREamics Ligthning training and validation data module.
35
+
36
+ The data module can be used with Path, str or numpy arrays. In the case of
37
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
38
+ inputs, it calculates the total file size and estimate whether it can fit in
39
+ memory. If it does not, it iterates through the files. This behaviour can be
40
+ deactivated by setting `use_in_memory` to False, in which case it will
41
+ always use the iterating dataset to train on a Path or str.
42
+
43
+ The data can be either a folder containing images or a single file.
44
+
45
+ Validation can be omitted, in which case the validation data is extracted from
46
+ the training data. The percentage of the training data to use for validation,
47
+ as well as the minimum number of patches or files to split from the training
48
+ data can be set using `val_percentage` and `val_minimum_split`, respectively.
49
+
50
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
51
+ and provide a function that returns a numpy array from a path as
52
+ `read_source_func` parameter. The function will receive a Path object and
53
+ an axies string as arguments, the axes being derived from the `data_config`.
54
+
55
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
56
+ "*.czi") to filter the files extension using `extension_filter`.
57
+
58
+ Parameters
59
+ ----------
60
+ data_config : DataModel
61
+ Pydantic model for CAREamics data configuration.
62
+ train_data : Union[Path, str, np.ndarray]
63
+ Training data, can be a path to a folder, a file or a numpy array.
64
+ val_data : Optional[Union[Path, str, np.ndarray]], optional
65
+ Validation data, can be a path to a folder, a file or a numpy array, by
66
+ default None.
67
+ train_data_target : Optional[Union[Path, str, np.ndarray]], optional
68
+ Training target data, can be a path to a folder, a file or a numpy array, by
69
+ default None.
70
+ val_data_target : Optional[Union[Path, str, np.ndarray]], optional
71
+ Validation target data, can be a path to a folder, a file or a numpy array,
72
+ by default None.
73
+ read_source_func : Optional[Callable], optional
74
+ Function to read the source data, by default None. Only used for `custom`
75
+ data type (see DataModel).
76
+ extension_filter : str, optional
77
+ Filter for file extensions, by default "". Only used for `custom` data types
78
+ (see DataModel).
79
+ val_percentage : float, optional
80
+ Percentage of the training data to use for validation, by default 0.1. Only
81
+ used if `val_data` is None.
82
+ val_minimum_split : int, optional
83
+ Minimum number of patches or files to split from the training data for
84
+ validation, by default 5. Only used if `val_data` is None.
85
+ use_in_memory : bool, optional
86
+ Use in memory dataset if possible, by default True.
87
+
88
+ Attributes
89
+ ----------
90
+ data_config : DataModel
91
+ CAREamics data configuration.
92
+ data_type : SupportedData
93
+ Expected data type, one of "tiff", "array" or "custom".
94
+ batch_size : int
95
+ Batch size.
96
+ use_in_memory : bool
97
+ Whether to use in memory dataset if possible.
98
+ train_data : Union[Path, str, np.ndarray]
99
+ Training data.
100
+ val_data : Optional[Union[Path, str, np.ndarray]]
101
+ Validation data.
102
+ train_data_target : Optional[Union[Path, str, np.ndarray]]
103
+ Training target data.
104
+ val_data_target : Optional[Union[Path, str, np.ndarray]]
105
+ Validation target data.
106
+ val_percentage : float
107
+ Percentage of the training data to use for validation, if no validation data is
108
+ provided.
109
+ val_minimum_split : int
110
+ Minimum number of patches or files to split from the training data for
111
+ validation, if no validation data is provided.
112
+ read_source_func : Optional[Callable]
113
+ Function to read the source data, used if `data_type` is `custom`.
114
+ extension_filter : str
115
+ Filter for file extensions, used if `data_type` is `custom`.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ data_config: DataConfig,
121
+ train_data: Union[Path, str, np.ndarray],
122
+ val_data: Optional[Union[Path, str, np.ndarray]] = None,
123
+ train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
124
+ val_data_target: Optional[Union[Path, str, np.ndarray]] = None,
125
+ read_source_func: Optional[Callable] = None,
126
+ extension_filter: str = "",
127
+ val_percentage: float = 0.1,
128
+ val_minimum_split: int = 5,
129
+ use_in_memory: bool = True,
130
+ ) -> None:
131
+ """
132
+ Constructor.
133
+
134
+ Parameters
135
+ ----------
136
+ data_config : DataModel
137
+ Pydantic model for CAREamics data configuration.
138
+ train_data : Union[Path, str, np.ndarray]
139
+ Training data, can be a path to a folder, a file or a numpy array.
140
+ val_data : Optional[Union[Path, str, np.ndarray]], optional
141
+ Validation data, can be a path to a folder, a file or a numpy array, by
142
+ default None.
143
+ train_data_target : Optional[Union[Path, str, np.ndarray]], optional
144
+ Training target data, can be a path to a folder, a file or a numpy array, by
145
+ default None.
146
+ val_data_target : Optional[Union[Path, str, np.ndarray]], optional
147
+ Validation target data, can be a path to a folder, a file or a numpy array,
148
+ by default None.
149
+ read_source_func : Optional[Callable], optional
150
+ Function to read the source data, by default None. Only used for `custom`
151
+ data type (see DataModel).
152
+ extension_filter : str, optional
153
+ Filter for file extensions, by default "". Only used for `custom` data types
154
+ (see DataModel).
155
+ val_percentage : float, optional
156
+ Percentage of the training data to use for validation, by default 0.1. Only
157
+ used if `val_data` is None.
158
+ val_minimum_split : int, optional
159
+ Minimum number of patches or files to split from the training data for
160
+ validation, by default 5. Only used if `val_data` is None.
161
+ use_in_memory : bool, optional
162
+ Use in memory dataset if possible, by default True.
163
+
164
+ Raises
165
+ ------
166
+ NotImplementedError
167
+ Raised if target data is provided.
168
+ ValueError
169
+ If the input types are mixed (e.g. Path and np.ndarray).
170
+ ValueError
171
+ If the data type is `custom` and no `read_source_func` is provided.
172
+ ValueError
173
+ If the data type is `array` and the input is not a numpy array.
174
+ ValueError
175
+ If the data type is `tiff` and the input is neither a Path nor a str.
176
+ """
177
+ super().__init__()
178
+
179
+ # check input types coherence (no mixed types)
180
+ inputs = [train_data, val_data, train_data_target, val_data_target]
181
+ types_set = {type(i) for i in inputs}
182
+ if len(types_set) > 2: # None + expected type
183
+ raise ValueError(
184
+ f"Inputs for `train_data`, `val_data`, `train_data_target` and "
185
+ f"`val_data_target` must be of the same type or None. Got "
186
+ f"{types_set}."
187
+ )
188
+
189
+ # check that a read source function is provided for custom types
190
+ if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
191
+ raise ValueError(
192
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
193
+ f"specifying a `read_source_func` and an `extension_filer`."
194
+ )
195
+
196
+ # check correct input type
197
+ if (
198
+ isinstance(train_data, np.ndarray)
199
+ and data_config.data_type != SupportedData.ARRAY
200
+ ):
201
+ raise ValueError(
202
+ f"Received a numpy array as input, but the data type was set to "
203
+ f"{data_config.data_type}. Set the data type in the configuration "
204
+ f"to {SupportedData.ARRAY} to train on numpy arrays."
205
+ )
206
+
207
+ # and that Path or str are passed, if tiff file type specified
208
+ elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
209
+ data_config.data_type != SupportedData.TIFF
210
+ and data_config.data_type != SupportedData.CUSTOM
211
+ ):
212
+ raise ValueError(
213
+ f"Received a path as input, but the data type was neither set to "
214
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
215
+ f"in the configuration to {SupportedData.TIFF} or "
216
+ f"{SupportedData.CUSTOM} to train on files."
217
+ )
218
+
219
+ # configuration
220
+ self.data_config = data_config
221
+ self.data_type = data_config.data_type
222
+ self.batch_size = data_config.batch_size
223
+ self.use_in_memory = use_in_memory
224
+
225
+ # data
226
+ self.train_data = train_data
227
+ self.val_data = val_data
228
+
229
+ self.train_data_target = train_data_target
230
+ self.val_data_target = val_data_target
231
+ self.val_percentage = val_percentage
232
+ self.val_minimum_split = val_minimum_split
233
+
234
+ # read source function corresponding to the requested type
235
+ if data_config.data_type == SupportedData.CUSTOM:
236
+ # mypy check
237
+ assert read_source_func is not None
238
+
239
+ self.read_source_func: Callable = read_source_func
240
+
241
+ elif data_config.data_type != SupportedData.ARRAY:
242
+ self.read_source_func = get_read_func(data_config.data_type)
243
+
244
+ self.extension_filter = extension_filter
245
+
246
+ # Pytorch dataloader parameters
247
+ self.dataloader_params = (
248
+ data_config.dataloader_params if data_config.dataloader_params else {}
249
+ )
250
+
251
+ def prepare_data(self) -> None:
252
+ """
253
+ Hook used to prepare the data before calling `setup`.
254
+
255
+ Here, we only need to examine the data if it was provided as a str or a Path.
256
+
257
+ TODO: from lightning doc:
258
+ prepare_data is called from the main process. It is not recommended to assign
259
+ state here (e.g. self.x = y) since it is called on a single process and if you
260
+ assign states here then they won't be available for other processes.
261
+
262
+ https://lightning.ai/docs/pytorch/stable/data/datamodule.html
263
+ """
264
+ # if the data is a Path or a str
265
+ if (
266
+ not isinstance(self.train_data, np.ndarray)
267
+ and not isinstance(self.val_data, np.ndarray)
268
+ and not isinstance(self.train_data_target, np.ndarray)
269
+ and not isinstance(self.val_data_target, np.ndarray)
270
+ ):
271
+ # list training files
272
+ self.train_files = list_files(
273
+ self.train_data, self.data_type, self.extension_filter
274
+ )
275
+ self.train_files_size = get_files_size(self.train_files)
276
+
277
+ # list validation files
278
+ if self.val_data is not None:
279
+ self.val_files = list_files(
280
+ self.val_data, self.data_type, self.extension_filter
281
+ )
282
+
283
+ # same for target data
284
+ if self.train_data_target is not None:
285
+ self.train_target_files: List[Path] = list_files(
286
+ self.train_data_target, self.data_type, self.extension_filter
287
+ )
288
+
289
+ # verify that they match the training data
290
+ validate_source_target_files(self.train_files, self.train_target_files)
291
+
292
+ if self.val_data_target is not None:
293
+ self.val_target_files = list_files(
294
+ self.val_data_target, self.data_type, self.extension_filter
295
+ )
296
+
297
+ # verify that they match the validation data
298
+ validate_source_target_files(self.val_files, self.val_target_files)
299
+
300
+ def setup(self, *args: Any, **kwargs: Any) -> None:
301
+ """Hook called at the beginning of fit, validate, or predict.
302
+
303
+ Parameters
304
+ ----------
305
+ *args : Any
306
+ Unused.
307
+ **kwargs : Any
308
+ Unused.
309
+ """
310
+ # if numpy array
311
+ if self.data_type == SupportedData.ARRAY:
312
+ # train dataset
313
+ self.train_dataset: DatasetType = InMemoryDataset(
314
+ data_config=self.data_config,
315
+ inputs=self.train_data,
316
+ data_target=self.train_data_target,
317
+ )
318
+
319
+ # validation dataset
320
+ if self.val_data is not None:
321
+ # create its own dataset
322
+ self.val_dataset: DatasetType = InMemoryDataset(
323
+ data_config=self.data_config,
324
+ inputs=self.val_data,
325
+ data_target=self.val_data_target,
326
+ )
327
+ else:
328
+ # extract validation from the training patches
329
+ self.val_dataset = self.train_dataset.split_dataset(
330
+ percentage=self.val_percentage,
331
+ minimum_patches=self.val_minimum_split,
332
+ )
333
+
334
+ # else we read files
335
+ else:
336
+ # Heuristics, if the file size is smaller than 80% of the RAM,
337
+ # we run the training in memory, otherwise we switch to iterable dataset
338
+ # The switch is deactivated if use_in_memory is False
339
+ if self.use_in_memory and self.train_files_size < get_ram_size() * 0.8:
340
+ # train dataset
341
+ self.train_dataset = InMemoryDataset(
342
+ data_config=self.data_config,
343
+ inputs=self.train_files,
344
+ data_target=self.train_target_files
345
+ if self.train_data_target
346
+ else None,
347
+ read_source_func=self.read_source_func,
348
+ )
349
+
350
+ # validation dataset
351
+ if self.val_data is not None:
352
+ self.val_dataset = InMemoryDataset(
353
+ data_config=self.data_config,
354
+ inputs=self.val_files,
355
+ data_target=self.val_target_files
356
+ if self.val_data_target
357
+ else None,
358
+ read_source_func=self.read_source_func,
359
+ )
360
+ else:
361
+ # split dataset
362
+ self.val_dataset = self.train_dataset.split_dataset(
363
+ percentage=self.val_percentage,
364
+ minimum_patches=self.val_minimum_split,
365
+ )
366
+
367
+ # else if the data is too large, load file by file during training
368
+ else:
369
+ # create training dataset
370
+ self.train_dataset = PathIterableDataset(
371
+ data_config=self.data_config,
372
+ src_files=self.train_files,
373
+ target_files=self.train_target_files
374
+ if self.train_data_target
375
+ else None,
376
+ read_source_func=self.read_source_func,
377
+ )
378
+
379
+ # create validation dataset
380
+ if self.val_files is not None:
381
+ # create its own dataset
382
+ self.val_dataset = PathIterableDataset(
383
+ data_config=self.data_config,
384
+ src_files=self.val_files,
385
+ target_files=self.val_target_files
386
+ if self.val_data_target
387
+ else None,
388
+ read_source_func=self.read_source_func,
389
+ )
390
+ elif len(self.train_files) <= self.val_minimum_split:
391
+ raise ValueError(
392
+ f"Not enough files to split a minimum of "
393
+ f"{self.val_minimum_split} files, got {len(self.train_files)} "
394
+ f"files."
395
+ )
396
+ else:
397
+ # extract validation from the training patches
398
+ self.val_dataset = self.train_dataset.split_dataset(
399
+ percentage=self.val_percentage,
400
+ minimum_files=self.val_minimum_split,
401
+ )
402
+
403
+ def train_dataloader(self) -> Any:
404
+ """
405
+ Create a dataloader for training.
406
+
407
+ Returns
408
+ -------
409
+ Any
410
+ Training dataloader.
411
+ """
412
+ return DataLoader(
413
+ self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
414
+ )
415
+
416
+ def val_dataloader(self) -> Any:
417
+ """
418
+ Create a dataloader for validation.
419
+
420
+ Returns
421
+ -------
422
+ Any
423
+ Validation dataloader.
424
+ """
425
+ return DataLoader(
426
+ self.val_dataset,
427
+ batch_size=self.batch_size,
428
+ )
429
+
430
+
431
+ class TrainingDataWrapper(CAREamicsTrainData):
432
+ """
433
+ Wrapper around the CAREamics Lightning training data module.
434
+
435
+ This class is used to explicitely pass the parameters usually contained in a
436
+ `data_model` configuration.
437
+
438
+ Since the lightning datamodule has no access to the model, make sure that the
439
+ parameters passed to the datamodule are consistent with the model's requirements and
440
+ are coherent.
441
+
442
+ The data module can be used with Path, str or numpy arrays. In the case of
443
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
444
+ inputs, it calculates the total file size and estimate whether it can fit in
445
+ memory. If it does not, it iterates through the files. This behaviour can be
446
+ deactivated by setting `use_in_memory` to False, in which case it will
447
+ always use the iterating dataset to train on a Path or str.
448
+
449
+ To use array data, set `data_type` to `array` and pass a numpy array to
450
+ `train_data`.
451
+
452
+ In particular, N2V requires a specific transformation (N2V manipulates), which is
453
+ not compatible with supervised training. The default transformations applied to the
454
+ training patches are defined in `careamics.config.data_model`. To use different
455
+ transformations, pass a list of transforms or an albumentation `Compose` as
456
+ `transforms` parameter. See examples for more details.
457
+
458
+ By default, CAREamics only supports types defined in
459
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
460
+ `data_type` to `custom` and provide a function that returns a numpy array from a
461
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression (e.g.
462
+ "*.jpeg") to filter the files extension using `extension_filter`.
463
+
464
+ In the absence of validation data, the validation data is extracted from the
465
+ training data. The percentage of the training data to use for validation, as well as
466
+ the minimum number of patches to split from the training data for validation can be
467
+ set using `val_percentage` and `val_minimum_patches`, respectively.
468
+
469
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
470
+ except for `batch_size`, which is set by the `batch_size` parameter.
471
+
472
+ Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to
473
+ use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define
474
+ the axis and span of the structN2V mask. These parameters are without effect if
475
+ a `train_target_data` or if `transforms` are provided.
476
+
477
+ Parameters
478
+ ----------
479
+ train_data : Union[str, Path, np.ndarray]
480
+ Training data.
481
+ data_type : Union[str, SupportedData]
482
+ Data type, see `SupportedData` for available options.
483
+ patch_size : List[int]
484
+ Patch size, 2D or 3D patch size.
485
+ axes : str
486
+ Axes of the data, choosen amongst SCZYX.
487
+ batch_size : int
488
+ Batch size.
489
+ val_data : Optional[Union[str, Path]], optional
490
+ Validation data, by default None.
491
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
492
+ List of transforms to apply to training patches. If None, default transforms
493
+ are applied.
494
+ train_target_data : Optional[Union[str, Path]], optional
495
+ Training target data, by default None.
496
+ val_target_data : Optional[Union[str, Path]], optional
497
+ Validation target data, by default None.
498
+ read_source_func : Optional[Callable], optional
499
+ Function to read the source data, used if `data_type` is `custom`, by
500
+ default None.
501
+ extension_filter : str, optional
502
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
503
+ val_percentage : float, optional
504
+ Percentage of the training data to use for validation if no validation data
505
+ is given, by default 0.1.
506
+ val_minimum_patches : int, optional
507
+ Minimum number of patches to split from the training data for validation if
508
+ no validation data is given, by default 5.
509
+ dataloader_params : dict, optional
510
+ Pytorch dataloader parameters, by default {}.
511
+ use_in_memory : bool, optional
512
+ Use in memory dataset if possible, by default True.
513
+ use_n2v2 : bool, optional
514
+ Use N2V2 transformation during training, by default False.
515
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
516
+ Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
517
+ default "none".
518
+ struct_n2v_span : int, optional
519
+ Span for the structN2V mask, by default 5.
520
+
521
+ Examples
522
+ --------
523
+ Create a TrainingDataWrapper with default transforms with a numpy array:
524
+ >>> import numpy as np
525
+ >>> from careamics import TrainingDataWrapper
526
+ >>> my_array = np.arange(256).reshape(16, 16)
527
+ >>> data_module = TrainingDataWrapper(
528
+ ... train_data=my_array,
529
+ ... data_type="array",
530
+ ... patch_size=(8, 8),
531
+ ... axes='YX',
532
+ ... batch_size=2,
533
+ ... )
534
+
535
+ For custom data types (those not supported by CAREamics), then one can pass a read
536
+ function and a filter for the files extension:
537
+ >>> import numpy as np
538
+ >>> from careamics import TrainingDataWrapper
539
+ >>>
540
+ >>> def read_npy(path):
541
+ ... return np.load(path)
542
+ >>>
543
+ >>> data_module = TrainingDataWrapper(
544
+ ... train_data="path/to/data",
545
+ ... data_type="custom",
546
+ ... patch_size=(8, 8),
547
+ ... axes='YX',
548
+ ... batch_size=2,
549
+ ... read_source_func=read_npy,
550
+ ... extension_filter="*.npy",
551
+ ... )
552
+
553
+ If you want to use a different set of transformations, you can pass a list of
554
+ transforms:
555
+ >>> import numpy as np
556
+ >>> from careamics import TrainingDataWrapper
557
+ >>> from careamics.config.support import SupportedTransform
558
+ >>> my_array = np.arange(256).reshape(16, 16)
559
+ >>> my_transforms = [
560
+ ... {
561
+ ... "name": SupportedTransform.NORMALIZE.value,
562
+ ... "mean": 0,
563
+ ... "std": 1,
564
+ ... },
565
+ ... {
566
+ ... "name": SupportedTransform.N2V_MANIPULATE.value,
567
+ ... }
568
+ ... ]
569
+ >>> data_module = TrainingDataWrapper(
570
+ ... train_data=my_array,
571
+ ... data_type="array",
572
+ ... patch_size=(8, 8),
573
+ ... axes='YX',
574
+ ... batch_size=2,
575
+ ... transforms=my_transforms,
576
+ ... )
577
+ """
578
+
579
+ def __init__(
580
+ self,
581
+ train_data: Union[str, Path, np.ndarray],
582
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
583
+ patch_size: List[int],
584
+ axes: str,
585
+ batch_size: int,
586
+ val_data: Optional[Union[str, Path]] = None,
587
+ transforms: Optional[Union[List[TRANSFORMS_UNION], Compose]] = None,
588
+ train_target_data: Optional[Union[str, Path]] = None,
589
+ val_target_data: Optional[Union[str, Path]] = None,
590
+ read_source_func: Optional[Callable] = None,
591
+ extension_filter: str = "",
592
+ val_percentage: float = 0.1,
593
+ val_minimum_patches: int = 5,
594
+ dataloader_params: Optional[dict] = None,
595
+ use_in_memory: bool = True,
596
+ use_n2v2: bool = False,
597
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
598
+ struct_n2v_span: int = 5,
599
+ ) -> None:
600
+ """
601
+ LightningDataModule wrapper for training and validation datasets.
602
+
603
+ Since the lightning datamodule has no access to the model, make sure that the
604
+ parameters passed to the datamodule are consistent with the model's requirements
605
+ and are coherent.
606
+
607
+ The data module can be used with Path, str or numpy arrays. In the case of
608
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
609
+ inputs, it calculates the total file size and estimate whether it can fit in
610
+ memory. If it does not, it iterates through the files. This behaviour can be
611
+ deactivated by setting `use_in_memory` to False, in which case it will
612
+ always use the iterating dataset to train on a Path or str.
613
+
614
+ To use array data, set `data_type` to `array` and pass a numpy array to
615
+ `train_data`.
616
+
617
+ In particular, N2V requires a specific transformation (N2V manipulates), which
618
+ is not compatible with supervised training. The default transformations applied
619
+ to the training patches are defined in `careamics.config.data_model`. To use
620
+ different transformations, pass a list of transforms or an albumentation
621
+ `Compose` as `transforms` parameter. See examples for more details.
622
+
623
+ By default, CAREamics only supports types defined in
624
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
625
+ `data_type` to `custom` and provide a function that returns a numpy array from a
626
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
627
+ (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
628
+
629
+ In the absence of validation data, the validation data is extracted from the
630
+ training data. The percentage of the training data to use for validation, as
631
+ well as the minimum number of patches to split from the training data for
632
+ validation can be set using `val_percentage` and `val_minimum_patches`,
633
+ respectively.
634
+
635
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch
636
+ dataloaders, except for `batch_size`, which is set by the `batch_size`
637
+ parameter.
638
+
639
+ Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2`
640
+ to use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to
641
+ define the axis and span of the structN2V mask. These parameters are without
642
+ effect if a `train_target_data` or if `transforms` are provided.
643
+
644
+ Parameters
645
+ ----------
646
+ train_data : Union[str, Path, np.ndarray]
647
+ Training data.
648
+ data_type : Union[str, SupportedData]
649
+ Data type, see `SupportedData` for available options.
650
+ patch_size : List[int]
651
+ Patch size, 2D or 3D patch size.
652
+ axes : str
653
+ Axes of the data, choosen amongst SCZYX.
654
+ batch_size : int
655
+ Batch size.
656
+ val_data : Optional[Union[str, Path]], optional
657
+ Validation data, by default None.
658
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
659
+ List of transforms to apply to training patches. If None, default transforms
660
+ are applied.
661
+ train_target_data : Optional[Union[str, Path]], optional
662
+ Training target data, by default None.
663
+ val_target_data : Optional[Union[str, Path]], optional
664
+ Validation target data, by default None.
665
+ read_source_func : Optional[Callable], optional
666
+ Function to read the source data, used if `data_type` is `custom`, by
667
+ default None.
668
+ extension_filter : str, optional
669
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
670
+ val_percentage : float, optional
671
+ Percentage of the training data to use for validation if no validation data
672
+ is given, by default 0.1.
673
+ val_minimum_patches : int, optional
674
+ Minimum number of patches to split from the training data for validation if
675
+ no validation data is given, by default 5.
676
+ dataloader_params : dict, optional
677
+ Pytorch dataloader parameters, by default {}.
678
+ use_in_memory : bool, optional
679
+ Use in memory dataset if possible, by default True.
680
+ use_n2v2 : bool, optional
681
+ Use N2V2 transformation during training, by default False.
682
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
683
+ Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
684
+ default "none".
685
+ struct_n2v_span : int, optional
686
+ Span for the structN2V mask, by default 5.
687
+
688
+ Raises
689
+ ------
690
+ ValueError
691
+ If a target is set and N2V manipulation is present in the transforms.
692
+ """
693
+ if dataloader_params is None:
694
+ dataloader_params = {}
695
+ data_dict: Dict[str, Any] = {
696
+ "mode": "train",
697
+ "data_type": data_type,
698
+ "patch_size": patch_size,
699
+ "axes": axes,
700
+ "batch_size": batch_size,
701
+ "dataloader_params": dataloader_params,
702
+ }
703
+
704
+ # if transforms are passed (otherwise it will use the default ones)
705
+ if transforms is not None:
706
+ data_dict["transforms"] = transforms
707
+
708
+ # validate configuration
709
+ self.data_config = DataConfig(**data_dict)
710
+
711
+ # N2V specific checks, N2V, structN2V, and transforms
712
+ if (
713
+ self.data_config.has_transform_list()
714
+ and self.data_config.has_n2v_manipulate()
715
+ ):
716
+ # there is not target, n2v2 and structN2V can be changed
717
+ if train_target_data is None:
718
+ self.data_config.set_N2V2(use_n2v2)
719
+ self.data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
720
+ else:
721
+ raise ValueError(
722
+ "Cannot have both supervised training (target data) and "
723
+ "N2V manipulation in the transforms. Pass a list of transforms "
724
+ "that is compatible with your supervised training."
725
+ )
726
+
727
+ # sanity check on the dataloader parameters
728
+ if "batch_size" in dataloader_params:
729
+ # remove it
730
+ del dataloader_params["batch_size"]
731
+
732
+ super().__init__(
733
+ data_config=self.data_config,
734
+ train_data=train_data,
735
+ val_data=val_data,
736
+ train_data_target=train_target_data,
737
+ val_data_target=val_target_data,
738
+ read_source_func=read_source_func,
739
+ extension_filter=extension_filter,
740
+ val_percentage=val_percentage,
741
+ val_minimum_split=val_minimum_patches,
742
+ use_in_memory=use_in_memory,
743
+ )