careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """A class to train, predict and export models in CAREamics."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload
4
+ from typing import Any, Callable, Literal, Optional, Union, overload
5
5
 
6
6
  import numpy as np
7
+ from numpy.typing import NDArray
7
8
  from pytorch_lightning import Trainer
8
9
  from pytorch_lightning.callbacks import (
9
10
  Callback,
@@ -15,55 +16,54 @@ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
16
  from careamics.callbacks import ProgressBarCallback
16
17
  from careamics.config import (
17
18
  Configuration,
18
- create_inference_configuration,
19
19
  load_configuration,
20
20
  )
21
- from careamics.config.inference_model import TRANSFORMS_UNION
22
21
  from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
22
+ from careamics.dataset.dataset_utils import reshape_array
23
23
  from careamics.lightning_datamodule import CAREamicsTrainData
24
24
  from careamics.lightning_module import CAREamicsModule
25
- from careamics.lightning_prediction_datamodule import CAREamicsPredictData
26
- from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
27
25
  from careamics.model_io import export_to_bmz, load_pretrained
26
+ from careamics.prediction_utils import convert_outputs, create_pred_datamodule
28
27
  from careamics.utils import check_path_exists, get_logger
29
28
 
30
29
  from .callbacks import HyperParametersCallback
30
+ from .lightning_prediction_datamodule import CAREamicsPredictData
31
31
 
32
32
  logger = get_logger(__name__)
33
33
 
34
34
  LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
35
35
 
36
36
 
37
- # TODO napari callbacks
38
- # TODO: how to do AMP? How to continue training?
39
37
  class CAREamist:
40
38
  """Main CAREamics class, allowing training and prediction using various algorithms.
41
39
 
42
40
  Parameters
43
41
  ----------
44
- source : Union[Path, str, Configuration]
42
+ source : pathlib.Path or str or CAREamics Configuration
45
43
  Path to a configuration file or a trained model.
46
- work_dir : Optional[str], optional
44
+ work_dir : str, optional
47
45
  Path to working directory in which to save checkpoints and logs,
48
46
  by default None.
49
- experiment_name : str, optional
50
- Experiment name used for checkpoints, by default "CAREamics".
47
+ experiment_name : str, by default "CAREamics"
48
+ Experiment name used for checkpoints.
49
+ callbacks : list of Callback, optional
50
+ List of callbacks to use during training and prediction, by default None.
51
51
 
52
52
  Attributes
53
53
  ----------
54
- model : CAREamicsKiln
54
+ model : CAREamicsModule
55
55
  CAREamics model.
56
56
  cfg : Configuration
57
57
  CAREamics configuration.
58
58
  trainer : Trainer
59
59
  PyTorch Lightning trainer.
60
- experiment_logger : Optional[Union[TensorBoardLogger, WandbLogger]]
60
+ experiment_logger : TensorBoardLogger or WandbLogger
61
61
  Experiment logger, "wandb" or "tensorboard".
62
- work_dir : Path
62
+ work_dir : pathlib.Path
63
63
  Working directory.
64
- train_datamodule : Optional[CAREamicsWood]
64
+ train_datamodule : CAREamicsTrainData
65
65
  Training datamodule.
66
- pred_datamodule : Optional[CAREamicsClay]
66
+ pred_datamodule : CAREamicsPredictData
67
67
  Prediction datamodule.
68
68
  """
69
69
 
@@ -73,6 +73,7 @@ class CAREamist:
73
73
  source: Union[Path, str],
74
74
  work_dir: Optional[str] = None,
75
75
  experiment_name: str = "CAREamics",
76
+ callbacks: Optional[list[Callback]] = None,
76
77
  ) -> None: ...
77
78
 
78
79
  @overload
@@ -81,6 +82,7 @@ class CAREamist:
81
82
  source: Configuration,
82
83
  work_dir: Optional[str] = None,
83
84
  experiment_name: str = "CAREamics",
85
+ callbacks: Optional[list[Callback]] = None,
84
86
  ) -> None: ...
85
87
 
86
88
  def __init__(
@@ -88,6 +90,7 @@ class CAREamist:
88
90
  source: Union[Path, str, Configuration],
89
91
  work_dir: Optional[Union[Path, str]] = None,
90
92
  experiment_name: str = "CAREamics",
93
+ callbacks: Optional[list[Callback]] = None,
91
94
  ) -> None:
92
95
  """
