careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +83 -62
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -0
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +2 -0
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +1 -79
  11. careamics/config/configuration_model.py +12 -7
  12. careamics/config/data_model.py +29 -10
  13. careamics/config/inference_model.py +12 -2
  14. careamics/config/optimizer_models.py +6 -0
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/tile_information.py +10 -0
  17. careamics/config/training_model.py +5 -1
  18. careamics/dataset/dataset_utils/__init__.py +0 -6
  19. careamics/dataset/dataset_utils/file_utils.py +1 -1
  20. careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
  21. careamics/dataset/in_memory_dataset.py +37 -21
  22. careamics/dataset/iterable_dataset.py +38 -34
  23. careamics/dataset/iterable_pred_dataset.py +2 -1
  24. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  25. careamics/dataset/patching/patching.py +53 -37
  26. careamics/file_io/__init__.py +7 -0
  27. careamics/file_io/read/__init__.py +11 -0
  28. careamics/file_io/read/get_func.py +56 -0
  29. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
  30. careamics/file_io/write/__init__.py +9 -0
  31. careamics/file_io/write/get_func.py +59 -0
  32. careamics/file_io/write/tiff.py +39 -0
  33. careamics/lightning/__init__.py +17 -0
  34. careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
  35. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
  36. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
  37. careamics/model_io/bmz_io.py +1 -1
  38. careamics/model_io/model_io_utils.py +1 -1
  39. careamics/prediction_utils/__init__.py +0 -2
  40. careamics/prediction_utils/prediction_outputs.py +18 -46
  41. careamics/prediction_utils/stitch_prediction.py +17 -14
  42. careamics/utils/__init__.py +2 -0
  43. careamics/utils/autocorrelation.py +40 -0
  44. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
  45. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
  46. careamics/config/configuration_example.py +0 -86
  47. careamics/dataset/dataset_utils/read_utils.py +0 -27
  48. careamics/prediction_utils/create_pred_datamodule.py +0 -185
  49. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  50. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  51. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  52. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  53. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
  54. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,11 @@
1
1
  """Training and validation Lightning data modules."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+ from typing import Any, Callable, Literal, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
+ from numpy.typing import NDArray
8
9
  from torch.utils.data import DataLoader
9
10
 
10
11
  from careamics.config import DataConfig
@@ -12,7 +13,6 @@ from careamics.config.data_model import TRANSFORMS_UNION
12
13
  from careamics.config.support import SupportedData
13
14
  from careamics.dataset.dataset_utils import (
14
15
  get_files_size,
15
- get_read_func,
16
16
  list_files,
17
17
  validate_source_target_files,
18
18
  )
@@ -22,6 +22,7 @@ from careamics.dataset.in_memory_dataset import (
22
22
  from careamics.dataset.iterable_dataset import (
23
23
  PathIterableDataset,
24
24
  )
25
+ from careamics.file_io.read import get_read_func
25
26
  from careamics.utils import get_logger, get_ram_size
26
27
 
27
28
  DatasetType = Union[InMemoryDataset, PathIterableDataset]
@@ -29,7 +30,7 @@ DatasetType = Union[InMemoryDataset, PathIterableDataset]
29
30
  logger = get_logger(__name__)
30
31
 
31
32
 
32
- class CAREamicsTrainData(L.LightningDataModule):
33
+ class TrainDataModule(L.LightningDataModule):
33
34
  """
34
35
  CAREamics Ligthning training and validation data module.
35
36
 
@@ -59,18 +60,18 @@ class CAREamicsTrainData(L.LightningDataModule):
59
60
  ----------
60
61
  data_config : DataModel
61
62
  Pydantic model for CAREamics data configuration.
62
- train_data : Union[Path, str, np.ndarray]
63
+ train_data : pathlib.Path or str or numpy.ndarray
63
64
  Training data, can be a path to a folder, a file or a numpy array.
