careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,11 @@
1
1
  """Training and validation Lightning data modules."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+ from typing import Any, Callable, Literal, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
+ from numpy.typing import NDArray
8
9
  from torch.utils.data import DataLoader
9
10
 
10
11
  from careamics.config import DataConfig
@@ -12,7 +13,6 @@ from careamics.config.data_model import TRANSFORMS_UNION
12
13
  from careamics.config.support import SupportedData
13
14
  from careamics.dataset.dataset_utils import (
14
15
  get_files_size,
15
- get_read_func,
16
16
  list_files,
17
17
  validate_source_target_files,
18
18
  )
@@ -22,6 +22,7 @@ from careamics.dataset.in_memory_dataset import (
22
22
  from careamics.dataset.iterable_dataset import (
23
23
  PathIterableDataset,
24
24
  )
25
+ from careamics.file_io.read import get_read_func
25
26
  from careamics.utils import get_logger, get_ram_size
26
27
 
27
28
  DatasetType = Union[InMemoryDataset, PathIterableDataset]
@@ -29,7 +30,7 @@ DatasetType = Union[InMemoryDataset, PathIterableDataset]
29
30
  logger = get_logger(__name__)
30
31
 
31
32
 
32
- class CAREamicsTrainData(L.LightningDataModule):
33
+ class TrainDataModule(L.LightningDataModule):
33
34
  """
34
35
  CAREamics Ligthning training and validation data module.
35
36
 
@@ -59,18 +60,18 @@ class CAREamicsTrainData(L.LightningDataModule):
59
60
  ----------
60
61
  data_config : DataModel
61
62
  Pydantic model for CAREamics data configuration.
62
- train_data : Union[Path, str, np.ndarray]
63
+ train_data : pathlib.Path or str or numpy.ndarray
63
64
  Training data, can be a path to a folder, a file or a numpy array.
64
- val_data : Optional[Union[Path, str, np.ndarray]], optional
65
+ val_data : pathlib.Path or str or numpy.ndarray, optional
65
66
  Validation data, can be a path to a folder, a file or a numpy array, by
66
67
  default None.
67
- train_data_target : Optional[Union[Path, str, np.ndarray]], optional
68
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
68
69
  Training target data, can be a path to a folder, a file or a numpy array, by
69
70
  default None.
70
- val_data_target : Optional[Union[Path, str, np.ndarray]], optional
71
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
71
72
  Validation target data, can be a path to a folder, a file or a numpy array,
72
73
  by default None.
73
- read_source_func : Optional[Callable], optional
74
+ read_source_func : Callable, optional
74
75
  Function to read the source data, by default None. Only used for `custom`
75
76
  data type (see DataModel).
76
77
  extension_filter : str, optional
@@ -95,13 +96,13 @@ class CAREamicsTrainData(L.LightningDataModule):
95
96
  Batch size.
96
97
  use_in_memory : bool
97
98
  Whether to use in memory dataset if possible.
98
- train_data : Union[Path, np.ndarray]
99
+ train_data : pathlib.Path or numpy.ndarray
99
100
  Training data.
100
- val_data : Optional[Union[Path, np.ndarray]]
101
+ val_data : pathlib.Path or numpy.ndarray
101
102
  Validation data.
102
- train_data_target : Optional[Union[Path, np.ndarray]]
103
+ train_data_target : pathlib.Path or numpy.ndarray
103
104
  Training target data.
104
- val_data_target : Optional[Union[Path, np.ndarray]]
105
+ val_data_target : pathlib.Path or numpy.ndarray
105
106
  Validation target data.
106
107
  val_percentage : float
107
108
  Percentage of the training data to use for validation, if no validation data is