93
96
  Initialize CAREamist with a configuration object or a path.
@@ -104,13 +107,15 @@ class CAREamist:
104
107
 
105
108
  Parameters
106
109
  ----------
107
- source : Union[Path, str, Configuration]
110
+ source : pathlib.Path or str or CAREamics Configuration
108
111
  Path to a configuration file or a trained model.
109
- work_dir : Optional[str], optional
112
+ work_dir : str, optional
110
113
  Path to working directory in which to save checkpoints and logs,
111
114
  by default None.
112
115
  experiment_name : str, optional
113
116
  Experiment name used for checkpoints, by default "CAREamics".
117
+ callbacks : list of Callback, optional
118
+ List of callbacks to use during training and prediction, by default None.
114
119
 
115
120
  Raises
116
121
  ------
@@ -163,7 +168,7 @@ class CAREamist:
163
168
  self.model, self.cfg = load_pretrained(source)
164
169
 
165
170
  # define the checkpoint saving callback
166
- self.callbacks = self._define_callbacks()
171
+ self._define_callbacks(callbacks)
167
172
 
168
173
  # instantiate logger
169
174
  if self.cfg.training_config.has_logger():
@@ -187,32 +192,50 @@ class CAREamist:
187
192
  logger=self.experiment_logger,
188
193
  )
189
194
 
190
- # change the prediction loop, necessary for tiled prediction
191
- self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
192
-
193
195
  # place holder for the datamodules
194
196
  self.train_datamodule: Optional[CAREamicsTrainData] = None
195
197
  self.pred_datamodule: Optional[CAREamicsPredictData] = None
196
198
 
197
- def _define_callbacks(self) -> List[Callback]:
199
+ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
198
200
  """
199
201
  Define the callbacks for the training loop.
200
202
 
201
- Returns
202
- -------
203
- List[Callback]
204
- List of callbacks to be used during training.
203
+ Parameters
204
+ ----------
205
+ callbacks : list of Callback, optional
206
+ List of callbacks to use during training and prediction, by default None.
205
207
  """
208
+ self.callbacks = [] if callbacks is None else callbacks
209
+
210
+ # check that user callbacks are not any of the CAREamics callbacks
211
+ for c in self.callbacks:
212
+ if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
213
+ raise ValueError(
214
+ "ModelCheckpoint and EarlyStopping callbacks are already defined "
215
+ "in CAREamics and should only be modified through the "
216
+ "training configuration (see TrainingConfig)."
217
+ )
218
+
219
+ if isinstance(c, HyperParametersCallback) or isinstance(
220
+ c, ProgressBarCallback
221
+ ):
222
+ raise ValueError(
223
+ "HyperParameter and ProgressBar callbacks are defined internally "
224
+ "and should not be passed as callbacks."
225
+ )
226
+
206
227
  # checkpoint callback saves checkpoints during training
207
- self.callbacks = [
208
- HyperParametersCallback(self.cfg),
209
- ModelCheckpoint(
210
- dirpath=self.work_dir / Path("checkpoints"),
211
- filename=self.cfg.experiment_name,
212
- **self.cfg.training_config.checkpoint_callback.model_dump(),
213
- ),
214
- ProgressBarCallback(),
215
- ]
228
+ self.callbacks.extend(
229
+ [
230
+ HyperParametersCallback(self.cfg),
231
+ ModelCheckpoint(
232
+ dirpath=self.work_dir / Path("checkpoints"),
233
+ filename=self.cfg.experiment_name,
234
+ **self.cfg.training_config.checkpoint_callback.model_dump(),
235
+ ),
236
+ ProgressBarCallback(),
237
+ ]
238
+ )
216
239
 
