careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """A class to train, predict and export models in CAREamics."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload
4
+ from typing import Any, Callable, Literal, Optional, Union, overload
5
5
 
6
6
  import numpy as np
7
+ from numpy.typing import NDArray
7
8
  from pytorch_lightning import Trainer
8
9
  from pytorch_lightning.callbacks import (
9
10
  Callback,
@@ -15,56 +16,54 @@ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
16
  from careamics.callbacks import ProgressBarCallback
16
17
  from careamics.config import (
17
18
  Configuration,
18
- create_inference_configuration,
19
19
  load_configuration,
20
20
  )
21
21
  from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
22
22
  from careamics.dataset.dataset_utils import reshape_array
23
23
  from careamics.lightning_datamodule import CAREamicsTrainData
24
24
  from careamics.lightning_module import CAREamicsModule
25
- from careamics.lightning_prediction_datamodule import CAREamicsPredictData
26
- from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
27
25
  from careamics.model_io import export_to_bmz, load_pretrained
28
- from careamics.transforms import Denormalize
26
+ from careamics.prediction_utils import convert_outputs, create_pred_datamodule
29
27
  from careamics.utils import check_path_exists, get_logger
30
28
 
31
29
  from .callbacks import HyperParametersCallback
30
+ from .lightning_prediction_datamodule import CAREamicsPredictData
32
31
 
33
32
  logger = get_logger(__name__)
34
33
 
35
34
  LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
36
35
 
37
36
 
38
- # TODO napari callbacks
39
- # TODO: how to do AMP? How to continue training?
40
37
  class CAREamist:
41
38
  """Main CAREamics class, allowing training and prediction using various algorithms.
42
39
 
43
40
  Parameters
44
41
  ----------
45
- source : Union[Path, str, Configuration]
42
+ source : pathlib.Path or str or CAREamics Configuration
46
43
  Path to a configuration file or a trained model.
47
- work_dir : Optional[str], optional
44
+ work_dir : str, optional
48
45
  Path to working directory in which to save checkpoints and logs,
49
46
  by default None.
50
- experiment_name : str, optional
51
- Experiment name used for checkpoints, by default "CAREamics".
47
+ experiment_name : str, by default "CAREamics"
48
+ Experiment name used for checkpoints.
49
+ callbacks : list of Callback, optional
50
+ List of callbacks to use during training and prediction, by default None.
52
51
 
53
52
  Attributes
54
53
  ----------
55
- model : CAREamicsKiln
54
+ model : CAREamicsModule
56
55
  CAREamics model.
57
56
  cfg : Configuration
58
57
  CAREamics configuration.
59
58
  trainer : Trainer
60
59
  PyTorch Lightning trainer.
61
- experiment_logger : Optional[Union[TensorBoardLogger, WandbLogger]]
60
+ experiment_logger : TensorBoardLogger or WandbLogger
62
61
  Experiment logger, "wandb" or "tensorboard".
63
- work_dir : Path
62
+ work_dir : pathlib.Path
64
63
  Working directory.
65
- train_datamodule : Optional[CAREamicsWood]
64
+ train_datamodule : CAREamicsTrainData
66
65
  Training datamodule.
67
- pred_datamodule : Optional[CAREamicsClay]
66
+ pred_datamodule : CAREamicsPredictData
68
67
  Prediction datamodule.
69
68
  """
70
69
 
@@ -74,6 +73,7 @@ class CAREamist:
74
73
  source: Union[Path, str],
75
74
  work_dir: Optional[str] = None,
76
75
  experiment_name: str = "CAREamics",
76
+ callbacks: Optional[list[Callback]] = None,
77
77
  ) -> None: ...
78
78
 
79
79
  @overload
@@ -82,6 +82,7 @@ class CAREamist:
82
82
  source: Configuration,
83
83
  work_dir: Optional[str] = None,
84
84
  experiment_name: str = "CAREamics",
85
+ callbacks: Optional[list[Callback]] = None,
85
86
  ) -> None: ...
86
87
 
87
88
  def __init__(
@@ -89,6 +90,7 @@ class CAREamist:
89
90
  source: Union[Path, str, Configuration],
90
91
  work_dir: Optional[Union[Path, str]] = None,
91
92
  experiment_name: str = "CAREamics",
93
+ callbacks: Optional[list[Callback]] = None,
92
94
  ) -> None:
93
95
  """
