careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (66) hide show
  1. careamics/__init__.py +8 -6
  2. careamics/careamist.py +30 -29
  3. careamics/config/__init__.py +12 -9
  4. careamics/config/algorithm_model.py +5 -5
  5. careamics/config/architectures/unet_model.py +1 -0
  6. careamics/config/callback_model.py +1 -0
  7. careamics/config/configuration_example.py +87 -0
  8. careamics/config/configuration_factory.py +285 -78
  9. careamics/config/configuration_model.py +22 -23
  10. careamics/config/data_model.py +62 -160
  11. careamics/config/inference_model.py +20 -21
  12. careamics/config/references/algorithm_descriptions.py +1 -0
  13. careamics/config/references/references.py +1 -0
  14. careamics/config/support/supported_extraction_strategies.py +1 -0
  15. careamics/config/support/supported_optimizers.py +3 -3
  16. careamics/config/training_model.py +2 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +2 -1
  18. careamics/config/transformations/nd_flip_model.py +7 -12
  19. careamics/config/transformations/normalize_model.py +2 -1
  20. careamics/config/transformations/transform_model.py +1 -0
  21. careamics/config/transformations/xy_random_rotate90_model.py +7 -9
  22. careamics/config/validators/validator_utils.py +1 -0
  23. careamics/conftest.py +1 -0
  24. careamics/dataset/dataset_utils/__init__.py +0 -1
  25. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  26. careamics/dataset/in_memory_dataset.py +17 -48
  27. careamics/dataset/iterable_dataset.py +16 -71
  28. careamics/dataset/patching/__init__.py +0 -7
  29. careamics/dataset/patching/patching.py +1 -0
  30. careamics/dataset/patching/sequential_patching.py +6 -6
  31. careamics/dataset/patching/tiled_patching.py +10 -6
  32. careamics/lightning_datamodule.py +123 -49
  33. careamics/lightning_module.py +7 -7
  34. careamics/lightning_prediction_datamodule.py +59 -48
  35. careamics/losses/__init__.py +0 -1
  36. careamics/losses/loss_factory.py +1 -0
  37. careamics/model_io/__init__.py +0 -1
  38. careamics/model_io/bioimage/_readme_factory.py +2 -1
  39. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  40. careamics/model_io/bioimage/model_description.py +4 -3
  41. careamics/model_io/bmz_io.py +8 -7
  42. careamics/model_io/model_io_utils.py +4 -4
  43. careamics/models/layers.py +1 -0
  44. careamics/models/model_factory.py +1 -0
  45. careamics/models/unet.py +91 -17
  46. careamics/prediction/stitch_prediction.py +1 -0
  47. careamics/transforms/__init__.py +2 -23
  48. careamics/transforms/compose.py +98 -0
  49. careamics/transforms/n2v_manipulate.py +18 -23
  50. careamics/transforms/nd_flip.py +38 -64
  51. careamics/transforms/normalize.py +45 -34
  52. careamics/transforms/pixel_manipulation.py +2 -2
  53. careamics/transforms/transform.py +33 -0
  54. careamics/transforms/tta.py +2 -2
  55. careamics/transforms/xy_random_rotate90.py +41 -68
  56. careamics/utils/__init__.py +0 -1
  57. careamics/utils/context.py +1 -0
  58. careamics/utils/logging.py +1 -0
  59. careamics/utils/metrics.py +1 -0
  60. careamics/utils/torch_utils.py +1 -0
  61. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  62. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  63. careamics/dataset/patching/patch_transform.py +0 -44
  64. careamics-0.1.0rc3.dist-info/RECORD +0 -109
  65. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  66. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,13 @@
1
+ """Training and validation Lightning data modules."""
2
+
1
3
  from pathlib import Path
2
4
  from typing import Any, Callable, Dict, List, Literal, Optional, Union
3
5
 
4
6
  import numpy as np
5
7
  import pytorch_lightning as L
6
- from albumentations import Compose
7
8
  from torch.utils.data import DataLoader
8
9
 
9
- from careamics.config import DataModel
10
+ from careamics.config import DataConfig
10
11
  from careamics.config.data_model import TRANSFORMS_UNION
11
12
  from careamics.config.support import SupportedData
