careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """A class to train, predict and export models in CAREamics."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload
4
+ from typing import Any, Callable, Literal, Optional, Union, overload
5
5
 
6
6
  import numpy as np
7
+ from numpy.typing import NDArray
7
8
  from pytorch_lightning import Trainer
8
9
  from pytorch_lightning.callbacks import (
9
10
  Callback,
@@ -12,59 +13,64 @@ from pytorch_lightning.callbacks import (
12
13
  )
13
14
  from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
14
15
 
15
- from careamics.callbacks import ProgressBarCallback
16
16
  from careamics.config import (
17
17
  Configuration,
18
- create_inference_configuration,
19
18
  load_configuration,
20
19
  )
21
- from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
20
+ from careamics.config.support import (
21
+ SupportedAlgorithm,
22
+ SupportedArchitecture,
23
+ SupportedData,
24
+ SupportedLogger,
25
+ )
22
26
  from careamics.dataset.dataset_utils import reshape_array
23
- from careamics.lightning_datamodule import CAREamicsTrainData
24
- from careamics.lightning_module import CAREamicsModule
25
- from careamics.lightning_prediction_datamodule import CAREamicsPredictData
26
- from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
27
+ from careamics.lightning import (
28
+ CAREamicsModule,
29
+ HyperParametersCallback,
30
+ PredictDataModule,
31
+ ProgressBarCallback,
32
+ TrainDataModule,
33
+ create_predict_datamodule,
34
+ )
27
35
  from careamics.model_io import export_to_bmz, load_pretrained
28
- from careamics.transforms import Denormalize
36
+ from careamics.prediction_utils import convert_outputs
29
37
  from careamics.utils import check_path_exists, get_logger
30
38
 
31
- from .callbacks import HyperParametersCallback
32
-
33
39
  logger = get_logger(__name__)
34
40
 
35
41
  LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
36
42
 
37
43
 
38
- # TODO napari callbacks
39
- # TODO: how to do AMP? How to continue training?
40
44
  class CAREamist:
41
45
  """Main CAREamics class, allowing training and prediction using various algorithms.
42
46
 
43
47
  Parameters
44
48
  ----------
45
- source : Union[Path, str, Configuration]
49
+ source : pathlib.Path or str or CAREamics Configuration
46
50
  Path to a configuration file or a trained model.
47
- work_dir : Optional[str], optional
51
+ work_dir : str, optional
48
52
  Path to working directory in which to save checkpoints and logs,
49
53
  by default None.
50
- experiment_name : str, optional
51
- Experiment name used for checkpoints, by default "CAREamics".
54
+ experiment_name : str, by default "CAREamics"
55
+ Experiment name used for checkpoints.
56
+ callbacks : list of Callback, optional
57
+ List of callbacks to use during training and prediction, by default None.
52
58
 
53
59
  Attributes
54
60
  ----------
55
- model : CAREamicsKiln
61
+ model : CAREamicsModule
56
62
  CAREamics model.
57
63
  cfg : Configuration
58
64
  CAREamics configuration.
59
65
  trainer : Trainer
60
66
  PyTorch Lightning trainer.
61
- experiment_logger : Optional[Union[TensorBoardLogger, WandbLogger]]
67
+ experiment_logger : TensorBoardLogger or WandbLogger
62
68
  Experiment logger, "wandb" or "tensorboard".
63
- work_dir : Path
69
+ work_dir : pathlib.Path
64
70
  Working directory.
65
- train_datamodule : Optional[CAREamicsWood]
71
+ train_datamodule : TrainDataModule
66
72
  Training datamodule.
67
- pred_datamodule : Optional[CAREamicsClay]
73
+ pred_datamodule : PredictDataModule
68
74
  Prediction datamodule.