94
96
  Initialize CAREamist with a configuration object or a path.
@@ -105,13 +107,15 @@ class CAREamist:
105
107
 
106
108
  Parameters
107
109
  ----------
108
- source : Union[Path, str, Configuration]
110
+ source : pathlib.Path or str or CAREamics Configuration
109
111
  Path to a configuration file or a trained model.
110
- work_dir : Optional[str], optional
112
+ work_dir : str, optional
111
113
  Path to working directory in which to save checkpoints and logs,
112
114
  by default None.
113
115
  experiment_name : str, optional
114
116
  Experiment name used for checkpoints, by default "CAREamics".
117
+ callbacks : list of Callback, optional
118
+ List of callbacks to use during training and prediction, by default None.
115
119
 
116
120
  Raises
117
121
  ------
@@ -164,7 +168,7 @@ class CAREamist:
164
168
  self.model, self.cfg = load_pretrained(source)
165
169
 
166
170
  # define the checkpoint saving callback
167
- self.callbacks = self._define_callbacks()
171
+ self._define_callbacks(callbacks)
168
172
 
169
173
  # instantiate logger
170
174
  if self.cfg.training_config.has_logger():
@@ -188,32 +192,50 @@ class CAREamist:
188
192
  logger=self.experiment_logger,
189
193
  )
190
194
 
191
- # change the prediction loop, necessary for tiled prediction
192
- self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
193
-
194
195
  # place holder for the datamodules
195
196
  self.train_datamodule: Optional[CAREamicsTrainData] = None
196
197
  self.pred_datamodule: Optional[CAREamicsPredictData] = None
197
198
 
198
- def _define_callbacks(self) -> List[Callback]:
199
+ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
199
200
  """
200
201
  Define the callbacks for the training loop.
201
202
 
202
- Returns
203
- -------
204
- List[Callback]
205
- List of callbacks to be used during training.
203
+ Parameters
204
+ ----------
205
+ callbacks : list of Callback, optional
206
+ List of callbacks to use during training and prediction, by default None.
206
207
  """
208
+ self.callbacks = [] if callbacks is None else callbacks
209
+
210
+ # check that user callbacks are not any of the CAREamics callbacks
211
+ for c in self.callbacks:
212
+ if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
213
+ raise ValueError(
214
+ "ModelCheckpoint and EarlyStopping callbacks are already defined "
215
+ "in CAREamics and should only be modified through the "
216
+ "training configuration (see TrainingConfig)."
217
+ )
218
+
219
+ if isinstance(c, HyperParametersCallback) or isinstance(
220
+ c, ProgressBarCallback
221
+ ):
222
+ raise ValueError(
223
+ "HyperParameter and ProgressBar callbacks are defined internally "
224
+ "and should not be passed as callbacks."
225
+ )
226
+
207
227
  # checkpoint callback saves checkpoints during training
208
- self.callbacks = [
209
- HyperParametersCallback(self.cfg),
210
- ModelCheckpoint(
211
- dirpath=self.work_dir / Path("checkpoints"),
212
- filename=self.cfg.experiment_name,
213
- **self.cfg.training_config.checkpoint_callback.model_dump(),
214
- ),
215
- ProgressBarCallback(),
216
- ]
228
+ self.callbacks.extend(
229
+ [
230
+ HyperParametersCallback(self.cfg),
231
+ ModelCheckpoint(
232
+ dirpath=self.work_dir / Path("checkpoints"),
233
+ filename=self.cfg.experiment_name,
234
+ **self.cfg.training_config.checkpoint_callback.model_dump(),
235
+ ),
236
+ ProgressBarCallback(),
237
+ ]
238
+ )
217
239
 
218
240
  # early stopping callback
219
241
  if self.cfg.training_config.early_stopping_callback is not None:
@@ -221,16 +243,14 @@ class CAREamist:
221
243
  EarlyStopping(self.cfg.training_config.early_stopping_callback)
222
244
  )
223
245
 
224
- return self.callbacks
225
-
226
246
  def train(
227
247
  self,
228
248
  *,
229
249
  datamodule: Optional[CAREamicsTrainData] = None,
230
- train_source: Optional[Union[Path, str, np.ndarray]] = None,
231
- val_source: Optional[Union[Path, str, np.ndarray]] = None,
232
- train_target: Optional[Union[Path, str, np.ndarray]] = None,
233
- val_target: Optional[Union[Path, str, np.ndarray]] = None,
250
+ train_source: Optional[Union[Path, str, NDArray]] = None,
251
+ val_source: Optional[Union[Path, str, NDArray]] = None,
252
+ train_target: Optional[Union[Path, str, NDArray]] = None,
253
+ val_target: Optional[Union[Path, str, NDArray]] = None,
234
254
  use_in_memory: bool = True,
235
255
  val_percentage: float = 0.1,
236
256
  val_minimum_split: int = 1,
@@ -253,15 +273,15 @@ class CAREamist:
253
273
 
254
274
  Parameters
255
275
  ----------
256
- datamodule : Optional[CAREamicsWood], optional
276
+ datamodule : CAREamicsTrainData, optional
257
277
  Datamodule to train on, by default None.
258
- train_source : Optional[Union[Path, str, np.ndarray]], optional
278
+ train_source : pathlib.Path or str or NDArray, optional
259
279
  Train source, if no datamodule is provided, by default None.
260
- val_source : Optional[Union[Path, str, np.ndarray]], optional
280
+ val_source : pathlib.Path or str or NDArray, optional
261
281
  Validation source, if no datamodule is provided, by default None.
262
- train_target : Optional[Union[Path, str, np.ndarray]], optional
282
+ train_target : pathlib.Path or str or NDArray, optional
263
283
  Train target source, if no datamodule is provided, by default None.
264
- val_target : Optional[Union[Path, str, np.ndarray]], optional
284
+ val_target : pathlib.Path or str or NDArray, optional
265
285
  Validation target source, if no datamodule is provided, by default None.
266
286
  use_in_memory : bool, optional
267
287
  Use in memory dataset if possible, by default True.
@@ -355,7 +375,7 @@ class CAREamist:
355
375
 
356
376
  else:
357
377
  raise ValueError(
358
- f"Invalid input, expected a str, Path, array or CAREamicsWood "
378
+ f"Invalid input, expected a str, Path, array or CAREamicsTrainData "
359
379
  f"instance (got {type(train_source)})."
360
380
  )
361
381
 
@@ -365,7 +385,7 @@ class CAREamist:
365
385
 
366
386
  Parameters
367
387
  ----------
368
- datamodule : CAREamicsWood
388
+ datamodule : CAREamicsTrainData
369
389
  Datamodule to train on.
370
390
  """
371
391
  # record datamodule
@@ -375,10 +395,10 @@ class CAREamist:
375
395
 
376
396
  def _train_on_array(
377
397
  self,
378
- train_data: np.ndarray,
379
- val_data: Optional[np.ndarray] = None,
380
- train_target: Optional[np.ndarray] = None,
381
- val_target: Optional[np.ndarray] = None,
398
+ train_data: NDArray,
399
+ val_data: Optional[NDArray] = None,
400
+ train_target: Optional[NDArray] = None,
401
+ val_target: Optional[NDArray] = None,
382
402
  val_percentage: float = 0.1,
383
403
  val_minimum_split: int = 5,
384
404
  ) -> None:
@@ -387,13 +407,13 @@ class CAREamist:
387
407
 
388
408
  Parameters
389
409
  ----------
390
- train_data : np.ndarray
410
+ train_data : NDArray
391
411
  Training data.
392
- val_data : Optional[np.ndarray], optional
412
+ val_data : NDArray, optional
393
413
  Validation data, by default None.
394
- train_target : Optional[np.ndarray], optional
414
+ train_target : NDArray, optional
395
415
  Train target data, by default None.
396
- val_target : Optional[np.ndarray], optional
416
+ val_target : NDArray, optional
397
417
  Validation target data, by default None.
398
418
  val_percentage : float, optional
399
419
  Percentage of patches to use for validation, by default 0.1.
@@ -429,13 +449,13 @@ class CAREamist:
429
449
 
430
450
  Parameters
431
451
  ----------
432
- path_to_train_data : Union[Path, str]
452
+ path_to_train_data : pathlib.Path or str
433
453
  Path to the training data.
434
- path_to_val_data : Optional[Union[Path, str]], optional
454
+ path_to_val_data : pathlib.Path or str, optional
435
455
  Path to validation data, by default None.
436
- path_to_train_target : Optional[Union[Path, str]], optional
456
+ path_to_train_target : pathlib.Path or str, optional
437
457
  Path to train target data, by default None.
438
- path_to_val_target : Optional[Union[Path, str]], optional
458
+ path_to_val_target : pathlib.Path or str, optional
439
459
  Path to validation target data, by default None.
440
460
  use_in_memory : bool, optional
441
461
  Use in memory dataset if possible, by default True.
@@ -477,7 +497,7 @@ class CAREamist:
477
497
  source: CAREamicsPredictData,
478
498
  *,
479
499
  checkpoint: Optional[Literal["best", "last"]] = None,
480
- ) -> Union[list, np.ndarray]: ...
500
+ ) -> Union[list[NDArray], NDArray]: ...
481
501
 
482
502
  @overload
483
503
  def predict( # numpydoc ignore=GL08
@@ -485,59 +505,62 @@ class CAREamist:
485
505
  source: Union[Path, str],
486
506
  *,
487
507
  batch_size: int = 1,
488
- tile_size: Optional[Tuple[int, ...]] = None,
489
- tile_overlap: Tuple[int, ...] = (48, 48),
508
+ tile_size: Optional[tuple[int, ...]] = None,
509
+ tile_overlap: tuple[int, ...] = (48, 48),
490
510
  axes: Optional[str] = None,
491
511
  data_type: Optional[Literal["tiff", "custom"]] = None,
492
512
  tta_transforms: bool = True,
493
- dataloader_params: Optional[Dict] = None,
513
+ dataloader_params: Optional[dict] = None,
494
514
  read_source_func: Optional[Callable] = None,
495
515
  extension_filter: str = "",
496
516
  checkpoint: Optional[Literal["best", "last"]] = None,
497
- ) -> Union[list, np.ndarray]: ...
517
+ ) -> Union[list[NDArray], NDArray]: ...
498
518
 
499
519
  @overload
500
520
  def predict( # numpydoc ignore=GL08
501
521
  self,
502
- source: np.ndarray,
522
+ source: NDArray,
503
523
  *,
504
524
  batch_size: int = 1,
505
- tile_size: Optional[Tuple[int, ...]] = None,
506
- tile_overlap: Tuple[int, ...] = (48, 48),
525
+ tile_size: Optional[tuple[int, ...]] = None,
526
+ tile_overlap: tuple[int, ...] = (48, 48),
507
527
  axes: Optional[str] = None,
508
528
  data_type: Optional[Literal["array"]] = None,
509
529
  tta_transforms: bool = True,
510
- dataloader_params: Optional[Dict] = None,
530
+ dataloader_params: Optional[dict] = None,
511
531
  checkpoint: Optional[Literal["best", "last"]] = None,
512
- ) -> Union[list, np.ndarray]: ...
532
+ ) -> Union[list[NDArray], NDArray]: ...
513
533
 
514
534
  def predict(
515
535
  self,
516
- source: Union[CAREamicsPredictData, Path, str, np.ndarray],
536
+ source: Union[CAREamicsPredictData, Path, str, NDArray],
517
537
  *,
518
- batch_size: int = 1,
519
- tile_size: Optional[Tuple[int, ...]] = None,
520
- tile_overlap: Tuple[int, ...] = (48, 48),
538
+ batch_size: Optional[int] = None,
539
+ tile_size: Optional[tuple[int, ...]] = None,
540
+ tile_overlap: tuple[int, ...] = (48, 48),
521
541
  axes: Optional[str] = None,
522
542
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
523
543
  tta_transforms: bool = True,
524
- dataloader_params: Optional[Dict] = None,
544
+ dataloader_params: Optional[dict] = None,
525
545
  read_source_func: Optional[Callable] = None,
526
546
  extension_filter: str = "",
527
547
  checkpoint: Optional[Literal["best", "last"]] = None,
528
548
  **kwargs: Any,
529
- ) -> Union[List[np.ndarray], np.ndarray]:
549
+ ) -> Union[list[NDArray], NDArray]:
530
550
  """
531
551
  Make predictions on the provided data.
532
552
 
533
- Input can be a CAREamicsClay instance, a path to a data file, or a numpy array.
553
+ Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
554
+ array.
534
555
 
535
556
  If `data_type`, `axes` and `tile_size` are not provided, the training
536
557
  configuration parameters will be used, with the `patch_size` instead of
537
558
  `tile_size`.
538
559
 
539
560
  Test-time augmentation (TTA) can be switched off using the `tta_transforms`
540
- parameter.
561
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
562
+ rotations to the prediction input and averages the predictions. TTA augmentation
563
+ should not be used if you did not train with these augmentations.
541
564
 
542
565
  Note that if you are using a UNet model and tiling, the tile size must be
543
566
  divisible in every dimension by 2**d, where d is the depth of the model. This
@@ -547,221 +570,96 @@ class CAREamist:
547
570
 
548
571
  Parameters
549
572
  ----------
550
- source : Union[CAREamicsClay, Path, str, np.ndarray]
573
+ source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
551
574
  Data to predict on.
552
- batch_size : int, optional
553
- Batch size for prediction, by default 1.
554
- tile_size : Optional[Tuple[int, ...]], optional
555
- Size of the tiles to use for prediction, by default None.
556
- tile_overlap : Tuple[int, ...], optional
557
- Overlap between tiles, by default (48, 48).
558
- axes : Optional[str], optional
575
+ batch_size : int, default=1
576
+ Batch size for prediction.
577
+ tile_size : tuple of int, optional
578
+ Size of the tiles to use for prediction.
579
+ tile_overlap : tuple of int, default=(48, 48)
580
+ Overlap between tiles.
581
+ axes : str, optional
559
582
  Axes of the input data, by default None.
560
- data_type : Optional[Literal["array", "tiff", "custom"]], optional
561
- Type of the input data, by default None.
562
- tta_transforms : bool, optional
563
- Whether to apply test-time augmentation, by default True.
564
- dataloader_params : Optional[Dict], optional
565
- Parameters to pass to the dataloader, by default None.
566
- read_source_func : Optional[Callable], optional
567
- Function to read the source data, by default None.
568
- extension_filter : str, optional
569
- Filter for the file extension, by default "".
570
- checkpoint : Optional[Literal["best", "last"]], optional
571
- Checkpoint to use for prediction, by default None.
583
+ data_type : {"array", "tiff", "custom"}, optional
584
+ Type of the input data.
585
+ tta_transforms : bool, default=True
586
+ Whether to apply test-time augmentation.
587
+ dataloader_params : dict, optional
588
+ Parameters to pass to the dataloader.
589
+ read_source_func : Callable, optional
590
+ Function to read the source data.
591
+ extension_filter : str, default=""
592
+ Filter for the file extension.
593
+ checkpoint : {"best", "last"}, optional
594
+ Checkpoint to use for prediction.
572
595
  **kwargs : Any
573
596
  Unused.
574
597
 
575
598
  Returns
576
599
  -------
577
- Union[List[np.ndarray], np.ndarray]
600
+ list of NDArray or NDArray
578
601
  Predictions made by the model.
579
-
580
- Raises
581
- ------
582
- ValueError
583
- If the input is not a CAREamicsClay instance, a path or a numpy array.
584
602
  """
585
- if isinstance(source, CAREamicsPredictData):
586
- # record datamodule
587
- self.pred_datamodule = source
588
-
589
- return self.trainer.predict(
590
- model=self.model, datamodule=source, ckpt_path=checkpoint
603
+ # Reuse batch size if not provided explicitly
604
+ if batch_size is None:
605
+ batch_size = (
606
+ self.train_datamodule.batch_size
607
+ if self.train_datamodule
608
+ else self.cfg.data_config.batch_size
591
609
  )
592
- else:
593
- if self.cfg is None:
594
- raise ValueError(
595
- "No configuration found. Train a model or load from a "
596
- "checkpoint before predicting."
597
- )
598
- # create predict config, reuse training config if parameters missing
599
- prediction_config = create_inference_configuration(
600
- configuration=self.cfg,
601
- tile_size=tile_size,
602
- tile_overlap=tile_overlap,
603
- data_type=data_type,
604
- axes=axes,
605
- tta_transforms=tta_transforms,
606
- batch_size=batch_size,
607
- )
608
-
609
- # remove batch from dataloader parameters (priority given to config)
610
- if dataloader_params is None:
611
- dataloader_params = {}
612
- if "batch_size" in dataloader_params:
613
- del dataloader_params["batch_size"]
614
-
615
- if isinstance(source, Path) or isinstance(source, str):
616
- # Check the source
617
- source_path = check_path_exists(source)
618
-
619
- # create datamodule
620
- datamodule = CAREamicsPredictData(
621
- pred_config=prediction_config,
622
- pred_data=source_path,
623
- read_source_func=read_source_func,
624
- extension_filter=extension_filter,
625
- dataloader_params=dataloader_params,
626
- )
627
-
628
- # record datamodule
629
- self.pred_datamodule = datamodule
630
-
631
- return self.trainer.predict(
632
- model=self.model, datamodule=datamodule, ckpt_path=checkpoint
633
- )
634
-
635
- elif isinstance(source, np.ndarray):
636
- # create datamodule
637
- datamodule = CAREamicsPredictData(
638
- pred_config=prediction_config,
639
- pred_data=source,
640
- dataloader_params=dataloader_params,
641
- )
642
-
643
- # record datamodule
644
- self.pred_datamodule = datamodule
645
-
646
- return self.trainer.predict(
647
- model=self.model, datamodule=datamodule, ckpt_path=checkpoint
648
- )
649
-
650
- else:
651
- raise ValueError(
652
- f"Invalid input. Expected a CAREamicsWood instance, paths or "
653
- f"np.ndarray (got {type(source)})."
654
- )
655
-
656
- def _create_data_for_bmz(
657
- self,
658
- input_array: Optional[np.ndarray] = None,
659
- ) -> np.ndarray:
660
- """Create data for BMZ export.
661
-
662
- If no `input_array` is provided, this method checks if there is a prediction
663
- datamodule, or a training data module, to extract a patch. If none exists,
664
- then a random aray is created.
665
-
666
- If there is a non-singleton batch dimension, this method returns only the first
667
- element.
668
-
669
- Parameters
670
- ----------
671
- input_array : Optional[np.ndarray], optional
672
- Input array, by default None.
673
-
674
- Returns
675
- -------
676
- np.ndarray
677
- Input data for BMZ export.
678
-
679
- Raises
680
- ------
681
- ValueError
682
- If mean and std are not provided in the configuration.
683
- """
684
- if input_array is None:
685
- if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
686
- raise ValueError(
687
- "Mean and std cannot be None in the configuration in order to"
688
- "export to the BMZ format. Was the model trained?"
689
- )
690
-
691
- # generate images, priority is given to the prediction data module
692
- if self.pred_datamodule is not None:
693
- # unpack a batch, ignore masks or targets
694
- input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
695
-
696
- # convert torch.Tensor to numpy
697
- input_patch = input_patch.numpy()
698
-
699
- # denormalize
700
- denormalize = Denormalize(
701
- mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
702
- )
703
- input_patch, _ = denormalize(input_patch)
704
-
705
- elif self.train_datamodule is not None:
706
- input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
707
- input_patch = input_patch.numpy()
708
-
709
- # denormalize
710
- denormalize = Denormalize(
711
- mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
712
- )
713
- input_patch, _ = denormalize(input_patch)
714
- else:
715
- # create a random input array
716
- input_patch = np.random.normal(
717
- loc=self.cfg.data_config.mean,
718
- scale=self.cfg.data_config.std,
719
- size=self.cfg.data_config.patch_size,
720
- ).astype(np.float32)[
721
- np.newaxis, np.newaxis, ...
722
- ] # add S & C dimensions
723
- else:
724
- # potentially correct shape
725
- input_patch = reshape_array(input_array, self.cfg.data_config.axes)
726
610
 
727
- # if this a batch
728
- if input_patch.shape[0] > 1:
729
- input_patch = input_patch[[0], ...] # keep singleton dim
611
+ self.pred_datamodule = create_pred_datamodule(
612
+ source=source,
613
+ config=self.cfg,
614
+ batch_size=batch_size,
615
+ tile_size=tile_size,
616
+ tile_overlap=tile_overlap,
617
+ axes=axes,
618
+ data_type=data_type,
619
+ tta_transforms=tta_transforms,
620
+ dataloader_params=dataloader_params,
621
+ read_source_func=read_source_func,
622
+ extension_filter=extension_filter,
623
+ )
730
624
 
731
- return input_patch
625
+ predictions = self.trainer.predict(
626
+ model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint
627
+ )
628
+ return convert_outputs(predictions, self.pred_datamodule.tiled)
732
629
 
733
630
  def export_to_bmz(
734
631
  self,
735
632
  path: Union[Path, str],
736
633
  name: str,
737
- authors: List[dict],
738
- input_array: Optional[np.ndarray] = None,
634
+ input_array: NDArray,
635
+ authors: list[dict],
739
636
  general_description: str = "",
740
- channel_names: Optional[List[str]] = None,
637
+ channel_names: Optional[list[str]] = None,
741
638
  data_description: Optional[str] = None,
742
639
  ) -> None:
743
640
  """Export the model to the BioImage Model Zoo format.
744
641
 
745
- Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
642
+ Input array must be of the same dimensions as the axes recorded in the
643
+ configuration of the `CAREamist`.
746
644
 
747
645
  Parameters
748
646
  ----------
749
- path : Union[Path, str]
647
+ path : pathlib.Path or str
750
648
  Path to save the model.
751
649
  name : str
752
650
  Name of the model.
753
- authors : List[dict]
651
+ input_array : NDArray
652
+ Input array used to validate the model and as example.
653
+ authors : list of dict
754
654
  List of authors of the model.
755
- input_array : Optional[np.ndarray], optional
756
- Input array for the model, must be of shape SC(Z)YX, by default None.
757
655
  general_description : str
758
656
  General description of the model, used in the metadata of the BMZ archive.
759
- channel_names : Optional[List[str]], optional
657
+ channel_names : list of str, optional
760
658
  Channel names, by default None.
761
- data_description : Optional[str], optional
659
+ data_description : str, optional
762
660
  Description of the data, by default None.
763
661
  """
764
- input_patch = self._create_data_for_bmz(input_array)
662
+ input_patch = reshape_array(input_array, self.cfg.data_config.axes)
765
663
 
766
664
  # axes need to be reformated for the export because reshaping was done in the
767
665
  # datamodule
@@ -778,11 +676,10 @@ class CAREamist:
778
676
  tta_transforms=False,
779
677
  )
780
678
 
781
- if not isinstance(output_patch, np.ndarray):
782
- raise ValueError(
783
- f"Numpy array required for export to BioImage Model Zoo, got "
784
- f"{type(output_patch)}."
785
- )
679
+ if isinstance(output_patch, list):
680
+ output = np.concatenate(output_patch, axis=0)
681
+ else:
682
+ output = output_patch
786
683
 
787
684
  export_to_bmz(
788
685
  model=self.model,
@@ -792,7 +689,7 @@ class CAREamist:
792
689
  general_description=general_description,
793
690
  authors=authors,
794
691
  input_array=input_patch,
795
- output_array=output_patch,
692
+ output_array=output,
796
693
  channel_names=channel_names,
797
694
  data_description=data_description,
798
695
  )