12
13
  from careamics.dataset.dataset_utils import (
@@ -28,9 +29,9 @@ DatasetType = Union[InMemoryDataset, PathIterableDataset]
28
29
  logger = get_logger(__name__)
29
30
 
30
31
 
31
- class CAREamicsWood(L.LightningDataModule):
32
+ class CAREamicsTrainData(L.LightningDataModule):
32
33
  """
33
- LightningDataModule for training and validation datasets.
34
+ CAREamics Ligthning training and validation data module.
34
35
 
35
36
  The data module can be used with Path, str or numpy arrays. In the case of
36
37
  numpy arrays, it loads and computes all the patches in memory. For Path and str
@@ -53,11 +54,70 @@ class CAREamicsWood(L.LightningDataModule):
53
54
 
54
55
  You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
55
56
  "*.czi") to filter the files extension using `extension_filter`.
57
+
58
+ Parameters
59
+ ----------
60
+ data_config : DataModel
61
+ Pydantic model for CAREamics data configuration.
62
+ train_data : Union[Path, str, np.ndarray]
63
+ Training data, can be a path to a folder, a file or a numpy array.
64
+ val_data : Optional[Union[Path, str, np.ndarray]], optional
65
+ Validation data, can be a path to a folder, a file or a numpy array, by
66
+ default None.
67
+ train_data_target : Optional[Union[Path, str, np.ndarray]], optional
68
+ Training target data, can be a path to a folder, a file or a numpy array, by
69
+ default None.
70
+ val_data_target : Optional[Union[Path, str, np.ndarray]], optional
71
+ Validation target data, can be a path to a folder, a file or a numpy array,
72
+ by default None.
73
+ read_source_func : Optional[Callable], optional
74
+ Function to read the source data, by default None. Only used for `custom`
75
+ data type (see DataModel).
76
+ extension_filter : str, optional
77
+ Filter for file extensions, by default "". Only used for `custom` data types
78
+ (see DataModel).
79
+ val_percentage : float, optional
80
+ Percentage of the training data to use for validation, by default 0.1. Only
81
+ used if `val_data` is None.
82
+ val_minimum_split : int, optional
83
+ Minimum number of patches or files to split from the training data for
84
+ validation, by default 5. Only used if `val_data` is None.
85
+ use_in_memory : bool, optional
86
+ Use in memory dataset if possible, by default True.
87
+
88
+ Attributes
89
+ ----------
90
+ data_config : DataModel
91
+ CAREamics data configuration.
92
+ data_type : SupportedData
93
+ Expected data type, one of "tiff", "array" or "custom".
94
+ batch_size : int
95
+ Batch size.
96
+ use_in_memory : bool
97
+ Whether to use in memory dataset if possible.
98
+ train_data : Union[Path, str, np.ndarray]
99
+ Training data.
100
+ val_data : Optional[Union[Path, str, np.ndarray]]
101
+ Validation data.
102
+ train_data_target : Optional[Union[Path, str, np.ndarray]]
103
+ Training target data.
104
+ val_data_target : Optional[Union[Path, str, np.ndarray]]
105
+ Validation target data.
106
+ val_percentage : float
107
+ Percentage of the training data to use for validation, if no validation data is
108
+ provided.
109
+ val_minimum_split : int
110
+ Minimum number of patches or files to split from the training data for
111
+ validation, if no validation data is provided.
112
+ read_source_func : Optional[Callable]
113
+ Function to read the source data, used if `data_type` is `custom`.
114
+ extension_filter : str
115
+ Filter for file extensions, used if `data_type` is `custom`.
56
116
  """
57
117
 
58
118
  def __init__(
59
119
  self,
60
- data_config: DataModel,
120
+ data_config: DataConfig,
61
121
  train_data: Union[Path, str, np.ndarray],
62
122
  val_data: Optional[Union[Path, str, np.ndarray]] = None,
63
123
  train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
@@ -98,6 +158,8 @@ class CAREamicsWood(L.LightningDataModule):
98
158
  val_minimum_split : int, optional
99
159
  Minimum number of patches or files to split from the training data for
100
160
  validation, by default 5. Only used if `val_data` is None.
161
+ use_in_memory : bool, optional
162
+ Use in memory dataset if possible, by default True.
101
163
 
102
164
  Raises
103
165
  ------
@@ -128,25 +190,30 @@ class CAREamicsWood(L.LightningDataModule):
128
190
  if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
129
191
  raise ValueError(
130
192
  f"Data type {SupportedData.CUSTOM} is not allowed without "
131
- f"specifying a `read_source_func`."
193
+ f"specifying a `read_source_func` and an `extension_filer`."
132
194
  )
133
195
 
134
- # and that arrays are passed, if array type specified
135
- elif data_config.data_type == SupportedData.ARRAY and not isinstance(
136
- train_data, np.ndarray
196
+ # check correct input type
197
+ if (
198
+ isinstance(train_data, np.ndarray)
199
+ and data_config.data_type != SupportedData.ARRAY
137
200
  ):
138
201
  raise ValueError(
139
- f"Expected array input (see configuration.data.data_type), but got "
140
- f"{type(train_data)} instead."
202
+ f"Received a numpy array as input, but the data type was set to "
203
+ f"{data_config.data_type}. Set the data type in the configuration "
204
+ f"to {SupportedData.ARRAY} to train on numpy arrays."
141
205
  )
142
206
 
143
207
  # and that Path or str are passed, if tiff file type specified
144
- elif data_config.data_type == SupportedData.TIFF and (
145
- not isinstance(train_data, Path) and not isinstance(train_data, str)
208
+ elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
209
+ data_config.data_type != SupportedData.TIFF
210
+ and data_config.data_type != SupportedData.CUSTOM
146
211
  ):
147
212
  raise ValueError(
148
- f"Expected Path or str input (see configuration.data.data_type), "
149
- f"but got {type(train_data)} instead."
213
+ f"Received a path as input, but the data type was neither set to "
214
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
215
+ f"in the configuration to {SupportedData.TIFF} or "
216
+ f"{SupportedData.CUSTOM} to train on files."
150
217
  )
151
218
 
152
219
  # configuration
@@ -231,7 +298,15 @@ class CAREamicsWood(L.LightningDataModule):
231
298
  validate_source_target_files(self.val_files, self.val_target_files)
232
299
 
233
300
  def setup(self, *args: Any, **kwargs: Any) -> None:
234
- """Hook called at the beginning of fit, validate, or predict."""
301
+ """Hook called at the beginning of fit, validate, or predict.
302
+
303
+ Parameters
304
+ ----------
305
+ *args : Any
306
+ Unused.
307
+ **kwargs : Any
308
+ Unused.
309
+ """
235
310
  # if numpy array
236
311
  if self.data_type == SupportedData.ARRAY:
237
312
  # train dataset
@@ -266,9 +341,9 @@ class CAREamicsWood(L.LightningDataModule):
266
341
  self.train_dataset = InMemoryDataset(
267
342
  data_config=self.data_config,
268
343
  inputs=self.train_files,
269
- data_target=self.train_target_files
270
- if self.train_data_target
271
- else None,
344
+ data_target=(
345
+ self.train_target_files if self.train_data_target else None
346
+ ),
272
347
  read_source_func=self.read_source_func,
273
348
  )
274
349
 
@@ -277,9 +352,9 @@ class CAREamicsWood(L.LightningDataModule):
277
352
  self.val_dataset = InMemoryDataset(
278
353
  data_config=self.data_config,
279
354
  inputs=self.val_files,
280
- data_target=self.val_target_files
281
- if self.val_data_target
282
- else None,
355
+ data_target=(
356
+ self.val_target_files if self.val_data_target else None
357
+ ),
283
358
  read_source_func=self.read_source_func,
284
359
  )
285
360
  else:
@@ -295,9 +370,9 @@ class CAREamicsWood(L.LightningDataModule):
295
370
  self.train_dataset = PathIterableDataset(
296
371
  data_config=self.data_config,
297
372
  src_files=self.train_files,
298
- target_files=self.train_target_files
299
- if self.train_data_target
300
- else None,
373
+ target_files=(
374
+ self.train_target_files if self.train_data_target else None
375
+ ),
301
376
  read_source_func=self.read_source_func,
302
377
  )
303
378
 
@@ -307,9 +382,9 @@ class CAREamicsWood(L.LightningDataModule):
307
382
  self.val_dataset = PathIterableDataset(
308
383
  data_config=self.data_config,
309
384
  src_files=self.val_files,
310
- target_files=self.val_target_files
311
- if self.val_data_target
312
- else None,
385
+ target_files=(
386
+ self.val_target_files if self.val_data_target else None
387
+ ),
313
388
  read_source_func=self.read_source_func,
314
389
  )
315
390
  elif len(self.train_files) <= self.val_minimum_split:
@@ -353,9 +428,12 @@ class CAREamicsWood(L.LightningDataModule):
353
428
  )
354
429
 
355
430
 
356
- class CAREamicsTrainDataModule(CAREamicsWood):
431
+ class TrainingDataWrapper(CAREamicsTrainData):
357
432
  """
358
- LightningDataModule wrapper for training and validation datasets.
433
+ Wrapper around the CAREamics Lightning training data module.
434
+
435
+ This class is used to explicitely pass the parameters usually contained in a
436
+ `data_model` configuration.
359
437
 
360
438
  Since the lightning datamodule has no access to the model, make sure that the
361
439
  parameters passed to the datamodule are consistent with the model's requirements and
@@ -374,8 +452,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
374
452
  In particular, N2V requires a specific transformation (N2V manipulates), which is
375
453
  not compatible with supervised training. The default transformations applied to the
376
454
  training patches are defined in `careamics.config.data_model`. To use different
377
- transformations, pass a list of transforms or an albumentation `Compose` as
378
- `transforms` parameter. See examples for more details.
455
+ transformations, pass a list of transforms. See examples for more details.
379
456
 
380
457
  By default, CAREamics only supports types defined in
381
458
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -410,7 +487,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
410
487
  Batch size.
411
488
  val_data : Optional[Union[str, Path]], optional
412
489
  Validation data, by default None.
413
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
490
+ transforms : List[TRANSFORMS_UNION], optional
414
491
  List of transforms to apply to training patches. If None, default transforms
415
492
  are applied.
416
493
  train_target_data : Optional[Union[str, Path]], optional
@@ -442,11 +519,11 @@ class CAREamicsTrainDataModule(CAREamicsWood):
442
519
 
443
520
  Examples
444
521
  --------
445
- Create a CAREamicsTrainDataModule with default transforms with a numpy array:
522
+ Create a TrainingDataWrapper with default transforms with a numpy array:
446
523
  >>> import numpy as np
447
- >>> from careamics import CAREamicsTrainDataModule
524
+ >>> from careamics import TrainingDataWrapper
448
525
  >>> my_array = np.arange(256).reshape(16, 16)
449
- >>> data_module = CAREamicsTrainDataModule(
526
+ >>> data_module = TrainingDataWrapper(
450
527
  ... train_data=my_array,
451
528
  ... data_type="array",
452
529
  ... patch_size=(8, 8),
@@ -457,12 +534,12 @@ class CAREamicsTrainDataModule(CAREamicsWood):
457
534
  For custom data types (those not supported by CAREamics), then one can pass a read
458
535
  function and a filter for the files extension:
459
536
  >>> import numpy as np
460
- >>> from careamics import CAREamicsTrainDataModule
537
+ >>> from careamics import TrainingDataWrapper
461
538
  >>>
462
539
  >>> def read_npy(path):
463
540
  ... return np.load(path)
464
541
  >>>
465
- >>> data_module = CAREamicsTrainDataModule(
542
+ >>> data_module = TrainingDataWrapper(
466
543
  ... train_data="path/to/data",
467
544
  ... data_type="custom",
468
545
  ... patch_size=(8, 8),
@@ -475,7 +552,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
475
552
  If you want to use a different set of transformations, you can pass a list of
476
553
  transforms:
477
554
  >>> import numpy as np
478
- >>> from careamics import CAREamicsTrainDataModule
555
+ >>> from careamics import TrainingDataWrapper
479
556
  >>> from careamics.config.support import SupportedTransform
480
557
  >>> my_array = np.arange(256).reshape(16, 16)
481
558
  >>> my_transforms = [
@@ -488,7 +565,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
488
565
  ... "name": SupportedTransform.N2V_MANIPULATE.value,
489
566
  ... }
490
567
  ... ]
491
- >>> data_module = CAREamicsTrainDataModule(
568
+ >>> data_module = TrainingDataWrapper(
492
569
  ... train_data=my_array,
493
570
  ... data_type="array",
494
571
  ... patch_size=(8, 8),
@@ -506,7 +583,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
506
583
  axes: str,
507
584
  batch_size: int,
508
585
  val_data: Optional[Union[str, Path]] = None,
509
- transforms: Optional[Union[List[TRANSFORMS_UNION], Compose]] = None,
586
+ transforms: Optional[List[TRANSFORMS_UNION]] = None,
510
587
  train_target_data: Optional[Union[str, Path]] = None,
511
588
  val_target_data: Optional[Union[str, Path]] = None,
512
589
  read_source_func: Optional[Callable] = None,
@@ -539,8 +616,8 @@ class CAREamicsTrainDataModule(CAREamicsWood):
539
616
  In particular, N2V requires a specific transformation (N2V manipulates), which
540
617
  is not compatible with supervised training. The default transformations applied
541
618
  to the training patches are defined in `careamics.config.data_model`. To use
542
- different transformations, pass a list of transforms or an albumentation
543
- `Compose` as `transforms` parameter. See examples for more details.
619
+ different transformations, pass a list of transforms. See examples for more
620
+ details.
544
621
 
545
622
  By default, CAREamics only supports types defined in
546
623
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -577,7 +654,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
577
654
  Batch size.
578
655
  val_data : Optional[Union[str, Path]], optional
579
656
  Validation data, by default None.
580
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
657
+ transforms : Optional[List[TRANSFORMS_UNION]], optional
581
658
  List of transforms to apply to training patches. If None, default transforms
582
659
  are applied.
583
660
  train_target_data : Optional[Union[str, Path]], optional
@@ -628,13 +705,10 @@ class CAREamicsTrainDataModule(CAREamicsWood):
628
705
  data_dict["transforms"] = transforms
629
706
 
630
707
  # validate configuration
631
- self.data_config = DataModel(**data_dict)
708
+ self.data_config = DataConfig(**data_dict)
632
709
 
633
710
  # N2V specific checks, N2V, structN2V, and transforms
634
- if (
635
- self.data_config.has_transform_list()
636
- and self.data_config.has_n2v_manipulate()
637
- ):
711
+ if self.data_config.has_n2v_manipulate():
638
712
  # there is not target, n2v2 and structN2V can be changed
639
713
  if train_target_data is None:
640
714
  self.data_config.set_N2V2(use_n2v2)
@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
3
3
  import pytorch_lightning as L
4
4
  from torch import Tensor, nn
5
5
 
6
- from careamics.config import AlgorithmModel
6
+ from careamics.config import AlgorithmConfig
7
7
  from careamics.config.support import (
8
8
  SupportedAlgorithm,
9
9
  SupportedArchitecture,
@@ -17,7 +17,7 @@ from careamics.transforms import Denormalize, ImageRestorationTTA
17
17
  from careamics.utils.torch_utils import get_optimizer, get_scheduler
18
18
 
19
19
 
20
- class CAREamicsKiln(L.LightningModule):
20
+ class CAREamicsModule(L.LightningModule):
21
21
  """
22
22
  CAREamics Lightning module.
23
23
 
@@ -38,7 +38,7 @@ class CAREamicsKiln(L.LightningModule):
38
38
  Learning rate scheduler name.
39
39
  """
40
40
 
41
- def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None:
41
+ def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
42
42
  """
43
43
  CAREamics Lightning module.
44
44
 
@@ -53,7 +53,7 @@ class CAREamicsKiln(L.LightningModule):
53
53
  super().__init__()
54
54
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
55
55
  if isinstance(algorithm_config, dict):
56
- algorithm_config = AlgorithmModel(**algorithm_config)
56
+ algorithm_config = AlgorithmConfig(**algorithm_config)
57
57
 
58
58
  # create model and loss function
59
59
  self.model: nn.Module = model_factory(algorithm_config.model)
@@ -162,7 +162,7 @@ class CAREamicsKiln(L.LightningModule):
162
162
  mean=self._trainer.datamodule.predict_dataset.mean,
163
163
  std=self._trainer.datamodule.predict_dataset.std,
164
164
  )
165
- denormalized_output = denorm(image=output)["image"]
165
+ denormalized_output, _ = denorm(patch=output)
166
166
 
167
167
  if len(aux) > 0:
168
168
  return denormalized_output, aux
@@ -192,7 +192,7 @@ class CAREamicsKiln(L.LightningModule):
192
192
  }