217
240
  # early stopping callback
218
241
  if self.cfg.training_config.early_stopping_callback is not None:
@@ -220,16 +243,14 @@ class CAREamist:
220
243
  EarlyStopping(self.cfg.training_config.early_stopping_callback)
221
244
  )
222
245
 
223
- return self.callbacks
224
-
225
246
  def train(
226
247
  self,
227
248
  *,
228
249
  datamodule: Optional[CAREamicsTrainData] = None,
229
- train_source: Optional[Union[Path, str, np.ndarray]] = None,
230
- val_source: Optional[Union[Path, str, np.ndarray]] = None,
231
- train_target: Optional[Union[Path, str, np.ndarray]] = None,
232
- val_target: Optional[Union[Path, str, np.ndarray]] = None,
250
+ train_source: Optional[Union[Path, str, NDArray]] = None,
251
+ val_source: Optional[Union[Path, str, NDArray]] = None,
252
+ train_target: Optional[Union[Path, str, NDArray]] = None,
253
+ val_target: Optional[Union[Path, str, NDArray]] = None,
233
254
  use_in_memory: bool = True,
234
255
  val_percentage: float = 0.1,
235
256
  val_minimum_split: int = 1,
@@ -252,15 +273,15 @@ class CAREamist:
252
273
 
253
274
  Parameters
254
275
  ----------
255
- datamodule : Optional[CAREamicsWood], optional
276
+ datamodule : CAREamicsTrainData, optional
256
277
  Datamodule to train on, by default None.
257
- train_source : Optional[Union[Path, str, np.ndarray]], optional
278
+ train_source : pathlib.Path or str or NDArray, optional
258
279
  Train source, if no datamodule is provided, by default None.
259
- val_source : Optional[Union[Path, str, np.ndarray]], optional
280
+ val_source : pathlib.Path or str or NDArray, optional
260
281
  Validation source, if no datamodule is provided, by default None.
261
- train_target : Optional[Union[Path, str, np.ndarray]], optional
282
+ train_target : pathlib.Path or str or NDArray, optional
262
283
  Train target source, if no datamodule is provided, by default None.
263
- val_target : Optional[Union[Path, str, np.ndarray]], optional
284
+ val_target : pathlib.Path or str or NDArray, optional
264
285
  Validation target source, if no datamodule is provided, by default None.
265
286
  use_in_memory : bool, optional
266
287
  Use in memory dataset if possible, by default True.
@@ -354,7 +375,7 @@ class CAREamist:
354
375
 
355
376
  else:
356
377
  raise ValueError(
357
- f"Invalid input, expected a str, Path, array or CAREamicsWood "
378
+ f"Invalid input, expected a str, Path, array or CAREamicsTrainData "
358
379
  f"instance (got {type(train_source)})."
359
380
  )
360
381
 
@@ -364,7 +385,7 @@ class CAREamist:
364
385
 
365
386
  Parameters
366
387
  ----------
367
- datamodule : CAREamicsWood
388
+ datamodule : CAREamicsTrainData
368
389
  Datamodule to train on.
369
390
  """
370
391
  # record datamodule
@@ -374,10 +395,10 @@ class CAREamist:
374
395
 
375
396
  def _train_on_array(
376
397
  self,
377
- train_data: np.ndarray,
378
- val_data: Optional[np.ndarray] = None,
379
- train_target: Optional[np.ndarray] = None,
380
- val_target: Optional[np.ndarray] = None,
398
+ train_data: NDArray,
399
+ val_data: Optional[NDArray] = None,
400
+ train_target: Optional[NDArray] = None,
401
+ val_target: Optional[NDArray] = None,
381
402
  val_percentage: float = 0.1,
382
403
  val_minimum_split: int = 5,
383
404
  ) -> None:
@@ -386,13 +407,13 @@ class CAREamist:
386
407
 
387
408
  Parameters
388
409
  ----------
389
- train_data : np.ndarray
410
+ train_data : NDArray
390
411
  Training data.
391
- val_data : Optional[np.ndarray], optional
412
+ val_data : NDArray, optional
392
413
  Validation data, by default None.
393
- train_target : Optional[np.ndarray], optional
414
+ train_target : NDArray, optional
394
415
  Train target data, by default None.
395
- val_target : Optional[np.ndarray], optional
416
+ val_target : NDArray, optional
396
417
  Validation target data, by default None.
397
418
  val_percentage : float, optional
398
419
  Percentage of patches to use for validation, by default 0.1.
@@ -428,13 +449,13 @@ class CAREamist:
428
449
 
429
450
  Parameters
430
451
  ----------
431
- path_to_train_data : Union[Path, str]
452
+ path_to_train_data : pathlib.Path or str
432
453
  Path to the training data.
433
- path_to_val_data : Optional[Union[Path, str]], optional
454
+ path_to_val_data : pathlib.Path or str, optional
434
455
  Path to validation data, by default None.
435
- path_to_train_target : Optional[Union[Path, str]], optional
456
+ path_to_train_target : pathlib.Path or str, optional
436
457
  Path to train target data, by default None.
437
- path_to_val_target : Optional[Union[Path, str]], optional
458
+ path_to_val_target : pathlib.Path or str, optional
438
459
  Path to validation target data, by default None.
439
460
  use_in_memory : bool, optional
440
461
  Use in memory dataset if possible, by default True.
@@ -476,7 +497,7 @@ class CAREamist:
476
497
  source: CAREamicsPredictData,
477
498
  *,
478
499
  checkpoint: Optional[Literal["best", "last"]] = None,
479
- ) -> Union[list, np.ndarray]: ...
500
+ ) -> Union[list[NDArray], NDArray]: ...
480
501
 
481
502
  @overload
482
503
  def predict( # numpydoc ignore=GL08
@@ -484,64 +505,62 @@ class CAREamist:
484
505
  source: Union[Path, str],
485
506
  *,
486
507
  batch_size: int = 1,
487
- tile_size: Optional[Tuple[int, ...]] = None,
488
- tile_overlap: Tuple[int, ...] = (48, 48),
508
+ tile_size: Optional[tuple[int, ...]] = None,
509
+ tile_overlap: tuple[int, ...] = (48, 48),
489
510
  axes: Optional[str] = None,
490
511
  data_type: Optional[Literal["tiff", "custom"]] = None,
491
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
492
512
  tta_transforms: bool = True,
493
- dataloader_params: Optional[Dict] = None,
513
+ dataloader_params: Optional[dict] = None,
494
514
  read_source_func: Optional[Callable] = None,
495
515
  extension_filter: str = "",
496
516
  checkpoint: Optional[Literal["best", "last"]] = None,
497
- ) -> Union[list, np.ndarray]: ...
517
+ ) -> Union[list[NDArray], NDArray]: ...
498
518
 
499
519
  @overload
500
520
  def predict( # numpydoc ignore=GL08
501
521
  self,
502
- source: np.ndarray,
522
+ source: NDArray,
503
523
  *,
504
524
  batch_size: int = 1,
505
- tile_size: Optional[Tuple[int, ...]] = None,
506
- tile_overlap: Tuple[int, ...] = (48, 48),
525
+ tile_size: Optional[tuple[int, ...]] = None,
526
+ tile_overlap: tuple[int, ...] = (48, 48),
507
527
  axes: Optional[str] = None,
508
528
  data_type: Optional[Literal["array"]] = None,
509
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
510
529
  tta_transforms: bool = True,
511
- dataloader_params: Optional[Dict] = None,
530
+ dataloader_params: Optional[dict] = None,
512
531
  checkpoint: Optional[Literal["best", "last"]] = None,
513
- ) -> Union[list, np.ndarray]: ...
532
+ ) -> Union[list[NDArray], NDArray]: ...
514
533
 
515
534
  def predict(
516
535
  self,
517
- source: Union[CAREamicsPredictData, Path, str, np.ndarray],
536
+ source: Union[CAREamicsPredictData, Path, str, NDArray],
518
537
  *,
519
- batch_size: int = 1,
520
- tile_size: Optional[Tuple[int, ...]] = None,
521
- tile_overlap: Tuple[int, ...] = (48, 48),
538
+ batch_size: Optional[int] = None,
539
+ tile_size: Optional[tuple[int, ...]] = None,
540
+ tile_overlap: tuple[int, ...] = (48, 48),
522
541
  axes: Optional[str] = None,
523
542
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
524
- transforms: Optional[List[TRANSFORMS_UNION]] = None,
525
543
  tta_transforms: bool = True,
526
- dataloader_params: Optional[Dict] = None,
544
+ dataloader_params: Optional[dict] = None,
527
545
  read_source_func: Optional[Callable] = None,
528
546
  extension_filter: str = "",
529
547
  checkpoint: Optional[Literal["best", "last"]] = None,
530
548
  **kwargs: Any,
531
- ) -> Union[List[np.ndarray], np.ndarray]:
549
+ ) -> Union[list[NDArray], NDArray]:
532
550
  """
533
551
  Make predictions on the provided data.
534
552
 
535
- Input can be a CAREamicsClay instance, a path to a data file, or a numpy array.
553
+ Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
554
+ array.
536
555
 
537
556
  If `data_type`, `axes` and `tile_size` are not provided, the training
538
557
  configuration parameters will be used, with the `patch_size` instead of
539
558
  `tile_size`.
540
559
 
541
- The default transforms are defined in the `InferenceModel` Pydantic model.
542
-
543
560
  Test-time augmentation (TTA) can be switched off using the `tta_transforms`
544
- parameter.
561
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
562
+ rotations to the prediction input and averages the predictions. TTA augmentation
563
+ should not be used if you did not train with these augmentations.
545
564
 
546
565
  Note that if you are using a UNet model and tiling, the tile size must be
547
566
  divisible in every dimension by 2**d, where d is the depth of the model. This
@@ -551,181 +570,96 @@ class CAREamist:
551
570
 
552
571
  Parameters
553
572
  ----------
554
- source : Union[CAREamicsClay, Path, str, np.ndarray]
573
+ source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
555
574
  Data to predict on.
556
- batch_size : int, optional
557
- Batch size for prediction, by default 1.
558
- tile_size : Optional[Tuple[int, ...]], optional
559
- Size of the tiles to use for prediction, by default None.
560
- tile_overlap : Tuple[int, ...], optional
561
- Overlap between tiles, by default (48, 48).
562
- axes : Optional[str], optional
575
+ batch_size : int, default=1
576
+ Batch size for prediction.
577
+ tile_size : tuple of int, optional
578
+ Size of the tiles to use for prediction.
579
+ tile_overlap : tuple of int, default=(48, 48)
580
+ Overlap between tiles.
581
+ axes : str, optional
563
582
  Axes of the input data, by default None.
564
- data_type : Optional[Literal["array", "tiff", "custom"]], optional
565
- Type of the input data, by default None.
566
- transforms : Optional[List[TRANSFORMS_UNION]], optional
567
- List of transforms to apply to the data, by default None.
568
- tta_transforms : bool, optional
569
- Whether to apply test-time augmentation, by default True.
570
- dataloader_params : Optional[Dict], optional
571
- Parameters to pass to the dataloader, by default None.
572
- read_source_func : Optional[Callable], optional
573
- Function to read the source data, by default None.
574
- extension_filter : str, optional
575
- Filter for the file extension, by default "".
576
- checkpoint : Optional[Literal["best", "last"]], optional
577
- Checkpoint to use for prediction, by default None.
583
+ data_type : {"array", "tiff", "custom"}, optional
584
+ Type of the input data.
585
+ tta_transforms : bool, default=True
586
+ Whether to apply test-time augmentation.
587
+ dataloader_params : dict, optional
588
+ Parameters to pass to the dataloader.
589
+ read_source_func : Callable, optional
590
+ Function to read the source data.
591
+ extension_filter : str, default=""
592
+ Filter for the file extension.
593
+ checkpoint : {"best", "last"}, optional
594
+ Checkpoint to use for prediction.
578
595
  **kwargs : Any
579
596
  Unused.
580
597
 
581
598
  Returns
582
599
  -------
583
- Union[List[np.ndarray], np.ndarray]
600
+ list of NDArray or NDArray
584
601
  Predictions made by the model.
585
-
586
- Raises
587
- ------
588
- ValueError
589
- If the input is not a CAREamicsClay instance, a path or a numpy array.
590
602
  """
591
- if isinstance(source, CAREamicsPredictData):
592
- # record datamodule
593
- self.pred_datamodule = source
594
-
595
- return self.trainer.predict(
596
- model=self.model, datamodule=source, ckpt_path=checkpoint
597
- )
598
- else:
599
- if self.cfg is None:
600
- raise ValueError(
601
- "No configuration found. Train a model or load from a "
602
- "checkpoint before predicting."
603
- )
604
- # create predict config, reuse training config if parameters missing
605
- prediction_config = create_inference_configuration(
606
- configuration=self.cfg,
607
- tile_size=tile_size,
608
- tile_overlap=tile_overlap,
609
- data_type=data_type,
610
- axes=axes,
611
- transforms=transforms,
612
- tta_transforms=tta_transforms,
613
- batch_size=batch_size,
603
+ # Reuse batch size if not provided explicitly
604
+ if batch_size is None:
605
+ batch_size = (
606
+ self.train_datamodule.batch_size
607
+ if self.train_datamodule
608
+ else self.cfg.data_config.batch_size
614
609
  )
615
610
 
616
- # remove batch from dataloader parameters (priority given to config)
617
- if dataloader_params is None:
618
- dataloader_params = {}
619
- if "batch_size" in dataloader_params:
620
- del dataloader_params["batch_size"]
621
-
622
- if isinstance(source, Path) or isinstance(source, str):
623
- # Check the source
624
- source_path = check_path_exists(source)
625
-
626
- # create datamodule
627
- datamodule = CAREamicsPredictData(
628
- pred_config=prediction_config,
629
- pred_data=source_path,
630
- read_source_func=read_source_func,
631
- extension_filter=extension_filter,
632
- dataloader_params=dataloader_params,
633
- )
634
-
635
- # record datamodule
636
- self.pred_datamodule = datamodule
637
-
638
- return self.trainer.predict(
639
- model=self.model, datamodule=datamodule, ckpt_path=checkpoint
640
- )
641
-
642
- elif isinstance(source, np.ndarray):
643
- # create datamodule
644
- datamodule = CAREamicsPredictData(
645
- pred_config=prediction_config,
646
- pred_data=source,
647
- dataloader_params=dataloader_params,
648
- )
649
-
650
- # record datamodule
651
- self.pred_datamodule = datamodule
652
-
653
- return self.trainer.predict(
654
- model=self.model, datamodule=datamodule, ckpt_path=checkpoint
655
- )
611
+ self.pred_datamodule = create_pred_datamodule(
612
+ source=source,
613
+ config=self.cfg,
614
+ batch_size=batch_size,
615
+ tile_size=tile_size,
616
+ tile_overlap=tile_overlap,
617
+ axes=axes,
618
+ data_type=data_type,
619
+ tta_transforms=tta_transforms,
620
+ dataloader_params=dataloader_params,
621
+ read_source_func=read_source_func,
622
+ extension_filter=extension_filter,
623
+ )
656
624
 
657
- else:
658
- raise ValueError(
659
- f"Invalid input. Expected a CAREamicsWood instance, paths or "
660
- f"np.ndarray (got {type(source)})."
661
- )
625
+ predictions = self.trainer.predict(
626
+ model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint
627
+ )
628
+ return convert_outputs(predictions, self.pred_datamodule.tiled)
662
629
 
663
630
  def export_to_bmz(
664
631
  self,
665
632
  path: Union[Path, str],
666
633
  name: str,
667
- authors: List[dict],
668
- input_array: Optional[np.ndarray] = None,
634
+ input_array: NDArray,
635
+ authors: list[dict],
669
636
  general_description: str = "",
670
- channel_names: Optional[List[str]] = None,
637
+ channel_names: Optional[list[str]] = None,
671
638
  data_description: Optional[str] = None,
672
639
  ) -> None:
673
640
  """Export the model to the BioImage Model Zoo format.
674
641
 
675
- Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
642
+ Input array must be of the same dimensions as the axes recorded in the
643
+ configuration of the `CAREamist`.
676
644
 
677
645
  Parameters
678
646
  ----------
679
- path : Union[Path, str]
647
+ path : pathlib.Path or str
680
648
  Path to save the model.
681
649
  name : str
682
650
  Name of the model.
683
- authors : List[dict]
651
+ input_array : NDArray
652
+ Input array used to validate the model and as example.
653
+ authors : list of dict
684
654
  List of authors of the model.
685
- input_array : Optional[np.ndarray], optional
686
- Input array for the model, must be of shape SC(Z)YX, by default None.
687
655
  general_description : str
688
656
  General description of the model, used in the metadata of the BMZ archive.
689
- channel_names : Optional[List[str]], optional
657
+ channel_names : list of str, optional
690
658
  Channel names, by default None.
691
- data_description : Optional[str], optional
659
+ data_description : str, optional
692
660
  Description of the data, by default None.
693
661
  """
694
- if input_array is None:
695
- # generate images, priority is given to the prediction data module
696
- if self.pred_datamodule is not None:
697
- # unpack a batch, ignore masks or targets
698
- input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
699
-
700
- # convert torch.Tensor to numpy
701
- input_patch = input_patch.numpy()
702
- elif self.train_datamodule is not None:
703
- input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
704
- input_patch = input_patch.numpy()
705
- else:
706
- if (
707
- self.cfg.data_config.mean is None
708
- or self.cfg.data_config.std is None
709
- ):
710
- raise ValueError(
711
- "Mean and std cannot be None in the configuration in order to"
712
- "export to the BMZ format. Was the model trained?"
713
- )
714
-
715
- # create a random input array
716
- input_patch = np.random.normal(
717
- loc=self.cfg.data_config.mean,
718
- scale=self.cfg.data_config.std,
719
- size=self.cfg.data_config.patch_size,
720
- ).astype(np.float32)[
721
- np.newaxis, np.newaxis, ...
722
- ] # add S & C dimensions
723
- else:
724
- input_patch = input_array
725
-
726
- # if there is a batch dimension
727
- if input_patch.shape[0] > 1:
728
- input_patch = input_patch[0:1, ...] # keep singleton dim
662
+ input_patch = reshape_array(input_array, self.cfg.data_config.axes)
729
663
 
730
664
  # axes need to be reformated for the export because reshaping was done in the
731
665
  # datamodule
@@ -742,11 +676,10 @@ class CAREamist:
742
676
  tta_transforms=False,
743
677
  )
744
678
 
745
- if not isinstance(output_patch, np.ndarray):
746
- raise ValueError(
747
- f"Numpy array required for export to BioImage Model Zoo, got "
748
- f"{type(output_patch)}."
749
- )
679
+ if isinstance(output_patch, list):
680
+ output = np.concatenate(output_patch, axis=0)
681
+ else:
682
+ output = output_patch
750
683
 
751
684
  export_to_bmz(
752
685
  model=self.model,
@@ -756,7 +689,7 @@ class CAREamist:
756
689
  general_description=general_description,
757
690
  authors=authors,
758
691
  input_array=input_patch,
759
- output_array=output_patch,
692
+ output_array=output,
760
693
  channel_names=channel_names,
761
694
  data_description=data_description,
762
695
  )