69
75
  """
70
76
 
@@ -74,6 +80,7 @@ class CAREamist:
74
80
  source: Union[Path, str],
75
81
  work_dir: Optional[str] = None,
76
82
  experiment_name: str = "CAREamics",
83
+ callbacks: Optional[list[Callback]] = None,
77
84
  ) -> None: ...
78
85
 
79
86
  @overload
@@ -82,6 +89,7 @@ class CAREamist:
82
89
  source: Configuration,
83
90
  work_dir: Optional[str] = None,
84
91
  experiment_name: str = "CAREamics",
92
+ callbacks: Optional[list[Callback]] = None,
85
93
  ) -> None: ...
86
94
 
87
95
  def __init__(
@@ -89,6 +97,7 @@ class CAREamist:
89
97
  source: Union[Path, str, Configuration],
90
98
  work_dir: Optional[Union[Path, str]] = None,
91
99
  experiment_name: str = "CAREamics",
100
+ callbacks: Optional[list[Callback]] = None,
92
101
  ) -> None:
93
102
  """
94
103
  Initialize CAREamist with a configuration object or a path.
@@ -105,13 +114,15 @@ class CAREamist:
105
114
 
106
115
  Parameters
107
116
  ----------
108
- source : Union[Path, str, Configuration]
117
+ source : pathlib.Path or str or CAREamics Configuration
109
118
  Path to a configuration file or a trained model.
110
- work_dir : Optional[str], optional
119
+ work_dir : str, optional
111
120
  Path to working directory in which to save checkpoints and logs,
112
121
  by default None.
113
122
  experiment_name : str, optional
114
123
  Experiment name used for checkpoints, by default "CAREamics".
124
+ callbacks : list of Callback, optional
125
+ List of callbacks to use during training and prediction, by default None.
115
126
 
116
127
  Raises
117
128
  ------
@@ -164,7 +175,7 @@ class CAREamist:
164
175
  self.model, self.cfg = load_pretrained(source)
165
176
 
166
177
  # define the checkpoint saving callback
167
- self.callbacks = self._define_callbacks()
178
+ self._define_callbacks(callbacks)
168
179
 
169
180
  # instantiate logger
170
181
  if self.cfg.training_config.has_logger():
@@ -188,32 +199,50 @@ class CAREamist:
188
199
  logger=self.experiment_logger,
189
200
  )
190
201
 
191
- # change the prediction loop, necessary for tiled prediction
192
- self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
193
-
194
202
  # place holder for the datamodules
195
- self.train_datamodule: Optional[CAREamicsTrainData] = None
196
- self.pred_datamodule: Optional[CAREamicsPredictData] = None
203
+ self.train_datamodule: Optional[TrainDataModule] = None
204
+ self.pred_datamodule: Optional[PredictDataModule] = None
197
205
 
198
- def _define_callbacks(self) -> List[Callback]:
206
+ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
199
207
  """
200
208
  Define the callbacks for the training loop.
201
209
 
202
- Returns
203
- -------
204
- List[Callback]
205
- List of callbacks to be used during training.
210
+ Parameters
211
+ ----------
212
+ callbacks : list of Callback, optional
213
+ List of callbacks to use during training and prediction, by default None.
206
214
  """
215
+ self.callbacks = [] if callbacks is None else callbacks
216
+
217
+ # check that user callbacks are not any of the CAREamics callbacks
218
+ for c in self.callbacks:
219
+ if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
220
+ raise ValueError(
221
+ "ModelCheckpoint and EarlyStopping callbacks are already defined "
222
+ "in CAREamics and should only be modified through the "
223
+ "training configuration (see TrainingConfig)."
224
+ )
225
+
226
+ if isinstance(c, HyperParametersCallback) or isinstance(
227
+ c, ProgressBarCallback
228
+ ):
229
+ raise ValueError(
230
+ "HyperParameter and ProgressBar callbacks are defined internally "
231
+ "and should not be passed as callbacks."
232
+ )
233
+
207
234
  # checkpoint callback saves checkpoints during training
