careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,665 @@
1
+ from pathlib import Path
2
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as L
6
+ from albumentations import Compose
7
+ from torch.utils.data import DataLoader
8
+
9
+ from careamics.config import DataModel
10
+ from careamics.config.data_model import TRANSFORMS_UNION
11
+ from careamics.config.support import SupportedData
12
+ from careamics.dataset.dataset_utils import (
13
+ get_files_size,
14
+ get_read_func,
15
+ list_files,
16
+ validate_source_target_files,
17
+ )
18
+ from careamics.dataset.in_memory_dataset import (
19
+ InMemoryDataset,
20
+ )
21
+ from careamics.dataset.iterable_dataset import (
22
+ PathIterableDataset,
23
+ )
24
+ from careamics.utils import get_logger, get_ram_size
25
+
26
+ DatasetType = Union[InMemoryDataset, PathIterableDataset]
27
+
28
+ logger = get_logger(__name__)
29
+
30
+
31
+ class CAREamicsWood(L.LightningDataModule):
32
+ """
33
+ LightningDataModule for training and validation datasets.
34
+
35
+ The data module can be used with Path, str or numpy arrays. In the case of
36
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
37
+ inputs, it calculates the total file size and estimate whether it can fit in
38
+ memory. If it does not, it iterates through the files. This behaviour can be
39
+ deactivated by setting `use_in_memory` to False, in which case it will
40
+ always use the iterating dataset to train on a Path or str.
41
+
42
+ The data can be either a folder containing images or a single file.
43
+
44
+ Validation can be omitted, in which case the validation data is extracted from
45
+ the training data. The percentage of the training data to use for validation,
46
+ as well as the minimum number of patches or files to split from the training
47
+ data can be set using `val_percentage` and `val_minimum_split`, respectively.
48
+
49
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
50
+ and provide a function that returns a numpy array from a path as
51
+ `read_source_func` parameter. The function will receive a Path object and
52
+ an axies string as arguments, the axes being derived from the `data_config`.
53
+
54
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
55
+ "*.czi") to filter the files extension using `extension_filter`.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ data_config: DataModel,
61
+ train_data: Union[Path, str, np.ndarray],
62
+ val_data: Optional[Union[Path, str, np.ndarray]] = None,
63
+ train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
64
+ val_data_target: Optional[Union[Path, str, np.ndarray]] = None,
65
+ read_source_func: Optional[Callable] = None,
66
+ extension_filter: str = "",
67
+ val_percentage: float = 0.1,
68
+ val_minimum_split: int = 5,
69
+ use_in_memory: bool = True,
70
+ ) -> None:
71
+ """
72
+ Constructor.
73
+
74
+ Parameters
75
+ ----------
76
+ data_config : DataModel
77
+ Pydantic model for CAREamics data configuration.
78
+ train_data : Union[Path, str, np.ndarray]
79
+ Training data, can be a path to a folder, a file or a numpy array.
80
+ val_data : Optional[Union[Path, str, np.ndarray]], optional
81
+ Validation data, can be a path to a folder, a file or a numpy array, by
82
+ default None.
83
+ train_data_target : Optional[Union[Path, str, np.ndarray]], optional
84
+ Training target data, can be a path to a folder, a file or a numpy array, by
85
+ default None.
86
+ val_data_target : Optional[Union[Path, str, np.ndarray]], optional
87
+ Validation target data, can be a path to a folder, a file or a numpy array,
88
+ by default None.
89
+ read_source_func : Optional[Callable], optional
90
+ Function to read the source data, by default None. Only used for `custom`
91
+ data type (see DataModel).
92
+ extension_filter : str, optional
93
+ Filter for file extensions, by default "". Only used for `custom` data types
94
+ (see DataModel).
95
+ val_percentage : float, optional
96
+ Percentage of the training data to use for validation, by default 0.1. Only
97
+ used if `val_data` is None.
98
+ val_minimum_split : int, optional
99
+ Minimum number of patches or files to split from the training data for
100
+ validation, by default 5. Only used if `val_data` is None.
101
+
102
+ Raises
103
+ ------
104
+ NotImplementedError
105
+ Raised if target data is provided.
106
+ ValueError
107
+ If the input types are mixed (e.g. Path and np.ndarray).
108
+ ValueError
109
+ If the data type is `custom` and no `read_source_func` is provided.
110
+ ValueError
111
+ If the data type is `array` and the input is not a numpy array.
112
+ ValueError
113
+ If the data type is `tiff` and the input is neither a Path nor a str.
114
+ """
115
+ super().__init__()
116
+
117
+ # check input types coherence (no mixed types)
118
+ inputs = [train_data, val_data, train_data_target, val_data_target]
119
+ types_set = {type(i) for i in inputs}
120
+ if len(types_set) > 2: # None + expected type
121
+ raise ValueError(
122
+ f"Inputs for `train_data`, `val_data`, `train_data_target` and "
123
+ f"`val_data_target` must be of the same type or None. Got "
124
+ f"{types_set}."
125
+ )
126
+
127
+ # check that a read source function is provided for custom types
128
+ if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
129
+ raise ValueError(
130
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
131
+ f"specifying a `read_source_func`."
132
+ )
133
+
134
+ # and that arrays are passed, if array type specified
135
+ elif data_config.data_type == SupportedData.ARRAY and not isinstance(
136
+ train_data, np.ndarray
137
+ ):
138
+ raise ValueError(
139
+ f"Expected array input (see configuration.data.data_type), but got "
140
+ f"{type(train_data)} instead."
141
+ )
142
+
143
+ # and that Path or str are passed, if tiff file type specified
144
+ elif data_config.data_type == SupportedData.TIFF and (
145
+ not isinstance(train_data, Path) and not isinstance(train_data, str)
146
+ ):
147
+ raise ValueError(
148
+ f"Expected Path or str input (see configuration.data.data_type), "
149
+ f"but got {type(train_data)} instead."
150
+ )
151
+
152
+ # configuration
153
+ self.data_config = data_config
154
+ self.data_type = data_config.data_type
155
+ self.batch_size = data_config.batch_size
156
+ self.use_in_memory = use_in_memory
157
+
158
+ # data
159
+ self.train_data = train_data
160
+ self.val_data = val_data
161
+
162
+ self.train_data_target = train_data_target
163
+ self.val_data_target = val_data_target
164
+ self.val_percentage = val_percentage
165
+ self.val_minimum_split = val_minimum_split
166
+
167
+ # read source function corresponding to the requested type
168
+ if data_config.data_type == SupportedData.CUSTOM:
169
+ # mypy check
170
+ assert read_source_func is not None
171
+
172
+ self.read_source_func: Callable = read_source_func
173
+
174
+ elif data_config.data_type != SupportedData.ARRAY:
175
+ self.read_source_func = get_read_func(data_config.data_type)
176
+
177
+ self.extension_filter = extension_filter
178
+
179
+ # Pytorch dataloader parameters
180
+ self.dataloader_params = (
181
+ data_config.dataloader_params if data_config.dataloader_params else {}
182
+ )
183
+
184
+ def prepare_data(self) -> None:
185
+ """
186
+ Hook used to prepare the data before calling `setup`.
187
+
188
+ Here, we only need to examine the data if it was provided as a str or a Path.
189
+
190
+ TODO: from lightning doc:
191
+ prepare_data is called from the main process. It is not recommended to assign
192
+ state here (e.g. self.x = y) since it is called on a single process and if you
193
+ assign states here then they won't be available for other processes.
194
+
195
+ https://lightning.ai/docs/pytorch/stable/data/datamodule.html
196
+ """
197
+ # if the data is a Path or a str
198
+ if (
199
+ not isinstance(self.train_data, np.ndarray)
200
+ and not isinstance(self.val_data, np.ndarray)
201
+ and not isinstance(self.train_data_target, np.ndarray)
202
+ and not isinstance(self.val_data_target, np.ndarray)
203
+ ):
204
+ # list training files
205
+ self.train_files = list_files(
206
+ self.train_data, self.data_type, self.extension_filter
207
+ )
208
+ self.train_files_size = get_files_size(self.train_files)
209
+
210
+ # list validation files
211
+ if self.val_data is not None:
212
+ self.val_files = list_files(
213
+ self.val_data, self.data_type, self.extension_filter
214
+ )
215
+
216
+ # same for target data
217
+ if self.train_data_target is not None:
218
+ self.train_target_files: List[Path] = list_files(
219
+ self.train_data_target, self.data_type, self.extension_filter
220
+ )
221
+
222
+ # verify that they match the training data
223
+ validate_source_target_files(self.train_files, self.train_target_files)
224
+
225
+ if self.val_data_target is not None:
226
+ self.val_target_files = list_files(
227
+ self.val_data_target, self.data_type, self.extension_filter
228
+ )
229
+
230
+ # verify that they match the validation data
231
+ validate_source_target_files(self.val_files, self.val_target_files)
232
+
233
+ def setup(self, *args: Any, **kwargs: Any) -> None:
234
+ """Hook called at the beginning of fit, validate, or predict."""
235
+ # if numpy array
236
+ if self.data_type == SupportedData.ARRAY:
237
+ # train dataset
238
+ self.train_dataset: DatasetType = InMemoryDataset(
239
+ data_config=self.data_config,
240
+ inputs=self.train_data,
241
+ data_target=self.train_data_target,
242
+ )
243
+
244
+ # validation dataset
245
+ if self.val_data is not None:
246
+ # create its own dataset
247
+ self.val_dataset: DatasetType = InMemoryDataset(
248
+ data_config=self.data_config,
249
+ inputs=self.val_data,
250
+ data_target=self.val_data_target,
251
+ )
252
+ else:
253
+ # extract validation from the training patches
254
+ self.val_dataset = self.train_dataset.split_dataset(
255
+ percentage=self.val_percentage,
256
+ minimum_patches=self.val_minimum_split,
257
+ )
258
+
259
+ # else we read files
260
+ else:
261
+ # Heuristics, if the file size is smaller than 80% of the RAM,
262
+ # we run the training in memory, otherwise we switch to iterable dataset
263
+ # The switch is deactivated if use_in_memory is False
264
+ if self.use_in_memory and self.train_files_size < get_ram_size() * 0.8:
265
+ # train dataset
266
+ self.train_dataset = InMemoryDataset(
267
+ data_config=self.data_config,
268
+ inputs=self.train_files,
269
+ data_target=self.train_target_files
270
+ if self.train_data_target
271
+ else None,
272
+ read_source_func=self.read_source_func,
273
+ )
274
+
275
+ # validation dataset
276
+ if self.val_data is not None:
277
+ self.val_dataset = InMemoryDataset(
278
+ data_config=self.data_config,
279
+ inputs=self.val_files,
280
+ data_target=self.val_target_files
281
+ if self.val_data_target
282
+ else None,
283
+ read_source_func=self.read_source_func,
284
+ )
285
+ else:
286
+ # split dataset
287
+ self.val_dataset = self.train_dataset.split_dataset(
288
+ percentage=self.val_percentage,
289
+ minimum_patches=self.val_minimum_split,
290
+ )
291
+
292
+ # else if the data is too large, load file by file during training
293
+ else:
294
+ # create training dataset
295
+ self.train_dataset = PathIterableDataset(
296
+ data_config=self.data_config,
297
+ src_files=self.train_files,
298
+ target_files=self.train_target_files
299
+ if self.train_data_target
300
+ else None,
301
+ read_source_func=self.read_source_func,
302
+ )
303
+
304
+ # create validation dataset
305
+ if self.val_files is not None:
306
+ # create its own dataset
307
+ self.val_dataset = PathIterableDataset(
308
+ data_config=self.data_config,
309
+ src_files=self.val_files,
310
+ target_files=self.val_target_files
311
+ if self.val_data_target
312
+ else None,
313
+ read_source_func=self.read_source_func,
314
+ )
315
+ elif len(self.train_files) <= self.val_minimum_split:
316
+ raise ValueError(
317
+ f"Not enough files to split a minimum of "
318
+ f"{self.val_minimum_split} files, got {len(self.train_files)} "
319
+ f"files."
320
+ )
321
+ else:
322
+ # extract validation from the training patches
323
+ self.val_dataset = self.train_dataset.split_dataset(
324
+ percentage=self.val_percentage,
325
+ minimum_files=self.val_minimum_split,
326
+ )
327
+
328
+ def train_dataloader(self) -> Any:
329
+ """
330
+ Create a dataloader for training.
331
+
332
+ Returns
333
+ -------
334
+ Any
335
+ Training dataloader.
336
+ """
337
+ return DataLoader(
338
+ self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
339
+ )
340
+
341
+ def val_dataloader(self) -> Any:
342
+ """
343
+ Create a dataloader for validation.
344
+
345
+ Returns
346
+ -------
347
+ Any
348
+ Validation dataloader.
349
+ """
350
+ return DataLoader(
351
+ self.val_dataset,
352
+ batch_size=self.batch_size,
353
+ )
354
+
355
+
356
+ class CAREamicsTrainDataModule(CAREamicsWood):
357
+ """
358
+ LightningDataModule wrapper for training and validation datasets.
359
+
360
+ Since the lightning datamodule has no access to the model, make sure that the
361
+ parameters passed to the datamodule are consistent with the model's requirements and
362
+ are coherent.
363
+
364
+ The data module can be used with Path, str or numpy arrays. In the case of
365
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
366
+ inputs, it calculates the total file size and estimate whether it can fit in
367
+ memory. If it does not, it iterates through the files. This behaviour can be
368
+ deactivated by setting `use_in_memory` to False, in which case it will
369
+ always use the iterating dataset to train on a Path or str.
370
+
371
+ To use array data, set `data_type` to `array` and pass a numpy array to
372
+ `train_data`.
373
+
374
+ In particular, N2V requires a specific transformation (N2V manipulates), which is
375
+ not compatible with supervised training. The default transformations applied to the
376
+ training patches are defined in `careamics.config.data_model`. To use different
377
+ transformations, pass a list of transforms or an albumentation `Compose` as
378
+ `transforms` parameter. See examples for more details.
379
+
380
+ By default, CAREamics only supports types defined in
381
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
382
+ `data_type` to `custom` and provide a function that returns a numpy array from a
383
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression (e.g.
384
+ "*.jpeg") to filter the files extension using `extension_filter`.
385
+
386
+ In the absence of validation data, the validation data is extracted from the
387
+ training data. The percentage of the training data to use for validation, as well as
388
+ the minimum number of patches to split from the training data for validation can be
389
+ set using `val_percentage` and `val_minimum_patches`, respectively.
390
+
391
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
392
+ except for `batch_size`, which is set by the `batch_size` parameter.
393
+
394
+ Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2` to
395
+ use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to define
396
+ the axis and span of the structN2V mask. These parameters are without effect if
397
+ a `train_target_data` or if `transforms` are provided.
398
+
399
+ Parameters
400
+ ----------
401
+ train_data : Union[str, Path, np.ndarray]
402
+ Training data.
403
+ data_type : Union[str, SupportedData]
404
+ Data type, see `SupportedData` for available options.
405
+ patch_size : List[int]
406
+ Patch size, 2D or 3D patch size.
407
+ axes : str
408
+ Axes of the data, choosen amongst SCZYX.
409
+ batch_size : int
410
+ Batch size.
411
+ val_data : Optional[Union[str, Path]], optional
412
+ Validation data, by default None.
413
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
414
+ List of transforms to apply to training patches. If None, default transforms
415
+ are applied.
416
+ train_target_data : Optional[Union[str, Path]], optional
417
+ Training target data, by default None.
418
+ val_target_data : Optional[Union[str, Path]], optional
419
+ Validation target data, by default None.
420
+ read_source_func : Optional[Callable], optional
421
+ Function to read the source data, used if `data_type` is `custom`, by
422
+ default None.
423
+ extension_filter : str, optional
424
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
425
+ val_percentage : float, optional
426
+ Percentage of the training data to use for validation if no validation data
427
+ is given, by default 0.1.
428
+ val_minimum_patches : int, optional
429
+ Minimum number of patches to split from the training data for validation if
430
+ no validation data is given, by default 5.
431
+ dataloader_params : dict, optional
432
+ Pytorch dataloader parameters, by default {}.
433
+ use_in_memory : bool, optional
434
+ Use in memory dataset if possible, by default True.
435
+ use_n2v2 : bool, optional
436
+ Use N2V2 transformation during training, by default False.
437
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
438
+ Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
439
+ default "none".
440
+ struct_n2v_span : int, optional
441
+ Span for the structN2V mask, by default 5.
442
+
443
+ Examples
444
+ --------
445
+ Create a CAREamicsTrainDataModule with default transforms with a numpy array:
446
+ >>> import numpy as np
447
+ >>> from careamics import CAREamicsTrainDataModule
448
+ >>> my_array = np.arange(256).reshape(16, 16)
449
+ >>> data_module = CAREamicsTrainDataModule(
450
+ ... train_data=my_array,
451
+ ... data_type="array",
452
+ ... patch_size=(8, 8),
453
+ ... axes='YX',
454
+ ... batch_size=2,
455
+ ... )
456
+
457
+ For custom data types (those not supported by CAREamics), then one can pass a read
458
+ function and a filter for the files extension:
459
+ >>> import numpy as np
460
+ >>> from careamics import CAREamicsTrainDataModule
461
+ >>>
462
+ >>> def read_npy(path):
463
+ ... return np.load(path)
464
+ >>>
465
+ >>> data_module = CAREamicsTrainDataModule(
466
+ ... train_data="path/to/data",
467
+ ... data_type="custom",
468
+ ... patch_size=(8, 8),
469
+ ... axes='YX',
470
+ ... batch_size=2,
471
+ ... read_source_func=read_npy,
472
+ ... extension_filter="*.npy",
473
+ ... )
474
+
475
+ If you want to use a different set of transformations, you can pass a list of
476
+ transforms:
477
+ >>> import numpy as np
478
+ >>> from careamics import CAREamicsTrainDataModule
479
+ >>> from careamics.config.support import SupportedTransform
480
+ >>> my_array = np.arange(256).reshape(16, 16)
481
+ >>> my_transforms = [
482
+ ... {
483
+ ... "name": SupportedTransform.NORMALIZE.value,
484
+ ... "mean": 0,
485
+ ... "std": 1,
486
+ ... },
487
+ ... {
488
+ ... "name": SupportedTransform.N2V_MANIPULATE.value,
489
+ ... }
490
+ ... ]
491
+ >>> data_module = CAREamicsTrainDataModule(
492
+ ... train_data=my_array,
493
+ ... data_type="array",
494
+ ... patch_size=(8, 8),
495
+ ... axes='YX',
496
+ ... batch_size=2,
497
+ ... transforms=my_transforms,
498
+ ... )
499
+ """
500
+
501
+ def __init__(
502
+ self,
503
+ train_data: Union[str, Path, np.ndarray],
504
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
505
+ patch_size: List[int],
506
+ axes: str,
507
+ batch_size: int,
508
+ val_data: Optional[Union[str, Path]] = None,
509
+ transforms: Optional[Union[List[TRANSFORMS_UNION], Compose]] = None,
510
+ train_target_data: Optional[Union[str, Path]] = None,
511
+ val_target_data: Optional[Union[str, Path]] = None,
512
+ read_source_func: Optional[Callable] = None,
513
+ extension_filter: str = "",
514
+ val_percentage: float = 0.1,
515
+ val_minimum_patches: int = 5,
516
+ dataloader_params: Optional[dict] = None,
517
+ use_in_memory: bool = True,
518
+ use_n2v2: bool = False,
519
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
520
+ struct_n2v_span: int = 5,
521
+ ) -> None:
522
+ """
523
+ LightningDataModule wrapper for training and validation datasets.
524
+
525
+ Since the lightning datamodule has no access to the model, make sure that the
526
+ parameters passed to the datamodule are consistent with the model's requirements
527
+ and are coherent.
528
+
529
+ The data module can be used with Path, str or numpy arrays. In the case of
530
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
531
+ inputs, it calculates the total file size and estimate whether it can fit in
532
+ memory. If it does not, it iterates through the files. This behaviour can be
533
+ deactivated by setting `use_in_memory` to False, in which case it will
534
+ always use the iterating dataset to train on a Path or str.
535
+
536
+ To use array data, set `data_type` to `array` and pass a numpy array to
537
+ `train_data`.
538
+
539
+ In particular, N2V requires a specific transformation (N2V manipulates), which
540
+ is not compatible with supervised training. The default transformations applied
541
+ to the training patches are defined in `careamics.config.data_model`. To use
542
+ different transformations, pass a list of transforms or an albumentation
543
+ `Compose` as `transforms` parameter. See examples for more details.
544
+
545
+ By default, CAREamics only supports types defined in
546
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
547
+ `data_type` to `custom` and provide a function that returns a numpy array from a
548
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
549
+ (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
550
+
551
+ In the absence of validation data, the validation data is extracted from the
552
+ training data. The percentage of the training data to use for validation, as
553
+ well as the minimum number of patches to split from the training data for
554
+ validation can be set using `val_percentage` and `val_minimum_patches`,
555
+ respectively.
556
+
557
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch
558
+ dataloaders, except for `batch_size`, which is set by the `batch_size`
559
+ parameter.
560
+
561
+ Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2`
562
+ to use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to
563
+ define the axis and span of the structN2V mask. These parameters are without
564
+ effect if a `train_target_data` or if `transforms` are provided.
565
+
566
+ Parameters
567
+ ----------
568
+ train_data : Union[str, Path, np.ndarray]
569
+ Training data.
570
+ data_type : Union[str, SupportedData]
571
+ Data type, see `SupportedData` for available options.
572
+ patch_size : List[int]
573
+ Patch size, 2D or 3D patch size.
574
+ axes : str
575
+ Axes of the data, choosen amongst SCZYX.
576
+ batch_size : int
577
+ Batch size.
578
+ val_data : Optional[Union[str, Path]], optional
579
+ Validation data, by default None.
580
+ transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
581
+ List of transforms to apply to training patches. If None, default transforms
582
+ are applied.
583
+ train_target_data : Optional[Union[str, Path]], optional
584
+ Training target data, by default None.
585
+ val_target_data : Optional[Union[str, Path]], optional
586
+ Validation target data, by default None.
587
+ read_source_func : Optional[Callable], optional
588
+ Function to read the source data, used if `data_type` is `custom`, by
589
+ default None.
590
+ extension_filter : str, optional
591
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
592
+ val_percentage : float, optional
593
+ Percentage of the training data to use for validation if no validation data
594
+ is given, by default 0.1.
595
+ val_minimum_patches : int, optional
596
+ Minimum number of patches to split from the training data for validation if
597
+ no validation data is given, by default 5.
598
+ dataloader_params : dict, optional
599
+ Pytorch dataloader parameters, by default {}.
600
+ use_in_memory : bool, optional
601
+ Use in memory dataset if possible, by default True.
602
+ use_n2v2 : bool, optional
603
+ Use N2V2 transformation during training, by default False.
604
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
605
+ Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
606
+ default "none".
607
+ struct_n2v_span : int, optional
608
+ Span for the structN2V mask, by default 5.
609
+
610
+ Raises
611
+ ------
612
+ ValueError
613
+ If a target is set and N2V manipulation is present in the transforms.
614
+ """
615
+ if dataloader_params is None:
616
+ dataloader_params = {}
617
+ data_dict: Dict[str, Any] = {
618
+ "mode": "train",
619
+ "data_type": data_type,
620
+ "patch_size": patch_size,
621
+ "axes": axes,
622
+ "batch_size": batch_size,
623
+ "dataloader_params": dataloader_params,
624
+ }
625
+
626
+ # if transforms are passed (otherwise it will use the default ones)
627
+ if transforms is not None:
628
+ data_dict["transforms"] = transforms
629
+
630
+ # validate configuration
631
+ self.data_config = DataModel(**data_dict)
632
+
633
+ # N2V specific checks, N2V, structN2V, and transforms
634
+ if (
635
+ self.data_config.has_transform_list()
636
+ and self.data_config.has_n2v_manipulate()
637
+ ):
638
+ # there is not target, n2v2 and structN2V can be changed
639
+ if train_target_data is None:
640
+ self.data_config.set_N2V2(use_n2v2)
641
+ self.data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
642
+ else:
643
+ raise ValueError(
644
+ "Cannot have both supervised training (target data) and "
645
+ "N2V manipulation in the transforms. Pass a list of transforms "
646
+ "that is compatible with your supervised training."
647
+ )
648
+
649
+ # sanity check on the dataloader parameters
650
+ if "batch_size" in dataloader_params:
651
+ # remove it
652
+ del dataloader_params["batch_size"]
653
+
654
+ super().__init__(
655
+ data_config=self.data_config,
656
+ train_data=train_data,
657
+ val_data=val_data,
658
+ train_data_target=train_target_data,
659
+ val_data_target=val_target_data,
660
+ read_source_func=read_source_func,
661
+ extension_filter=extension_filter,
662
+ val_percentage=val_percentage,
663
+ val_minimum_split=val_minimum_patches,
664
+ use_in_memory=use_in_memory,
665
+ )