193
193
 
194
194
 
195
- class CAREamicsModule(CAREamicsKiln):
195
+ class CAREamicsModuleWrapper(CAREamicsModule):
196
196
  """Class defining the API for CAREamics Lightning layer.
197
197
 
198
198
  This class exposes parameters used to create an AlgorithmModel instance, triggering
@@ -287,6 +287,6 @@ class CAREamicsModule(CAREamicsKiln):
287
287
  algorithm_configuration["model"] = model_configuration
288
288
 
289
289
  # call the parent init using an AlgorithmModel instance
290
- super().__init__(AlgorithmModel(**algorithm_configuration))
290
+ super().__init__(AlgorithmConfig(**algorithm_configuration))
291
291
 
292
292
  # TODO add load_from_checkpoint wrapper
@@ -1,13 +1,14 @@
1
+ """Prediction Lightning data modules."""
2
+
1
3
  from pathlib import Path
2
4
  from typing import Any, Callable, List, Literal, Optional, Tuple, Union
3
5
 
4
6
  import numpy as np
5
7
  import pytorch_lightning as L
6
- from albumentations import Compose
7
8
  from torch.utils.data import DataLoader
8
9
  from torch.utils.data.dataloader import default_collate
9
10
 
10
- from careamics.config import InferenceModel
11
+ from careamics.config import InferenceConfig
11
12
  from careamics.config.support import SupportedData
12
13
  from careamics.config.tile_information import TileInformation
13
14
  from careamics.dataset.dataset_utils import (
@@ -38,7 +39,7 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
38
39
 
39
40
  Parameters
40
41
  ----------
41
- batch : Tuple[Tuple[np.ndarray, TileInformation], ...]
42
+ batch : List[Tuple[np.ndarray, TileInformation], ...]
42
43
  Batch of tiles.
43
44
 
44
45
  Returns
@@ -62,9 +63,9 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
62
63
  return default_collate(new_batch)
63
64
 
64
65
 
65
- class CAREamicsClay(L.LightningDataModule):
66
+ class CAREamicsPredictData(L.LightningDataModule):
66
67
  """