208
- self.callbacks = [
209
- HyperParametersCallback(self.cfg),
210
- ModelCheckpoint(
211
- dirpath=self.work_dir / Path("checkpoints"),
212
- filename=self.cfg.experiment_name,
213
- **self.cfg.training_config.checkpoint_callback.model_dump(),
214
- ),
215
- ProgressBarCallback(),
216
- ]
235
+ self.callbacks.extend(
236
+ [
237
+ HyperParametersCallback(self.cfg),
238
+ ModelCheckpoint(
239
+ dirpath=self.work_dir / Path("checkpoints"),
240
+ filename=self.cfg.experiment_name,
241
+ **self.cfg.training_config.checkpoint_callback.model_dump(),
242
+ ),
243
+ ProgressBarCallback(),
244
+ ]
245
+ )
217
246
 
218
247
  # early stopping callback
219
248
  if self.cfg.training_config.early_stopping_callback is not None:
@@ -221,16 +250,14 @@ class CAREamist:
221
250
  EarlyStopping(self.cfg.training_config.early_stopping_callback)
222
251
  )
223
252
 
224
- return self.callbacks
225
-
226
253
  def train(
227
254
  self,
228
255
  *,
229
- datamodule: Optional[CAREamicsTrainData] = None,
230
- train_source: Optional[Union[Path, str, np.ndarray]] = None,
231
- val_source: Optional[Union[Path, str, np.ndarray]] = None,
232
- train_target: Optional[Union[Path, str, np.ndarray]] = None,
233
- val_target: Optional[Union[Path, str, np.ndarray]] = None,
256
+ datamodule: Optional[TrainDataModule] = None,
257
+ train_source: Optional[Union[Path, str, NDArray]] = None,
258
+ val_source: Optional[Union[Path, str, NDArray]] = None,
259
+ train_target: Optional[Union[Path, str, NDArray]] = None,
260
+ val_target: Optional[Union[Path, str, NDArray]] = None,
234
261
  use_in_memory: bool = True,
235
262
  val_percentage: float = 0.1,
236
263
  val_minimum_split: int = 1,
@@ -253,15 +280,15 @@ class CAREamist:
253
280
 
254
281
  Parameters
255
282
  ----------
256
- datamodule : Optional[CAREamicsWood], optional
283
+ datamodule : TrainDataModule, optional
257
284
  Datamodule to train on, by default None.
258
- train_source : Optional[Union[Path, str, np.ndarray]], optional
285
+ train_source : pathlib.Path or str or NDArray, optional
259
286
  Train source, if no datamodule is provided, by default None.
260
- val_source : Optional[Union[Path, str, np.ndarray]], optional
287
+ val_source : pathlib.Path or str or NDArray, optional
261
288
  Validation source, if no datamodule is provided, by default None.
262
- train_target : Optional[Union[Path, str, np.ndarray]], optional
289
+ train_target : pathlib.Path or str or NDArray, optional
263
290
  Train target source, if no datamodule is provided, by default None.
264
- val_target : Optional[Union[Path, str, np.ndarray]], optional
291
+ val_target : pathlib.Path or str or NDArray, optional
265
292
  Validation target source, if no datamodule is provided, by default None.
266
293
  use_in_memory : bool, optional
267
294
  Use in memory dataset if possible, by default True.
@@ -355,17 +382,17 @@ class CAREamist:
355
382
 
356
383
  else:
357
384
  raise ValueError(
358
- f"Invalid input, expected a str, Path, array or CAREamicsWood "
385
+ f"Invalid input, expected a str, Path, array or TrainDataModule "
359
386
  f"instance (got {type(train_source)})."
360
387
  )
361
388
 
362
- def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
389
+ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
363
390
  """
364
391
  Train the model on the provided datamodule.
365
392
 
366
393
  Parameters
367
394
  ----------
368
- datamodule : CAREamicsWood
395
+ datamodule : TrainDataModule
369
396
  Datamodule to train on.
370
397
  """
371
398
  # record datamodule
@@ -375,10 +402,10 @@ class CAREamist:
375
402
 
376
403
  def _train_on_array(
377
404
  self,
378
- train_data: np.ndarray,
379
- val_data: Optional[np.ndarray] = None,
380
- train_target: Optional[np.ndarray] = None,
381
- val_target: Optional[np.ndarray] = None,
405
+ train_data: NDArray,
406
+ val_data: Optional[NDArray] = None,
407
+ train_target: Optional[NDArray] = None,
408
+ val_target: Optional[NDArray] = None,
382
409
  val_percentage: float = 0.1,
383
410
  val_minimum_split: int = 5,
384
411
  ) -> None:
@@ -387,13 +414,13 @@ class CAREamist:
387
414
 
388
415
  Parameters
389
416
  ----------
390
- train_data : np.ndarray
417
+ train_data : NDArray
391
418
  Training data.
392
- val_data : Optional[np.ndarray], optional
419
+ val_data : NDArray, optional
393
420
  Validation data, by default None.
394
- train_target : Optional[np.ndarray], optional
421
+ train_target : NDArray, optional
395
422
  Train target data, by default None.
396
- val_target : Optional[np.ndarray], optional
423
+ val_target : NDArray, optional
397
424
  Validation target data, by default None.
398
425
  val_percentage : float, optional
399
426
  Percentage of patches to use for validation, by default 0.1.
@@ -401,7 +428,7 @@ class CAREamist:
401
428
  Minimum number of patches to use for validation, by default 5.
402
429
  """
403
430
  # create datamodule
404
- datamodule = CAREamicsTrainData(
431
+ datamodule = TrainDataModule(
405
432
  data_config=self.cfg.data_config,
406
433
  train_data=train_data,
407
434
  val_data=val_data,
@@ -429,13 +456,13 @@ class CAREamist:
429
456
 
430
457
  Parameters
431
458
  ----------
432
- path_to_train_data : Union[Path, str]
459
+ path_to_train_data : pathlib.Path or str
433
460
  Path to the training data.
434
- path_to_val_data : Optional[Union[Path, str]], optional
461
+ path_to_val_data : pathlib.Path or str, optional
435
462
  Path to validation data, by default None.
436
- path_to_train_target : Optional[Union[Path, str]], optional
463
+ path_to_train_target : pathlib.Path or str, optional
437
464
  Path to train target data, by default None.
438
- path_to_val_target : Optional[Union[Path, str]], optional
465
+ path_to_val_target : pathlib.Path or str, optional
439
466
  Path to validation target data, by default None.
440
467
  use_in_memory : bool, optional
441
468
  Use in memory dataset if possible, by default True.
@@ -457,7 +484,7 @@ class CAREamist:
457
484
  path_to_val_target = check_path_exists(path_to_val_target)
458
485
 
459
486
  # create datamodule
460
- datamodule = CAREamicsTrainData(
487
+ datamodule = TrainDataModule(
461
488
  data_config=self.cfg.data_config,
462
489
  train_data=path_to_train_data,
463
490
  val_data=path_to_val_data,
@@ -473,11 +500,8 @@ class CAREamist:
473
500
 
474
501
  @overload
475
502
  def predict( # numpydoc ignore=GL08
476
- self,
477
- source: CAREamicsPredictData,
478
- *,
479
- checkpoint: Optional[Literal["best", "last"]] = None,
480
- ) -> Union[list, np.ndarray]: ...
503
+ self, source: PredictDataModule
504
+ ) -> Union[list[NDArray], NDArray]: ...
481
505
 
482
506
  @overload
483
507
  def predict( # numpydoc ignore=GL08
@@ -485,59 +509,59 @@ class CAREamist:
485
509
  source: Union[Path, str],
486
510
  *,
487
511
  batch_size: int = 1,
488
- tile_size: Optional[Tuple[int, ...]] = None,
489
- tile_overlap: Tuple[int, ...] = (48, 48),
512
+ tile_size: Optional[tuple[int, ...]] = None,
513
+ tile_overlap: tuple[int, ...] = (48, 48),
490
514
  axes: Optional[str] = None,
491
515
  data_type: Optional[Literal["tiff", "custom"]] = None,
492
516
  tta_transforms: bool = True,
493
- dataloader_params: Optional[Dict] = None,
517
+ dataloader_params: Optional[dict] = None,
494
518
  read_source_func: Optional[Callable] = None,
495
519
  extension_filter: str = "",
496
- checkpoint: Optional[Literal["best", "last"]] = None,
497
- ) -> Union[list, np.ndarray]: ...
520
+ ) -> Union[list[NDArray], NDArray]: ...
498
521
 
499
522
  @overload
500
523
  def predict( # numpydoc ignore=GL08
501
524
  self,
502
- source: np.ndarray,
525
+ source: NDArray,
503
526
  *,
504
527
  batch_size: int = 1,
505
- tile_size: Optional[Tuple[int, ...]] = None,
506
- tile_overlap: Tuple[int, ...] = (48, 48),
528
+ tile_size: Optional[tuple[int, ...]] = None,
529
+ tile_overlap: tuple[int, ...] = (48, 48),
507
530
  axes: Optional[str] = None,
508
531
  data_type: Optional[Literal["array"]] = None,
509
532
  tta_transforms: bool = True,
510
- dataloader_params: Optional[Dict] = None,
511
- checkpoint: Optional[Literal["best", "last"]] = None,
512
- ) -> Union[list, np.ndarray]: ...
533
+ dataloader_params: Optional[dict] = None,
534
+ ) -> Union[list[NDArray], NDArray]: ...
513
535
 
514
536
  def predict(
515
537
  self,
516
- source: Union[CAREamicsPredictData, Path, str, np.ndarray],
538
+ source: Union[PredictDataModule, Path, str, NDArray],
517
539
  *,
518
- batch_size: int = 1,
519
- tile_size: Optional[Tuple[int, ...]] = None,
520
- tile_overlap: Tuple[int, ...] = (48, 48),
540
+ batch_size: Optional[int] = None,
541
+ tile_size: Optional[tuple[int, ...]] = None,
542
+ tile_overlap: tuple[int, ...] = (48, 48),
521
543
  axes: Optional[str] = None,
522
544
  data_type: Optional[Literal["array", "tiff", "custom"]] = None,
523
545
  tta_transforms: bool = True,
524
- dataloader_params: Optional[Dict] = None,
546
+ dataloader_params: Optional[dict] = None,
525
547
  read_source_func: Optional[Callable] = None,
526
548
  extension_filter: str = "",
527
- checkpoint: Optional[Literal["best", "last"]] = None,
528
549
  **kwargs: Any,
529
- ) -> Union[List[np.ndarray], np.ndarray]:
550
+ ) -> Union[list[NDArray], NDArray]:
530
551
  """
531
552
  Make predictions on the provided data.
532
553
 
533
- Input can be a CAREamicsClay instance, a path to a data file, or a numpy array.
554
+ Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
555
+ array.
534
556
 
535
557
  If `data_type`, `axes` and `tile_size` are not provided, the training
536
558
  configuration parameters will be used, with the `patch_size` instead of
537
559
  `tile_size`.
538
560
 
539
561
  Test-time augmentation (TTA) can be switched off using the `tta_transforms`
540
- parameter.
562
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
563
+ rotations to the prediction input and averages the predictions. TTA augmentation
564
+ should not be used if you did not train with these augmentations.
541
565
 
542
566
  Note that if you are using a UNet model and tiling, the tile size must be
543
567
  divisible in every dimension by 2**d, where d is the depth of the model. This
@@ -547,242 +571,136 @@ class CAREamist:
547
571
 
548
572
  Parameters
549
573
  ----------
550
- source : Union[CAREamicsClay, Path, str, np.ndarray]
574
+ source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
551
575
  Data to predict on.
552
- batch_size : int, optional
553
- Batch size for prediction, by default 1.
554
- tile_size : Optional[Tuple[int, ...]], optional
555
- Size of the tiles to use for prediction, by default None.
556
- tile_overlap : Tuple[int, ...], optional
557
- Overlap between tiles, by default (48, 48).
558
- axes : Optional[str], optional
576
+ batch_size : int, default=1
577
+ Batch size for prediction.
578
+ tile_size : tuple of int, optional
579
+ Size of the tiles to use for prediction.
580
+ tile_overlap : tuple of int, default=(48, 48)
581
+ Overlap between tiles.
582
+ axes : str, optional
559
583
  Axes of the input data, by default None.
560
- data_type : Optional[Literal["array", "tiff", "custom"]], optional
561
- Type of the input data, by default None.
562
- tta_transforms : bool, optional
563
- Whether to apply test-time augmentation, by default True.
564
- dataloader_params : Optional[Dict], optional
565
- Parameters to pass to the dataloader, by default None.
566
- read_source_func : Optional[Callable], optional
567
- Function to read the source data, by default None.
568
- extension_filter : str, optional
569
- Filter for the file extension, by default "".
570
- checkpoint : Optional[Literal["best", "last"]], optional
571
- Checkpoint to use for prediction, by default None.
584
+ data_type : {"array", "tiff", "custom"}, optional
585
+ Type of the input data.
586
+ tta_transforms : bool, default=True
587
+ Whether to apply test-time augmentation.
588
+ dataloader_params : dict, optional
589
+ Parameters to pass to the dataloader.
590
+ read_source_func : Callable, optional
591
+ Function to read the source data.
592
+ extension_filter : str, default=""
593
+ Filter for the file extension.
572
594
  **kwargs : Any
573
595
  Unused.
574
596
 
575
597
  Returns
576
598
  -------
577
- Union[List[np.ndarray], np.ndarray]
599
+ list of NDArray or NDArray
578
600
  Predictions made by the model.
579
601
 
580
- Raises
581
- ------
582
- ValueError
583
- If the input is not a CAREamicsClay instance, a path or a numpy array.
584
- """
585
- if isinstance(source, CAREamicsPredictData):
586
- # record datamodule
587
- self.pred_datamodule = source
588
-
589
- return self.trainer.predict(
590
- model=self.model, datamodule=source, ckpt_path=checkpoint
591
- )
592
- else:
593
- if self.cfg is None:
594
- raise ValueError(
595
- "No configuration found. Train a model or load from a "
596
- "checkpoint before predicting."
597
- )
598
- # create predict config, reuse training config if parameters missing
599
- prediction_config = create_inference_configuration(
600
- configuration=self.cfg,
601
- tile_size=tile_size,
602
- tile_overlap=tile_overlap,
603
- data_type=data_type,
604
- axes=axes,
605
- tta_transforms=tta_transforms,
606
- batch_size=batch_size,
607
- )
608
-
609
- # remove batch from dataloader parameters (priority given to config)
610
- if dataloader_params is None:
611
- dataloader_params = {}
612
- if "batch_size" in dataloader_params:
613
- del dataloader_params["batch_size"]
614
-
615
- if isinstance(source, Path) or isinstance(source, str):
616
- # Check the source
617
- source_path = check_path_exists(source)
618
-
619
- # create datamodule
620
- datamodule = CAREamicsPredictData(
621
- pred_config=prediction_config,
622
- pred_data=source_path,
623
- read_source_func=read_source_func,
624
- extension_filter=extension_filter,
625
- dataloader_params=dataloader_params,
626
- )
627
-
628
- # record datamodule
629
- self.pred_datamodule = datamodule
630
-
631
- return self.trainer.predict(
632
- model=self.model, datamodule=datamodule, ckpt_path=checkpoint
633
- )
634
-
635
- elif isinstance(source, np.ndarray):
636
- # create datamodule
637
- datamodule = CAREamicsPredictData(
638
- pred_config=prediction_config,
639
- pred_data=source,
640
- dataloader_params=dataloader_params,
641
- )
642
-
643
- # record datamodule
644
- self.pred_datamodule = datamodule
645
-
646
- return self.trainer.predict(
647
- model=self.model, datamodule=datamodule, ckpt_path=checkpoint
648
- )
649
-
650
- else:
651
- raise ValueError(
652
- f"Invalid input. Expected a CAREamicsWood instance, paths or "
653
- f"np.ndarray (got {type(source)})."
654
- )
655
-
656
- def _create_data_for_bmz(
657
- self,
658
- input_array: Optional[np.ndarray] = None,
659
- ) -> np.ndarray:
660
- """Create data for BMZ export.
661
-
662
- If no `input_array` is provided, this method checks if there is a prediction
663
- datamodule, or a training data module, to extract a patch. If none exists,
664
- then a random aray is created.
665
-
666
- If there is a non-singleton batch dimension, this method returns only the first
667
- element.
668
-
669
- Parameters
670
- ----------
671
- input_array : Optional[np.ndarray], optional
672
- Input array, by default None.
673
-
674
- Returns
675
- -------
676
- np.ndarray
677
- Input data for BMZ export.
678
-
679
602
  Raises
680
603
  ------
681
604
  ValueError
682
605
  If mean and std are not provided in the configuration.
606
+ ValueError
607
+ If tile size is not divisible by 2**depth for UNet models.
608
+ ValueError
609
+ If tile overlap is not specified.
683
610
  """
684
- if input_array is None:
685
- if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
686
- raise ValueError(
687
- "Mean and std cannot be None in the configuration in order to"
688
- "export to the BMZ format. Was the model trained?"
689
- )
690
-
691
- # generate images, priority is given to the prediction data module
692
- if self.pred_datamodule is not None:
693
- # unpack a batch, ignore masks or targets
694
- input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
695
-
696
- # convert torch.Tensor to numpy
697
- input_patch = input_patch.numpy()
698
-
699
- # denormalize
700
- denormalize = Denormalize(
701
- mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
702
- )
703
- input_patch, _ = denormalize(input_patch)
704
-
705
- elif self.train_datamodule is not None:
706
- input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
707
- input_patch = input_patch.numpy()
708
-
709
- # denormalize
710
- denormalize = Denormalize(
711
- mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
712
- )
713
- input_patch, _ = denormalize(input_patch)
714
- else:
715
- # create a random input array
716
- input_patch = np.random.normal(
717
- loc=self.cfg.data_config.mean,
718
- scale=self.cfg.data_config.std,
719
- size=self.cfg.data_config.patch_size,
720
- ).astype(np.float32)[
721
- np.newaxis, np.newaxis, ...
722
- ] # add S & C dimensions
723
- else:
724
- # potentially correct shape
725
- input_patch = reshape_array(input_array, self.cfg.data_config.axes)
726
-
727
- # if this a batch
728
- if input_patch.shape[0] > 1:
729
- input_patch = input_patch[[0], ...] # keep singleton dim
611
+ if (
612
+ self.cfg.data_config.image_means is None
613
+ or self.cfg.data_config.image_stds is None
614
+ ):
615
+ raise ValueError("Mean and std must be provided in the configuration.")
616
+
617
+ # tile size for UNets
618
+ if tile_size is not None:
619
+ model = self.cfg.algorithm_config.model
620
+
621
+ if model.architecture == SupportedArchitecture.UNET.value:
622
+ # tile size must be equal to k*2^n, where n is the number of pooling
623
+ # layers (equal to the depth) and k is an integer
624
+ depth = model.depth
625
+ tile_increment = 2**depth
626
+
627
+ for i, t in enumerate(tile_size):
628
+ if t % tile_increment != 0:
629
+ raise ValueError(
630
+ f"Tile size must be divisible by {tile_increment} along "
631
+ f"all axes (got {t} for axis {i}). If your image size is "
632
+ f"smaller along one axis (e.g. Z), consider padding the "
633
+ f"image."
634
+ )
635
+
636
+ # tile overlaps must be specified
637
+ if tile_overlap is None:
638
+ raise ValueError("Tile overlap must be specified.")
639
+
640
+ # create the prediction
641
+ self.pred_datamodule = create_predict_datamodule(
642
+ pred_data=source,
643
+ data_type=data_type or self.cfg.data_config.data_type,
644
+ axes=axes or self.cfg.data_config.axes,
645
+ image_means=self.cfg.data_config.image_means,
646
+ image_stds=self.cfg.data_config.image_stds,
647
+ tile_size=tile_size,
648
+ tile_overlap=tile_overlap,
649
+ batch_size=batch_size or self.cfg.data_config.batch_size,
650
+ tta_transforms=tta_transforms,
651
+ read_source_func=read_source_func,
652
+ extension_filter=extension_filter,
653
+ dataloader_params=dataloader_params,
654
+ )
730
655
 
731
- return input_patch
656
+ # predict
657
+ predictions = self.trainer.predict(
658
+ model=self.model, datamodule=self.pred_datamodule
659
+ )
660
+ return convert_outputs(predictions, self.pred_datamodule.tiled)
732
661
 
733
662
  def export_to_bmz(
734
663
  self,
735
664
  path: Union[Path, str],
736
665
  name: str,
737
- authors: List[dict],
738
- input_array: Optional[np.ndarray] = None,
666
+ input_array: NDArray,
667
+ authors: list[dict],
739
668
  general_description: str = "",
740
- channel_names: Optional[List[str]] = None,
669
+ channel_names: Optional[list[str]] = None,
741
670
  data_description: Optional[str] = None,
742
671
  ) -> None:
743
672
  """Export the model to the BioImage Model Zoo format.
744
673
 
745
- Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
674
+ Input array must be of the same dimensions as the axes recorded in the
675
+ configuration of the `CAREamist`.
746
676
 
747
677
  Parameters
748
678
  ----------
749
- path : Union[Path, str]
679
+ path : pathlib.Path or str
750
680
  Path to save the model.
751
681
  name : str
752
682
  Name of the model.
753
- authors : List[dict]
683
+ input_array : NDArray
684
+ Input array used to validate the model and as example.
685
+ authors : list of dict
754
686
  List of authors of the model.
755
- input_array : Optional[np.ndarray], optional
756
- Input array for the model, must be of shape SC(Z)YX, by default None.
757
687
  general_description : str
758
688
  General description of the model, used in the metadata of the BMZ archive.
759
- channel_names : Optional[List[str]], optional
689
+ channel_names : list of str, optional
760
690
  Channel names, by default None.
761
- data_description : Optional[str], optional
691
+ data_description : str, optional
762
692
  Description of the data, by default None.
763
693
  """
764
- input_patch = self._create_data_for_bmz(input_array)
694
+ # TODO: add in docs that it is expected that input_array dimensions match
695
+ # those in data_config
765
696
 
766
- # axes need to be reformated for the export because reshaping was done in the
767
- # datamodule
768
- if "Z" in self.cfg.data_config.axes:
769
- axes = "SCZYX"
770
- else:
771
- axes = "SCYX"
772
-
773
- # predict output, remove extra dimensions for the purpose of the prediction
774
697
  output_patch = self.predict(
775
- input_patch,
698
+ input_array,
776
699
  data_type=SupportedData.ARRAY.value,
777
- axes=axes,
778
700
  tta_transforms=False,
779
701
  )
780
-
781
- if not isinstance(output_patch, np.ndarray):
782
- raise ValueError(
783
- f"Numpy array required for export to BioImage Model Zoo, got "
784
- f"{type(output_patch)}."
785
- )
702
+ output = np.concatenate(output_patch, axis=0)
703
+ input_array = reshape_array(input_array, self.cfg.data_config.axes)
786
704
 
787
705
  export_to_bmz(
788
706
  model=self.model,
@@ -791,8 +709,8 @@ class CAREamist:
791
709
  name=name,
792
710
  general_description=general_description,
793
711
  authors=authors,
794
- input_array=input_patch,
795
- output_array=output_patch,
712
+ input_array=input_array,
713
+ output_array=output,
796
714
  channel_names=channel_names,
797
715
  data_description=data_description,
798
716
  )