64
- val_data : Optional[Union[Path, str, np.ndarray]], optional
65
+ val_data : pathlib.Path or str or numpy.ndarray, optional
65
66
  Validation data, can be a path to a folder, a file or a numpy array, by
66
67
  default None.
67
- train_data_target : Optional[Union[Path, str, np.ndarray]], optional
68
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
68
69
  Training target data, can be a path to a folder, a file or a numpy array, by
69
70
  default None.
70
- val_data_target : Optional[Union[Path, str, np.ndarray]], optional
71
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
71
72
  Validation target data, can be a path to a folder, a file or a numpy array,
72
73
  by default None.
73
- read_source_func : Optional[Callable], optional
74
+ read_source_func : Callable, optional
74
75
  Function to read the source data, by default None. Only used for `custom`
75
76
  data type (see DataModel).
76
77
  extension_filter : str, optional
@@ -95,13 +96,13 @@ class CAREamicsTrainData(L.LightningDataModule):
95
96
  Batch size.
96
97
  use_in_memory : bool
97
98
  Whether to use in memory dataset if possible.
98
- train_data : Union[Path, np.ndarray]
99
+ train_data : pathlib.Path or numpy.ndarray
99
100
  Training data.
100
- val_data : Optional[Union[Path, np.ndarray]]
101
+ val_data : pathlib.Path or numpy.ndarray
101
102
  Validation data.
102
- train_data_target : Optional[Union[Path, np.ndarray]]
103
+ train_data_target : pathlib.Path or numpy.ndarray
103
104
  Training target data.
104
- val_data_target : Optional[Union[Path, np.ndarray]]
105
+ val_data_target : pathlib.Path or numpy.ndarray
105
106
  Validation target data.
106
107
  val_percentage : float
107
108
  Percentage of the training data to use for validation, if no validation data is