@@ -118,10 +119,10 @@ class CAREamicsTrainData(L.LightningDataModule):
118
119
  def __init__(
119
120
  self,
120
121
  data_config: DataConfig,
121
- train_data: Union[Path, str, np.ndarray],
122
- val_data: Optional[Union[Path, str, np.ndarray]] = None,
123
- train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
124
- val_data_target: Optional[Union[Path, str, np.ndarray]] = None,
122
+ train_data: Union[Path, str, NDArray],
123
+ val_data: Optional[Union[Path, str, NDArray]] = None,
124
+ train_data_target: Optional[Union[Path, str, NDArray]] = None,
125
+ val_data_target: Optional[Union[Path, str, NDArray]] = None,
125
126
  read_source_func: Optional[Callable] = None,
126
127
  extension_filter: str = "",
127
128
  val_percentage: float = 0.1,
@@ -135,18 +136,18 @@ class CAREamicsTrainData(L.LightningDataModule):
135
136
  ----------
136
137
  data_config : DataModel
137
138
  Pydantic model for CAREamics data configuration.
138
- train_data : Union[Path, str, np.ndarray]
139
+ train_data : pathlib.Path or str or numpy.ndarray
139
140
  Training data, can be a path to a folder, a file or a numpy array.
140
- val_data : Optional[Union[Path, str, np.ndarray]], optional
141
+ val_data : pathlib.Path or str or numpy.ndarray, optional
141
142
  Validation data, can be a path to a folder, a file or a numpy array, by
142
143
  default None.
143
- train_data_target : Optional[Union[Path, str, np.ndarray]], optional
144
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
144
145
  Training target data, can be a path to a folder, a file or a numpy array, by
145
146
  default None.
146
- val_data_target : Optional[Union[Path, str, np.ndarray]], optional
147
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
147
148
  Validation target data, can be a path to a folder, a file or a numpy array,
148
149
  by default None.
149
- read_source_func : Optional[Callable], optional
150
+ read_source_func : Callable, optional
150
151
  Function to read the source data, by default None. Only used for `custom`
151
152
  data type (see DataModel).
152
153
  extension_filter : str, optional
@@ -166,7 +167,7 @@ class CAREamicsTrainData(L.LightningDataModule):
166
167
  NotImplementedError
167
168
  Raised if target data is provided.
168
169
  ValueError
169
- If the input types are mixed (e.g. Path and np.ndarray).
170
+ If the input types are mixed (e.g. Path and numpy.ndarray).
170
171
  ValueError
171
172
  If the data type is `custom` and no `read_source_func` is provided.
172
173
  ValueError
@@ -223,21 +224,21 @@ class CAREamicsTrainData(L.LightningDataModule):
223
224
  self.use_in_memory: bool = use_in_memory
224
225
 
225
226
  # data: make data Path or np.ndarray, use type annotations for mypy
226
- self.train_data: Union[Path, np.ndarray] = (
227
+ self.train_data: Union[Path, NDArray] = (
227
228
  Path(train_data) if isinstance(train_data, str) else train_data
228
229
  )
229
230
 
230
- self.val_data: Union[Path, np.ndarray] = (
231
+ self.val_data: Union[Path, NDArray] = (
231
232
  Path(val_data) if isinstance(val_data, str) else val_data
232
233
  )
233
234
 
234
- self.train_data_target: Union[Path, np.ndarray] = (
235
+ self.train_data_target: Union[Path, NDArray] = (
235
236
  Path(train_data_target)
236
237
  if isinstance(train_data_target, str)
237
238
  else train_data_target
238
239
  )
239
240
 
240
- self.val_data_target: Union[Path, np.ndarray] = (
241
+ self.val_data_target: Union[Path, NDArray] = (
241
242
  Path(val_data_target)
242
243
  if isinstance(val_data_target, str)
243
244
  else val_data_target
@@ -260,7 +261,7 @@ class CAREamicsTrainData(L.LightningDataModule):
260
261
  self.extension_filter: str = extension_filter
261
262
 
262
263
  # Pytorch dataloader parameters
263
- self.dataloader_params: Dict[str, Any] = (
264
+ self.dataloader_params: dict[str, Any] = (
264
265
  data_config.dataloader_params if data_config.dataloader_params else {}
265
266
  )
266
267
 
@@ -298,7 +299,7 @@ class CAREamicsTrainData(L.LightningDataModule):
298
299
 
299
300
  # same for target data
300
301
  if self.train_data_target is not None:
301
- self.train_target_files: List[Path] = list_files(
302
+ self.train_target_files: list[Path] = list_files(
302
303
  self.train_data_target, self.data_type, self.extension_filter
303
304
  )
304
305
 
@@ -403,7 +404,7 @@ class CAREamicsTrainData(L.LightningDataModule):
403
404
  )
404
405
 
405
406
  # create validation dataset
406
- if self.val_files is not None:
407
+ if self.val_data is not None:
407
408
  # create its own dataset
408
409
  self.val_dataset = PathIterableDataset(
409
410
  data_config=self.data_config,
@@ -423,9 +424,19 @@ class CAREamicsTrainData(L.LightningDataModule):
423
424
  # extract validation from the training patches
424
425
  self.val_dataset = self.train_dataset.split_dataset(
425
426
  percentage=self.val_percentage,
426
- minimum_files=self.val_minimum_split,
427
+ minimum_number=self.val_minimum_split,
427
428
  )
428
429
 
430
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
431
+ """Return training data statistics.
432
+
433
+ Returns
434
+ -------
435
+ tuple of list
436
+ Means and standard deviations across channels of the training data.
437
+ """
438
+ return self.train_dataset.get_data_statistics()
439
+
429
440
  def train_dataloader(self) -> Any:
430
441
  """
431
442
  Create a dataloader for training.
@@ -454,12 +465,30 @@ class CAREamicsTrainData(L.LightningDataModule):
454
465
  )
455
466
 
456
467
 
457
- class TrainingDataWrapper(CAREamicsTrainData):
458
- """
459
- Wrapper around the CAREamics Lightning training data module.
460
-
461
- This class is used to explicitely pass the parameters usually contained in a
462
- `data_model` configuration.
468
+ def create_train_datamodule(
469
+ train_data: Union[str, Path, NDArray],
470
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
471
+ patch_size: list[int],
472
+ axes: str,
473
+ batch_size: int,
474
+ val_data: Optional[Union[str, Path, NDArray]] = None,
475
+ transforms: Optional[list[TRANSFORMS_UNION]] = None,
476
+ train_target_data: Optional[Union[str, Path, NDArray]] = None,
477
+ val_target_data: Optional[Union[str, Path, NDArray]] = None,
478
+ read_source_func: Optional[Callable] = None,
479
+ extension_filter: str = "",
480
+ val_percentage: float = 0.1,
481
+ val_minimum_patches: int = 5,
482
+ dataloader_params: Optional[dict] = None,
483
+ use_in_memory: bool = True,
484
+ use_n2v2: bool = False,
485
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
486
+ struct_n2v_span: int = 5,
487
+ ) -> TrainDataModule:
488
+ """Create a TrainDataModule.
489
+
490
+ This function is used to explicitely pass the parameters usually contained in a
491
+ `data_model` configuration to a TrainDataModule.
463
492
 
464
493
  Since the lightning datamodule has no access to the model, make sure that the
465
494
  parameters passed to the datamodule are consistent with the model's requirements and
@@ -501,26 +530,26 @@ class TrainingDataWrapper(CAREamicsTrainData):
501
530
 
502
531
  Parameters
503
532
  ----------
504
- train_data : Union[str, Path, np.ndarray]
533
+ train_data : pathlib.Path or str or numpy.ndarray
505
534
  Training data.
506
- data_type : Union[str, SupportedData]
535
+ data_type : {"array", "tiff", "custom"}
507
536
  Data type, see `SupportedData` for available options.
508
- patch_size : List[int]
537
+ patch_size : list of int
509
538
  Patch size, 2D or 3D patch size.
510
539
  axes : str
511
540
  Axes of the data, choosen amongst SCZYX.
512
541
  batch_size : int
513
542
  Batch size.
514
- val_data : Optional[Union[str, Path]], optional
543
+ val_data : pathlib.Path or str or numpy.ndarray, optional
515
544
  Validation data, by default None.
516
- transforms : List[TRANSFORMS_UNION], optional
545
+ transforms : list of Transforms, optional
517
546
  List of transforms to apply to training patches. If None, default transforms
518
547
  are applied.
519
- train_target_data : Optional[Union[str, Path]], optional
548
+ train_target_data : pathlib.Path or str or numpy.ndarray, optional
520
549
  Training target data, by default None.
521
- val_target_data : Optional[Union[str, Path]], optional
550
+ val_target_data : pathlib.Path or str or numpy.ndarray, optional
522
551
  Validation target data, by default None.
523
- read_source_func : Optional[Callable], optional
552
+ read_source_func : Callable, optional
524
553
  Function to read the source data, used if `data_type` is `custom`, by
525
554
  default None.
526
555
  extension_filter : str, optional
@@ -537,19 +566,24 @@ class TrainingDataWrapper(CAREamicsTrainData):
537
566
  Use in memory dataset if possible, by default True.
538
567
  use_n2v2 : bool, optional
539
568
  Use N2V2 transformation during training, by default False.
540
- struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
569
+ struct_n2v_axis : {"horizontal", "vertical", "none"}, optional
541
570
  Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
542
571
  default "none".
543
572
  struct_n2v_span : int, optional
544
573
  Span for the structN2V mask, by default 5.
545
574
 
575
+ Returns
576
+ -------
577
+ TrainDataModule
578
+ CAREamics training Lightning data module.
579
+
546
580
  Examples
547
581
  --------
548
- Create a TrainingDataWrapper with default transforms with a numpy array:
582
+ Create a TrainingDataModule with default transforms with a numpy array:
549
583
  >>> import numpy as np
550
- >>> from careamics import TrainingDataWrapper
584
+ >>> from careamics.lightning import create_train_datamodule
551
585
  >>> my_array = np.arange(256).reshape(16, 16)
552
- >>> data_module = TrainingDataWrapper(
586
+ >>> data_module = create_train_datamodule(
553
587
  ... train_data=my_array,
554
588
  ... data_type="array",
555
589
  ... patch_size=(8, 8),
@@ -560,12 +594,12 @@ class TrainingDataWrapper(CAREamicsTrainData):
560
594
  For custom data types (those not supported by CAREamics), then one can pass a read
561
595
  function and a filter for the files extension:
562
596
  >>> import numpy as np
563
- >>> from careamics import TrainingDataWrapper
597
+ >>> from careamics.lightning import create_train_datamodule
564
598
  >>>
565
599
  >>> def read_npy(path):
566
600
  ... return np.load(path)
567
601
  >>>
568
- >>> data_module = TrainingDataWrapper(
602
+ >>> data_module = create_train_datamodule(
569
603
  ... train_data="path/to/data",
570
604
  ... data_type="custom",
571
605
  ... patch_size=(8, 8),
@@ -578,20 +612,15 @@ class TrainingDataWrapper(CAREamicsTrainData):
578
612
  If you want to use a different set of transformations, you can pass a list of
579
613
  transforms:
580
614
  >>> import numpy as np
581
- >>> from careamics import TrainingDataWrapper
615
+ >>> from careamics.lightning import create_train_datamodule
582
616
  >>> from careamics.config.support import SupportedTransform
583
617
  >>> my_array = np.arange(256).reshape(16, 16)
584
618
  >>> my_transforms = [
585
619
  ... {
586
- ... "name": SupportedTransform.NORMALIZE.value,
587
- ... "mean": 0,
588
- ... "std": 1,
589
- ... },
590
- ... {
591
- ... "name": SupportedTransform.N2V_MANIPULATE.value,
620
+ ... "name": SupportedTransform.XY_FLIP.value,
592
621
  ... }
593
622
  ... ]
594
- >>> data_module = TrainingDataWrapper(
623
+ >>> data_module = create_train_datamodule(
595
624
  ... train_data=my_array,
596
625
  ... data_type="array",
597
626
  ... patch_size=(8, 8),
@@ -600,166 +629,52 @@ class TrainingDataWrapper(CAREamicsTrainData):
600
629
  ... transforms=my_transforms,
601
630
  ... )
602
631
  """
632
+ if dataloader_params is None:
633
+ dataloader_params = {}
634
+
635
+ data_dict: dict[str, Any] = {
636
+ "mode": "train",
637
+ "data_type": data_type,
638
+ "patch_size": patch_size,
639
+ "axes": axes,
640
+ "batch_size": batch_size,
641
+ "dataloader_params": dataloader_params,
642
+ }
643
+
644
+ # if transforms are passed (otherwise it will use the default ones)
645
+ if transforms is not None:
646
+ data_dict["transforms"] = transforms
647
+
648
+ # validate configuration
649
+ data_config = DataConfig(**data_dict)
650
+
651
+ # N2V specific checks, N2V, structN2V, and transforms
652
+ if data_config.has_n2v_manipulate():
653
+ # there is not target, n2v2 and structN2V can be changed
654
+ if train_target_data is None:
655
+ data_config.set_N2V2(use_n2v2)
656
+ data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
657
+ else:
658
+ raise ValueError(
659
+ "Cannot have both supervised training (target data) and "
660
+ "N2V manipulation in the transforms. Pass a list of transforms "
661
+ "that is compatible with your supervised training."
662
+ )
603
663
 
604
- def __init__(
605
- self,
606
- train_data: Union[str, Path, np.ndarray],
607
- data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
608
- patch_size: List[int],
609
- axes: str,
610
- batch_size: int,
611
- val_data: Optional[Union[str, Path]] = None,
612
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
613
- train_target_data: Optional[Union[str, Path]] = None,
614
- val_target_data: Optional[Union[str, Path]] = None,
615
- read_source_func: Optional[Callable] = None,
616
- extension_filter: str = "",
617
- val_percentage: float = 0.1,
618
- val_minimum_patches: int = 5,
619
- dataloader_params: Optional[dict] = None,
620
- use_in_memory: bool = True,
621
- use_n2v2: bool = False,
622
- struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
623
- struct_n2v_span: int = 5,
624
- ) -> None:
625
- """
626
- LightningDataModule wrapper for training and validation datasets.
627
-
628
- Since the lightning datamodule has no access to the model, make sure that the
629
- parameters passed to the datamodule are consistent with the model's requirements
630
- and are coherent.
631
-
632
- The data module can be used with Path, str or numpy arrays. In the case of
633
- numpy arrays, it loads and computes all the patches in memory. For Path and str
634
- inputs, it calculates the total file size and estimate whether it can fit in
635
- memory. If it does not, it iterates through the files. This behaviour can be
636
- deactivated by setting `use_in_memory` to False, in which case it will
637
- always use the iterating dataset to train on a Path or str.
638
-
639
- To use array data, set `data_type` to `array` and pass a numpy array to
640
- `train_data`.
641
-
642
- In particular, N2V requires a specific transformation (N2V manipulates), which
643
- is not compatible with supervised training. The default transformations applied
644
- to the training patches are defined in `careamics.config.data_model`. To use
645
- different transformations, pass a list of transforms. See examples for more
646
- details.
647
-
648
- By default, CAREamics only supports types defined in
649
- `careamics.config.support.SupportedData`. To read custom data types, you can set
650
- `data_type` to `custom` and provide a function that returns a numpy array from a
651
- path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
652
- (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
653
-
654
- In the absence of validation data, the validation data is extracted from the
655
- training data. The percentage of the training data to use for validation, as
656
- well as the minimum number of patches to split from the training data for
657
- validation can be set using `val_percentage` and `val_minimum_patches`,
658
- respectively.
659
-
660
- In `dataloader_params`, you can pass any parameter accepted by PyTorch
661
- dataloaders, except for `batch_size`, which is set by the `batch_size`
662
- parameter.
663
-
664
- Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2`
665
- to use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to
666
- define the axis and span of the structN2V mask. These parameters are without
667
- effect if a `train_target_data` or if `transforms` are provided.
668
-
669
- Parameters
670
- ----------
671
- train_data : Union[str, Path, np.ndarray]
672
- Training data.
673
- data_type : Union[str, SupportedData]
674
- Data type, see `SupportedData` for available options.
675
- patch_size : List[int]
676
- Patch size, 2D or 3D patch size.
677
- axes : str
678
- Axes of the data, choosen amongst SCZYX.
679
- batch_size : int
680
- Batch size.
681
- val_data : Optional[Union[str, Path]], optional
682
- Validation data, by default None.
683
- transforms : Optional[List[TRANSFORMS_UNION]], optional
684
- List of transforms to apply to training patches. If None, default transforms
685
- are applied.
686
- train_target_data : Optional[Union[str, Path]], optional
687
- Training target data, by default None.
688
- val_target_data : Optional[Union[str, Path]], optional
689
- Validation target data, by default None.
690
- read_source_func : Optional[Callable], optional
691
- Function to read the source data, used if `data_type` is `custom`, by
692
- default None.
693
- extension_filter : str, optional
694
- Filter for file extensions, used if `data_type` is `custom`, by default "".
695
- val_percentage : float, optional
696
- Percentage of the training data to use for validation if no validation data
697
- is given, by default 0.1.
698
- val_minimum_patches : int, optional
699
- Minimum number of patches to split from the training data for validation if
700
- no validation data is given, by default 5.
701
- dataloader_params : dict, optional
702
- Pytorch dataloader parameters, by default {}.
703
- use_in_memory : bool, optional
704
- Use in memory dataset if possible, by default True.
705
- use_n2v2 : bool, optional
706
- Use N2V2 transformation during training, by default False.
707
- struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
708
- Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
709
- default "none".
710
- struct_n2v_span : int, optional
711
- Span for the structN2V mask, by default 5.
712
-
713
- Raises
714
- ------
715
- ValueError
716
- If a target is set and N2V manipulation is present in the transforms.
717
- """
718
- if dataloader_params is None:
719
- dataloader_params = {}
720
- data_dict: Dict[str, Any] = {
721
- "mode": "train",
722
- "data_type": data_type,
723
- "patch_size": patch_size,
724
- "axes": axes,
725
- "batch_size": batch_size,
726
- "dataloader_params": dataloader_params,
727
- }
728
-
729
- # if transforms are passed (otherwise it will use the default ones)
730
- if transforms is not None:
731
- data_dict["transforms"] = transforms
732
-
733
- # validate configuration
734
- self.data_config = DataConfig(**data_dict)
735
-
736
- # N2V specific checks, N2V, structN2V, and transforms
737
- if self.data_config.has_n2v_manipulate():
738
- # there is not target, n2v2 and structN2V can be changed
739
- if train_target_data is None:
740
- self.data_config.set_N2V2(use_n2v2)
741
- self.data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
742
- else:
743
- raise ValueError(
744
- "Cannot have both supervised training (target data) and "
745
- "N2V manipulation in the transforms. Pass a list of transforms "
746
- "that is compatible with your supervised training."
747
- )
748
-
749
- # sanity check on the dataloader parameters
750
- if "batch_size" in dataloader_params:
751
- # remove it
752
- del dataloader_params["batch_size"]
753
-
754
- super().__init__(
755
- data_config=self.data_config,
756
- train_data=train_data,
757
- val_data=val_data,
758
- train_data_target=train_target_data,
759
- val_data_target=val_target_data,
760
- read_source_func=read_source_func,
761
- extension_filter=extension_filter,
762
- val_percentage=val_percentage,
763
- val_minimum_split=val_minimum_patches,
764
- use_in_memory=use_in_memory,
765
- )
664
+ # sanity check on the dataloader parameters
665
+ if "batch_size" in dataloader_params:
666
+ # remove it
667
+ del dataloader_params["batch_size"]
668
+
669
+ return TrainDataModule(
670
+ data_config=data_config,
671
+ train_data=train_data,
672
+ val_data=val_data,
673
+ train_data_target=train_target_data,
674
+ val_data_target=val_target_data,
675
+ read_source_func=read_source_func,
676
+ extension_filter=extension_filter,
677
+ val_percentage=val_percentage,
678
+ val_minimum_split=val_minimum_patches,
679
+ use_in_memory=use_in_memory,
680
+ )
File without changes