67
- LightningDataModule for prediction dataset.
68
+ CAREamics Lightning prediction data module.
68
69
 
69
70
  The data module can be used with Path, str or numpy arrays. The data can be either
70
71
  a folder containing images or a single file.
@@ -79,7 +80,7 @@ class CAREamicsClay(L.LightningDataModule):
79
80
 
80
81
  Parameters
81
82
  ----------
82
- prediction_config : InferenceModel
83
+ pred_config : InferenceModel
83
84
  Pydantic model for CAREamics prediction configuration.
84
85
  pred_data : Union[Path, str, np.ndarray]
85
86
  Prediction data, can be a path to a folder, a file or a numpy array.
@@ -93,7 +94,7 @@ class CAREamicsClay(L.LightningDataModule):
93
94
 
94
95
  def __init__(
95
96
  self,
96
- prediction_config: InferenceModel,
97
+ pred_config: InferenceConfig,
97
98
  pred_data: Union[Path, str, np.ndarray],
98
99
  read_source_func: Optional[Callable] = None,
99
100
  extension_filter: str = "",
@@ -115,7 +116,7 @@ class CAREamicsClay(L.LightningDataModule):
115
116
 
116
117
  Parameters
117
118
  ----------
118
- prediction_config : InferenceModel
119
+ pred_config : InferenceModel
119
120
  Pydantic model for CAREamics prediction configuration.
120
121
  pred_data : Union[Path, str, np.ndarray]
121
122
  Prediction data, can be a path to a folder, a file or a numpy array.
@@ -142,51 +143,53 @@ class CAREamicsClay(L.LightningDataModule):
142
143
  super().__init__()
143
144
 
144
145
  # check that a read source function is provided for custom types
145
- if (
146
- prediction_config.data_type == SupportedData.CUSTOM
147
- and read_source_func is None
148
- ):
146
+ if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
149
147
  raise ValueError(
150
148
  f"Data type {SupportedData.CUSTOM} is not allowed without "
151
- f"specifying a `read_source_func`."
149
+ f"specifying a `read_source_func` and an `extension_filer`."
152
150
  )
153
151
 
154
- # and that arrays are passed, if array type specified
155
- elif prediction_config.data_type == SupportedData.ARRAY and not isinstance(
156
- pred_data, np.ndarray
152
+ # check correct input type
153
+ if (
154
+ isinstance(pred_data, np.ndarray)
155
+ and pred_config.data_type != SupportedData.ARRAY
157
156
  ):
158
157
  raise ValueError(
159
- f"Expected array input (see configuration.data.data_type), but got "
160
- f"{type(pred_data)} instead."
158
+ f"Received a numpy array as input, but the data type was set to "
159
+ f"{pred_config.data_type}. Set the data type "
160
+ f"to {SupportedData.ARRAY} to predict on numpy arrays."
161
161
  )
162
162
 
163
163
  # and that Path or str are passed, if tiff file type specified
164
- elif prediction_config.data_type == SupportedData.TIFF and not (
165
- isinstance(pred_data, Path) or isinstance(pred_data, str)
164
+ elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
165
+ pred_config.data_type != SupportedData.TIFF
166
+ and pred_config.data_type != SupportedData.CUSTOM
166
167
  ):
167
168
  raise ValueError(
168
- f"Expected Path or str input (see configuration.data.data_type), "
169
- f"but got {type(pred_data)} instead."
169
+ f"Received a path as input, but the data type was neither set to "
170
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
171
+ f" to {SupportedData.TIFF} or "
172
+ f"{SupportedData.CUSTOM} to predict on files."
170
173
  )
171
174
 
172
175
  # configuration data
173
- self.prediction_config = prediction_config
174
- self.data_type = prediction_config.data_type
175
- self.batch_size = prediction_config.batch_size
176
+ self.prediction_config = pred_config
177
+ self.data_type = pred_config.data_type
178
+ self.batch_size = pred_config.batch_size
176
179
  self.dataloader_params = dataloader_params
177
180
 
178
181
  self.pred_data = pred_data
179
- self.tile_size = prediction_config.tile_size
180
- self.tile_overlap = prediction_config.tile_overlap
182
+ self.tile_size = pred_config.tile_size
183
+ self.tile_overlap = pred_config.tile_overlap
181
184
 
182
185
  # read source function
183
- if prediction_config.data_type == SupportedData.CUSTOM:
186
+ if pred_config.data_type == SupportedData.CUSTOM:
184
187
  # mypy check
185
188
  assert read_source_func is not None
186
189
 
187
190
  self.read_source_func: Callable = read_source_func
188
- elif prediction_config.data_type != SupportedData.ARRAY:
189
- self.read_source_func = get_read_func(prediction_config.data_type)
191
+ elif pred_config.data_type != SupportedData.ARRAY:
192
+ self.read_source_func = get_read_func(pred_config.data_type)
190
193
 
191
194
  self.extension_filter = extension_filter
192
195
 
@@ -238,9 +241,12 @@ class CAREamicsClay(L.LightningDataModule):
238
241
  ) # TODO check workers are used
239
242
 
240
243
 
241
- class CAREamicsPredictDataModule(CAREamicsClay):
244
+ class PredictDataWrapper(CAREamicsPredictData):
242
245
  """
243
- LightningDataModule wrapper of an inference dataset.
246
+ Wrapper around the CAREamics inference Lightning data module.
247
+
248
+ This class is used to explicitely pass the parameters usually contained in a
249
+ `inference_model` configuration.
244
250
 
245
251
  Since the lightning datamodule has no access to the model, make sure that the
246
252
  parameters passed to the datamodule are consistent with the model's requirements
@@ -251,14 +257,13 @@ class CAREamicsPredictDataModule(CAREamicsClay):
251
257
 
252
258
  The default transformations applied to the images are defined in
253
259
  `careamics.config.inference_model`. To use different transformations, pass a list
254
- of transforms or an albumentation `Compose` as `transforms` parameter. See examples
260
+ of transforms. See examples
255
261
  for more details.
256
262
 
257
263
  The `mean` and `std` parameters are only used if Normalization is defined either
258
- in the default transformations or in the `transforms` parameter, but not with
259
- a `Compose` object. If you pass a `Normalization` transform in a list as
260
- `transforms`, then the mean and std parameters will be overwritten by those passed
261
- to this method.
264
+ in the default transformations or in the `transforms` parameter. If you pass a
265
+ `Normalization` transform in a list as `transforms`, then the mean and std
266
+ parameters will be overwritten by those passed to this method.
262
267
 
263
268
  By default, CAREamics only supports types defined in
264
269
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -270,6 +275,12 @@ class CAREamicsPredictDataModule(CAREamicsClay):
270
275
  dataloaders, except for `batch_size`, which is set by the `batch_size`
271
276
  parameter.
272
277
 
278
+ Note that if you are using a UNet model and tiling, the tile size must be
279
+ divisible in every dimension by 2**d, where d is the depth of the model. This
280
+ avoids artefacts arising from the broken shift invariance induced by the
281
+ pooling layers of the UNet. If your image has less dimensions, as it may
282
+ happen in the Z dimension, consider padding your image.
283
+
273
284
  Parameters
274
285
  ----------
275
286
  pred_data : Union[str, Path, np.ndarray]
@@ -292,7 +303,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
292
303
  Batch size.
293
304
  tta_transforms : bool, optional
294
305
  Use test time augmentation, by default True.
295
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
306
+ transforms : List, optional
296
307
  List of transforms to apply to prediction patches. If None, default
297
308
  transforms are applied.
298
309
  read_source_func : Optional[Callable], optional
@@ -315,7 +326,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
315
326
  axes: str = "YX",
316
327
  batch_size: int = 1,
317
328
  tta_transforms: bool = True,
318
- transforms: Optional[Union[List, Compose]] = None,
329
+ transforms: Optional[List] = None,
319
330
  read_source_func: Optional[Callable] = None,
320
331
  extension_filter: str = "",
321
332
  dataloader_params: Optional[dict] = None,
@@ -329,6 +340,12 @@ class CAREamicsPredictDataModule(CAREamicsClay):
329
340
  Prediction data.
330
341
  data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
331
342
  Data type, see `SupportedData` for available options.
343
+ mean : float
344
+ Mean value for normalization, only used if Normalization is defined in the
345
+ transforms.
346
+ std : float
347
+ Standard deviation value for normalization, only used if Normalization is
348
+ defined in the transform.
332
349
  tile_size : List[int]
333
350
  Tile size, 2D or 3D tile size.
334
351
  tile_overlap : List[int]
@@ -339,13 +356,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
339
356
  Batch size.
340
357
  tta_transforms : bool, optional
341
358
  Use test time augmentation, by default True.
342
- mean : Optional[float], optional
343
- Mean value for normalization, only used if Normalization is defined, by
344
- default None.
345
- std : Optional[float], optional
346
- Standard deviation value for normalization, only used if Normalization is
347
- defined, by default None.
348
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
359
+ transforms : Optional[List], optional
349
360
  List of transforms to apply to prediction patches. If None, default
350
361
  transforms are applied.
351
362
  read_source_func : Optional[Callable], optional
@@ -374,7 +385,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
374
385
  prediction_dict["transforms"] = transforms
375
386
 
376
387
  # validate configuration
377
- self.prediction_config = InferenceModel(**prediction_dict)
388
+ self.prediction_config = InferenceConfig(**prediction_dict)
378
389
 
379
390
  # sanity check on the dataloader parameters
380
391
  if "batch_size" in dataloader_params:
@@ -382,7 +393,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
382
393
  del dataloader_params["batch_size"]
383
394
 
384
395
  super().__init__(
385
- prediction_config=self.prediction_config,
396
+ pred_config=self.prediction_config,
386
397
  pred_data=pred_data,
387
398
  read_source_func=read_source_func,
388
399
  extension_filter=extension_filter,
@@ -1,6 +1,5 @@
1
1
  """Losses module."""
2
2
 
3
-
4
3
  from .loss_factory import loss_factory
5
4
 
6
5
  # from .noise_model_factory import noise_model_factory as noise_model_factory
@@ -3,6 +3,7 @@ Loss factory module.
3
3
 
4
4
  This module contains a factory function for creating loss functions.
5
5
  """
6
+
6
7
  from typing import Callable, Union
7
8
 
8
9
  from ..config.support import SupportedLoss