@@ -118,10 +119,10 @@ class CAREamicsTrainData(L.LightningDataModule):
118
119
  def __init__(
119
120
  self,
120
121
  data_config: DataConfig,
121
- train_data: Union[Path, str, np.ndarray],
122
- val_data: Optional[Union[Path, str, np.ndarray]] = None,
123
- train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
124
- val_data_target: Optional[Union[Path, str, np.ndarray]] = None,
122
+ train_data: Union[Path, str, NDArray],
123
+ val_data: Optional[Union[Path, str, NDArray]] = None,
124
+ train_data_target: Optional[Union[Path, str, NDArray]] = None,
125
+ val_data_target: Optional[Union[Path, str, NDArray]] = None,
125
126
  read_source_func: Optional[Callable] = None,
126
127
  extension_filter: str = "",
127
128
  val_percentage: float = 0.1,
@@ -135,18 +136,18 @@ class CAREamicsTrainData(L.LightningDataModule):
135
136
  ----------
136
137
  data_config : DataModel
137
138
  Pydantic model for CAREamics data configuration.
138
- train_data : Union[Path, str, np.ndarray]
139
+ train_data : pathlib.Path or str or numpy.ndarray
139
140
  Training data, can be a path to a folder, a file or a numpy array.
140
- val_data : Optional[Union[Path, str, np.ndarray]], optional
141
+ val_data : pathlib.Path or str or numpy.ndarray, optional
141
142
  Validation data, can be a path to a folder, a file or a numpy array, by
142
143
  default None.
143
- train_data_target : Optional[Union[Path, str, np.ndarray]], optional
144
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
144
145
  Training target data, can be a path to a folder, a file or a numpy array, by
145
146
  default None.
146
- val_data_target : Optional[Union[Path, str, np.ndarray]], optional
147
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
147
148
  Validation target data, can be a path to a folder, a file or a numpy array,
148
149
  by default None.
149
- read_source_func : Optional[Callable], optional
150
+ read_source_func : Callable, optional
150
151
  Function to read the source data, by default None. Only used for `custom`
151
152
  data type (see DataModel).
152
153
  extension_filter : str, optional
@@ -166,7 +167,7 @@ class CAREamicsTrainData(L.LightningDataModule):
166
167
  NotImplementedError
167
168
  Raised if target data is provided.
168
169
  ValueError
169
- If the input types are mixed (e.g. Path and np.ndarray).
170
+ If the input types are mixed (e.g. Path and numpy.ndarray).
170
171
  ValueError
171
172
  If the data type is `custom` and no `read_source_func` is provided.
172
173
  ValueError
@@ -223,21 +224,21 @@ class CAREamicsTrainData(L.LightningDataModule):
223
224
  self.use_in_memory: bool = use_in_memory
224
225
 
225
226
  # data: make data Path or np.ndarray, use type annotations for mypy
226
- self.train_data: Union[Path, np.ndarray] = (
227
+ self.train_data: Union[Path, NDArray] = (
227
228
  Path(train_data) if isinstance(train_data, str) else train_data
228
229
  )
229
230
 
230
- self.val_data: Union[Path, np.ndarray] = (
231
+ self.val_data: Union[Path, NDArray] = (
231
232
  Path(val_data) if isinstance(val_data, str) else val_data
232
233
  )
233
234
 
234
- self.train_data_target: Union[Path, np.ndarray] = (
235
+ self.train_data_target: Union[Path, NDArray] = (
235
236
  Path(train_data_target)
236
237
  if isinstance(train_data_target, str)
237
238
  else train_data_target
238
239
  )
239
240
 
240
- self.val_data_target: Union[Path, np.ndarray] = (
241
+ self.val_data_target: Union[Path, NDArray] = (
241
242
  Path(val_data_target)
242
243
  if isinstance(val_data_target, str)
243
244
  else val_data_target
@@ -260,7 +261,7 @@ class CAREamicsTrainData(L.LightningDataModule):
260
261
  self.extension_filter: str = extension_filter
261
262
 
262
263
  # Pytorch dataloader parameters
263
- self.dataloader_params: Dict[str, Any] = (
264
+ self.dataloader_params: dict[str, Any] = (
264
265
  data_config.dataloader_params if data_config.dataloader_params else {}
265
266
  )
266
267
 
@@ -298,7 +299,7 @@ class CAREamicsTrainData(L.LightningDataModule):
298
299
 
299
300
  # same for target data
300
301
  if self.train_data_target is not None:
301
- self.train_target_files: List[Path] = list_files(
302
+ self.train_target_files: list[Path] = list_files(
302
303
  self.train_data_target, self.data_type, self.extension_filter
303
304
  )
304
305
 
@@ -403,7 +404,7 @@ class CAREamicsTrainData(L.LightningDataModule):
403
404
  )
404
405
 
405
406
  # create validation dataset
406
- if self.val_files is not None:
407
+ if self.val_data is not None:
407
408
  # create its own dataset
408
409
  self.val_dataset = PathIterableDataset(
409
410
  data_config=self.data_config,
@@ -423,9 +424,19 @@ class CAREamicsTrainData(L.LightningDataModule):
423
424
  # extract validation from the training patches
424
425
  self.val_dataset = self.train_dataset.split_dataset(
425
426
  percentage=self.val_percentage,
426
- minimum_files=self.val_minimum_split,
427
+ minimum_number=self.val_minimum_split,
427
428
  )
428
429
 
430
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
431
+ """Return training data statistics.
432
+
433
+ Returns
434
+ -------
435
+ tuple of list
436
+ Means and standard deviations across channels of the training data.
437
+ """
438
+ return self.train_dataset.get_data_statistics()
439
+
429
440
  def train_dataloader(self) -> Any:
430
441
  """
431
442
  Create a dataloader for training.
@@ -454,12 +465,30 @@ class CAREamicsTrainData(L.LightningDataModule):
454
465
  )
455
466
 
456
467
 
457
- class TrainingDataWrapper(CAREamicsTrainData):
458
- """
459
- Wrapper around the CAREamics Lightning training data module.
460
-
461
- This class is used to explicitely pass the parameters usually contained in a
462
- `data_model` configuration.
468
+ def create_train_datamodule(
469
+ train_data: Union[str, Path, NDArray],
470
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
471
+ patch_size: list[int],
472
+ axes: str,
473
+ batch_size: int,
474
+ val_data: Optional[Union[str, Path, NDArray]] = None,
475
+ transforms: Optional[list[TRANSFORMS_UNION]] = None,
476
+ train_target_data: Optional[Union[str, Path, NDArray]] = None,
477
+ val_target_data: Optional[Union[str, Path, NDArray]] = None,
478
+ read_source_func: Optional[Callable] = None,
479
+ extension_filter: str = "",
480
+ val_percentage: float = 0.1,
481
+ val_minimum_patches: int = 5,
482
+ dataloader_params: Optional[dict] = None,
483
+ use_in_memory: bool = True,
484
+ use_n2v2: bool = False,
485
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
486
+ struct_n2v_span: int = 5,
487
+ ) -> TrainDataModule:
488
+ """Create a TrainDataModule.
489
+
490
+ This function is used to explicitely pass the parameters usually contained in a
491
+ `data_model` configuration to a TrainDataModule.
463
492
 
464
493
  Since the lightning datamodule has no access to the model, make sure that the
465
494
  parameters passed to the datamodule are consistent with the model's requirements and
@@ -501,26 +530,26 @@ class TrainingDataWrapper(CAREamicsTrainData):
501
530
 
502
531
  Parameters
503
532
  ----------
504
- train_data : Union[str, Path, np.ndarray]
533
+ train_data : pathlib.Path or str or numpy.ndarray
505
534
  Training data.
506
- data_type : Union[str, SupportedData]
535
+ data_type : {"array", "tiff", "custom"}
507
536
  Data type, see `SupportedData` for available options.
508
- patch_size : List[int]
537
+ patch_size : list of int
509
538
  Patch size, 2D or 3D patch size.
510
539
  axes : str
511
540
  Axes of the data, choosen amongst SCZYX.
512
541
  batch_size : int
513
542
  Batch size.
514
- val_data : Optional[Union[str, Path]], optional
543
+ val_data : pathlib.Path or str or numpy.ndarray, optional
515
544
  Validation data, by default None.
516
- transforms : List[TRANSFORMS_UNION], optional
545
+ transforms : list of Transforms, optional
517
546
  List of transforms to apply to training patches. If None, default transforms
518
547
  are applied.
519
- train_target_data : Optional[Union[str, Path]], optional
548
+ train_target_data : pathlib.Path or str or numpy.ndarray, optional
520
549
  Training target data, by default None.
521
- val_target_data : Optional[Union[str, Path]], optional
550
+ val_target_data : pathlib.Path or str or numpy.ndarray, optional
522
551
  Validation target data, by default None.
523
- read_source_func : Optional[Callable], optional
552
+ read_source_func : Callable, optional
524
553
  Function to read the source data, used if `data_type` is `custom`, by
525
554
  default None.
526
555
  extension_filter : str, optional
@@ -537,19 +566,24 @@ class TrainingDataWrapper(CAREamicsTrainData):
537
566
  Use in memory dataset if possible, by default True.
538
567
  use_n2v2 : bool, optional
539
568
  Use N2V2 transformation during training, by default False.
540
- struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
569
+ struct_n2v_axis : {"horizontal", "vertical", "none"}, optional
541
570
  Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
542
571
  default "none".
543
572
  struct_n2v_span : int, optional
544
573
  Span for the structN2V mask, by default 5.
545
574
 
575
+ Returns
576
+ -------
577
+ TrainDataModule
578
+ CAREamics training Lightning data module.
579
+
546
580
  Examples
547
581
  --------
548
- Create a TrainingDataWrapper with default transforms with a numpy array:
582
+ Create a TrainingDataModule with default transforms with a numpy array:
549
583
  >>> import numpy as np
550
- >>> from careamics import TrainingDataWrapper
584
+ >>> from careamics.lightning import create_train_datamodule
551
585
  >>> my_array = np.arange(256).reshape(16, 16)
552
- >>> data_module = TrainingDataWrapper(
586
+ >>> data_module = create_train_datamodule(
553
587
  ... train_data=my_array,
554
588
  ... data_type="array",
555
589
  ... patch_size=(8, 8),
@@ -560,12 +594,12 @@ class TrainingDataWrapper(CAREamicsTrainData):
560
594
  For custom data types (those not supported by CAREamics), then one can pass a read
561
595
  function and a filter for the files extension:
562
596
  >>> import numpy as np
563
- >>> from careamics import TrainingDataWrapper
597
+ >>> from careamics.lightning import create_train_datamodule
564
598
  >>>
565
599
  >>> def read_npy(path):
566
600
  ... return np.load(path)
567
601
  >>>
568
- >>> data_module = TrainingDataWrapper(
602
+ >>> data_module = create_train_datamodule(
569
603
  ... train_data="path/to/data",
570
604
  ... data_type="custom",
571
605
  ... patch_size=(8, 8),
@@ -578,7 +612,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
578
612
  If you want to use a different set of transformations, you can pass a list of
579
613
  transforms:
580
614
  >>> import numpy as np
581
- >>> from careamics import TrainingDataWrapper
615
+ >>> from careamics.lightning import create_train_datamodule
582
616
  >>> from careamics.config.support import SupportedTransform
583
617
  >>> my_array = np.arange(256).reshape(16, 16)
584
618
  >>> my_transforms = [
@@ -586,7 +620,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
586
620
  ... "name": SupportedTransform.XY_FLIP.value,
587
621
  ... }
588
622
  ... ]
589
- >>> data_module = TrainingDataWrapper(
623
+ >>> data_module = create_train_datamodule(
590
624
  ... train_data=my_array,
591
625
  ... data_type="array",
592
626
  ... patch_size=(8, 8),
@@ -595,166 +629,52 @@ class TrainingDataWrapper(CAREamicsTrainData):
595
629
  ... transforms=my_transforms,
596
630
  ... )
597
631
  """
632
+ if dataloader_params is None:
633
+ dataloader_params = {}
634
+
635
+ data_dict: dict[str, Any] = {
636
+ "mode": "train",
637
+ "data_type": data_type,
638
+ "patch_size": patch_size,
639
+ "axes": axes,
640
+ "batch_size": batch_size,
641
+ "dataloader_params": dataloader_params,
642
+ }
643
+
644
+ # if transforms are passed (otherwise it will use the default ones)
645
+ if transforms is not None:
646
+ data_dict["transforms"] = transforms
647
+
648
+ # validate configuration
649
+ data_config = DataConfig(**data_dict)
650
+
651
+ # N2V specific checks, N2V, structN2V, and transforms
652
+ if data_config.has_n2v_manipulate():
653
+ # there is not target, n2v2 and structN2V can be changed
654
+ if train_target_data is None:
655
+ data_config.set_N2V2(use_n2v2)
656
+ data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
657
+ else:
658
+ raise ValueError(
659
+ "Cannot have both supervised training (target data) and "
660
+ "N2V manipulation in the transforms. Pass a list of transforms "
661
+ "that is compatible with your supervised training."
662
+ )
598
663
 
599
- def __init__(
600
- self,
601
- train_data: Union[str, Path, np.ndarray],
602
- data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
603
- patch_size: List[int],
604
- axes: str,
605
- batch_size: int,
606
- val_data: Optional[Union[str, Path]] = None,
607
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
608
- train_target_data: Optional[Union[str, Path]] = None,
609
- val_target_data: Optional[Union[str, Path]] = None,
610
- read_source_func: Optional[Callable] = None,
611
- extension_filter: str = "",
612
- val_percentage: float = 0.1,
613
- val_minimum_patches: int = 5,
614
- dataloader_params: Optional[dict] = None,
615
- use_in_memory: bool = True,
616
- use_n2v2: bool = False,
617
- struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
618
- struct_n2v_span: int = 5,
619
- ) -> None:
620
- """
621
- LightningDataModule wrapper for training and validation datasets.
622
-
623
- Since the lightning datamodule has no access to the model, make sure that the
624
- parameters passed to the datamodule are consistent with the model's requirements
625
- and are coherent.
626
-
627
- The data module can be used with Path, str or numpy arrays. In the case of
628
- numpy arrays, it loads and computes all the patches in memory. For Path and str
629
- inputs, it calculates the total file size and estimate whether it can fit in
630
- memory. If it does not, it iterates through the files. This behaviour can be
631
- deactivated by setting `use_in_memory` to False, in which case it will
632
- always use the iterating dataset to train on a Path or str.
633
-
634
- To use array data, set `data_type` to `array` and pass a numpy array to
635
- `train_data`.
636
-
637
- In particular, N2V requires a specific transformation (N2V manipulates), which
638
- is not compatible with supervised training. The default transformations applied
639
- to the training patches are defined in `careamics.config.data_model`. To use
640
- different transformations, pass a list of transforms. See examples for more
641
- details.
642
-
643
- By default, CAREamics only supports types defined in
644
- `careamics.config.support.SupportedData`. To read custom data types, you can set
645
- `data_type` to `custom` and provide a function that returns a numpy array from a
646
- path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
647
- (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
648
-
649
- In the absence of validation data, the validation data is extracted from the
650
- training data. The percentage of the training data to use for validation, as
651
- well as the minimum number of patches to split from the training data for
652
- validation can be set using `val_percentage` and `val_minimum_patches`,
653
- respectively.
654
-
655
- In `dataloader_params`, you can pass any parameter accepted by PyTorch
656
- dataloaders, except for `batch_size`, which is set by the `batch_size`
657
- parameter.
658
-
659
- Finally, if you intend to use N2V family of algorithms, you can set `use_n2v2`
660
- to use N2V2, and set the `struct_n2v_axis` and `struct_n2v_span` parameters to
661
- define the axis and span of the structN2V mask. These parameters are without
662
- effect if a `train_target_data` or if `transforms` are provided.
663
-
664
- Parameters
665
- ----------
666
- train_data : Union[str, Path, np.ndarray]
667
- Training data.
668
- data_type : Union[str, SupportedData]
669
- Data type, see `SupportedData` for available options.
670
- patch_size : List[int]
671
- Patch size, 2D or 3D patch size.
672
- axes : str
673
- Axes of the data, choosen amongst SCZYX.
674
- batch_size : int
675
- Batch size.
676
- val_data : Optional[Union[str, Path]], optional
677
- Validation data, by default None.
678
- transforms : Optional[List[TRANSFORMS_UNION]], optional
679
- List of transforms to apply to training patches. If None, default transforms
680
- are applied.
681
- train_target_data : Optional[Union[str, Path]], optional
682
- Training target data, by default None.
683
- val_target_data : Optional[Union[str, Path]], optional
684
- Validation target data, by default None.
685
- read_source_func : Optional[Callable], optional
686
- Function to read the source data, used if `data_type` is `custom`, by
687
- default None.
688
- extension_filter : str, optional
689
- Filter for file extensions, used if `data_type` is `custom`, by default "".
690
- val_percentage : float, optional
691
- Percentage of the training data to use for validation if no validation data
692
- is given, by default 0.1.
693
- val_minimum_patches : int, optional
694
- Minimum number of patches to split from the training data for validation if
695
- no validation data is given, by default 5.
696
- dataloader_params : dict, optional
697
- Pytorch dataloader parameters, by default {}.
698
- use_in_memory : bool, optional
699
- Use in memory dataset if possible, by default True.
700
- use_n2v2 : bool, optional
701
- Use N2V2 transformation during training, by default False.
702
- struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
703
- Axis for the structN2V mask, only applied if `struct_n2v_axis` is `none`, by
704
- default "none".
705
- struct_n2v_span : int, optional
706
- Span for the structN2V mask, by default 5.
707
-
708
- Raises
709
- ------
710
- ValueError
711
- If a target is set and N2V manipulation is present in the transforms.
712
- """
713
- if dataloader_params is None:
714
- dataloader_params = {}
715
- data_dict: Dict[str, Any] = {
716
- "mode": "train",
717
- "data_type": data_type,
718
- "patch_size": patch_size,
719
- "axes": axes,
720
- "batch_size": batch_size,
721
- "dataloader_params": dataloader_params,
722
- }
723
-
724
- # if transforms are passed (otherwise it will use the default ones)
725
- if transforms is not None:
726
- data_dict["transforms"] = transforms
727
-
728
- # validate configuration
729
- self.data_config = DataConfig(**data_dict)
730
-
731
- # N2V specific checks, N2V, structN2V, and transforms
732
- if self.data_config.has_n2v_manipulate():
733
- # there is not target, n2v2 and structN2V can be changed
734
- if train_target_data is None:
735
- self.data_config.set_N2V2(use_n2v2)
736
- self.data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
737
- else:
738
- raise ValueError(
739
- "Cannot have both supervised training (target data) and "
740
- "N2V manipulation in the transforms. Pass a list of transforms "
741
- "that is compatible with your supervised training."
742
- )
743
-
744
- # sanity check on the dataloader parameters
745
- if "batch_size" in dataloader_params:
746
- # remove it
747
- del dataloader_params["batch_size"]
748
-
749
- super().__init__(
750
- data_config=self.data_config,
751
- train_data=train_data,
752
- val_data=val_data,
753
- train_data_target=train_target_data,
754
- val_data_target=val_target_data,
755
- read_source_func=read_source_func,
756
- extension_filter=extension_filter,
757
- val_percentage=val_percentage,
758
- val_minimum_split=val_minimum_patches,
759
- use_in_memory=use_in_memory,
760
- )
664
+ # sanity check on the dataloader parameters
665
+ if "batch_size" in dataloader_params:
666
+ # remove it
667
+ del dataloader_params["batch_size"]
668
+
669
+ return TrainDataModule(
670
+ data_config=data_config,
671
+ train_data=train_data,
672
+ val_data=val_data,
673
+ train_data_target=train_target_data,
674
+ val_data_target=val_target_data,
675
+ read_source_func=read_source_func,
676
+ extension_filter=extension_filter,
677
+ val_percentage=val_percentage,
678
+ val_minimum_split=val_minimum_patches,
679
+ use_in_memory=use_in_memory,
680
+ )
@@ -12,7 +12,7 @@ from torch import __version__, load, save
12
12
 
13
13
  from careamics.config import Configuration, load_configuration, save_configuration
14
14
  from careamics.config.support import SupportedArchitecture
15
- from careamics.lightning_module import CAREamicsModule
15
+ from careamics.lightning.lightning_module import CAREamicsModule
16
16
 
17
17
  from .bioimage import (
18
18
  create_env_text,
@@ -6,7 +6,7 @@ from typing import Tuple, Union
6
6
  import torch
7
7
 
8
8
  from careamics.config import Configuration
9
- from careamics.lightning_module import CAREamicsModule
9
+ from careamics.lightning.lightning_module import CAREamicsModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
12
12
 
@@ -1,12 +1,10 @@
1
1
  """Package to house various prediction utilies."""
2
2
 
3
3
  __all__ = [
4
- "create_pred_datamodule",
5
4
  "stitch_prediction",
6
5
  "stitch_prediction_single",
7
6
  "convert_outputs",
8
7
  ]
9
8
 
10
- from .create_pred_datamodule import create_pred_datamodule
11
9
  from .prediction_outputs import convert_outputs
12
10
  from .stitch_prediction import stitch_prediction, stitch_prediction_single