careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py ADDED
@@ -0,0 +1,761 @@
1
+ """A class to train, predict and export models in CAREamics."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload
5
+
6
+ import numpy as np
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import (
9
+ Callback,
10
+ EarlyStopping,
11
+ ModelCheckpoint,
12
+ )
13
+ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
14
+
15
+ from careamics.callbacks import ProgressBarCallback
16
+ from careamics.config import (
17
+ Configuration,
18
+ create_inference_configuration,
19
+ load_configuration,
20
+ )
21
+ from careamics.config.inference_model import TRANSFORMS_UNION
22
+ from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
23
+ from careamics.lightning_datamodule import CAREamicsTrainData
24
+ from careamics.lightning_module import CAREamicsModule
25
+ from careamics.lightning_prediction_datamodule import CAREamicsPredictData
26
+ from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
27
+ from careamics.model_io import export_to_bmz, load_pretrained
28
+ from careamics.utils import check_path_exists, get_logger
29
+
30
+ from .callbacks import HyperParametersCallback
31
+
32
+ logger = get_logger(__name__)
33
+
34
+ LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
35
+
36
+
37
+ # TODO napari callbacks
38
+ # TODO: how to do AMP? How to continue training?
39
+ class CAREamist:
40
+ """Main CAREamics class, allowing training and prediction using various algorithms.
41
+
42
+ Parameters
43
+ ----------
44
+ source : Union[Path, str, Configuration]
45
+ Path to a configuration file or a trained model.
46
+ work_dir : Optional[str], optional
47
+ Path to working directory in which to save checkpoints and logs,
48
+ by default None.
49
+ experiment_name : str, optional
50
+ Experiment name used for checkpoints, by default "CAREamics".
51
+
52
+ Attributes
53
+ ----------
54
+ model : CAREamicsKiln
55
+ CAREamics model.
56
+ cfg : Configuration
57
+ CAREamics configuration.
58
+ trainer : Trainer
59
+ PyTorch Lightning trainer.
60
+ experiment_logger : Optional[Union[TensorBoardLogger, WandbLogger]]
61
+ Experiment logger, "wandb" or "tensorboard".
62
+ work_dir : Path
63
+ Working directory.
64
+ train_datamodule : Optional[CAREamicsWood]
65
+ Training datamodule.
66
+ pred_datamodule : Optional[CAREamicsClay]
67
+ Prediction datamodule.
68
+ """
69
+
70
+ @overload
71
+ def __init__( # numpydoc ignore=GL08
72
+ self,
73
+ source: Union[Path, str],
74
+ work_dir: Optional[str] = None,
75
+ experiment_name: str = "CAREamics",
76
+ ) -> None:
77
+ ...
78
+
79
+ @overload
80
+ def __init__( # numpydoc ignore=GL08
81
+ self,
82
+ source: Configuration,
83
+ work_dir: Optional[str] = None,
84
+ experiment_name: str = "CAREamics",
85
+ ) -> None:
86
+ ...
87
+
88
+ def __init__(
89
+ self,
90
+ source: Union[Path, str, Configuration],
91
+ work_dir: Optional[Union[Path, str]] = None,
92
+ experiment_name: str = "CAREamics",
93
+ ) -> None:
94
+ """
95
+ Initialize CAREamist with a configuration object or a path.
96
+
97
+ A configuration object can be created using directly by calling `Configuration`,
98
+ using the configuration factory or loading a configuration from a yaml file.
99
+
100
+ Path can contain either a yaml file with parameters, or a saved checkpoint.
101
+
102
+ If no working directory is provided, the current working directory is used.
103
+
104
+ If `source` is a checkpoint, then `experiment_name` is used to name the
105
+ checkpoint, and is recorded in the configuration.
106
+
107
+ Parameters
108
+ ----------
109
+ source : Union[Path, str, Configuration]
110
+ Path to a configuration file or a trained model.
111
+ work_dir : Optional[str], optional
112
+ Path to working directory in which to save checkpoints and logs,
113
+ by default None.
114
+ experiment_name : str, optional
115
+ Experiment name used for checkpoints, by default "CAREamics".
116
+
117
+ Raises
118
+ ------
119
+ NotImplementedError
120
+ If the model is loaded from BioImage Model Zoo.
121
+ ValueError
122
+ If no hyper parameters are found in the checkpoint.
123
+ ValueError
124
+ If no data module hyper parameters are found in the checkpoint.
125
+ """
126
+ super().__init__()
127
+
128
+ # select current working directory if work_dir is None
129
+ if work_dir is None:
130
+ self.work_dir = Path.cwd()
131
+ logger.warning(
132
+ f"No working directory provided. Using current working directory: "
133
+ f"{self.work_dir}."
134
+ )
135
+ else:
136
+ self.work_dir = Path(work_dir)
137
+
138
+ # configuration object
139
+ if isinstance(source, Configuration):
140
+ self.cfg = source
141
+
142
+ # instantiate model
143
+ self.model = CAREamicsModule(
144
+ algorithm_config=self.cfg.algorithm_config,
145
+ )
146
+
147
+ # path to configuration file or model
148
+ else:
149
+ source = check_path_exists(source)
150
+
151
+ # configuration file
152
+ if source.is_file() and (
153
+ source.suffix == ".yaml" or source.suffix == ".yml"
154
+ ):
155
+ # load configuration
156
+ self.cfg = load_configuration(source)
157
+
158
+ # instantiate model
159
+ self.model = CAREamicsModule(
160
+ algorithm_config=self.cfg.algorithm_config,
161
+ )
162
+
163
+ # attempt loading a pre-trained model
164
+ else:
165
+ self.model, self.cfg = load_pretrained(source)
166
+
167
+ # define the checkpoint saving callback
168
+ self.callbacks = self._define_callbacks()
169
+
170
+ # instantiate logger
171
+ if self.cfg.training_config.has_logger():
172
+ if self.cfg.training_config.logger == SupportedLogger.WANDB:
173
+ self.experiment_logger: LOGGER_TYPES = WandbLogger(
174
+ name=experiment_name,
175
+ save_dir=self.work_dir / Path("logs"),
176
+ )
177
+ elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
178
+ self.experiment_logger = TensorBoardLogger(
179
+ save_dir=self.work_dir / Path("logs"),
180
+ )
181
+ else:
182
+ self.experiment_logger = None
183
+
184
+ # instantiate trainer
185
+ self.trainer = Trainer(
186
+ max_epochs=self.cfg.training_config.num_epochs,
187
+ callbacks=self.callbacks,
188
+ default_root_dir=self.work_dir,
189
+ logger=self.experiment_logger,
190
+ )
191
+
192
+ # change the prediction loop, necessary for tiled prediction
193
+ self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
194
+
195
+ # place holder for the datamodules
196
+ self.train_datamodule: Optional[CAREamicsTrainData] = None
197
+ self.pred_datamodule: Optional[CAREamicsPredictData] = None
198
+
199
+ def _define_callbacks(self) -> List[Callback]:
200
+ """
201
+ Define the callbacks for the training loop.
202
+
203
+ Returns
204
+ -------
205
+ List[Callback]
206
+ List of callbacks to be used during training.
207
+ """
208
+ # checkpoint callback saves checkpoints during training
209
+ self.callbacks = [
210
+ HyperParametersCallback(self.cfg),
211
+ ModelCheckpoint(
212
+ dirpath=self.work_dir / Path("checkpoints"),
213
+ filename=self.cfg.experiment_name,
214
+ **self.cfg.training_config.checkpoint_callback.model_dump(),
215
+ ),
216
+ ProgressBarCallback(),
217
+ ]
218
+
219
+ # early stopping callback
220
+ if self.cfg.training_config.early_stopping_callback is not None:
221
+ self.callbacks.append(
222
+ EarlyStopping(self.cfg.training_config.early_stopping_callback)
223
+ )
224
+
225
+ return self.callbacks
226
+
227
+ def train(
228
+ self,
229
+ *,
230
+ datamodule: Optional[CAREamicsTrainData] = None,
231
+ train_source: Optional[Union[Path, str, np.ndarray]] = None,
232
+ val_source: Optional[Union[Path, str, np.ndarray]] = None,
233
+ train_target: Optional[Union[Path, str, np.ndarray]] = None,
234
+ val_target: Optional[Union[Path, str, np.ndarray]] = None,
235
+ use_in_memory: bool = True,
236
+ val_percentage: float = 0.1,
237
+ val_minimum_split: int = 1,
238
+ ) -> None:
239
+ """
240
+ Train the model on the provided data.
241
+
242
+ If a datamodule is provided, then training will be performed using it.
243
+ Alternatively, the training data can be provided as arrays or paths.
244
+
245
+ If `use_in_memory` is set to True, the source provided as Path or str will be
246
+ loaded in memory if it fits. Otherwise, training will be performed by loading
247
+ patches from the files one by one. Training on arrays is always performed
248
+ in memory.
249
+
250
+ If no validation source is provided, then the validation is extracted from
251
+ the training data using `val_percentage` and `val_minimum_split`. In the case
252
+ of data provided as Path or str, the percentage and minimum number are applied
253
+ to the number of files. For arrays, it is the number of patches.
254
+
255
+ Parameters
256
+ ----------
257
+ datamodule : Optional[CAREamicsWood], optional
258
+ Datamodule to train on, by default None.
259
+ train_source : Optional[Union[Path, str, np.ndarray]], optional
260
+ Train source, if no datamodule is provided, by default None.
261
+ val_source : Optional[Union[Path, str, np.ndarray]], optional
262
+ Validation source, if no datamodule is provided, by default None.
263
+ train_target : Optional[Union[Path, str, np.ndarray]], optional
264
+ Train target source, if no datamodule is provided, by default None.
265
+ val_target : Optional[Union[Path, str, np.ndarray]], optional
266
+ Validation target source, if no datamodule is provided, by default None.
267
+ use_in_memory : bool, optional
268
+ Use in memory dataset if possible, by default True.
269
+ val_percentage : float, optional
270
+ Percentage of validation extracted from training data, by default 0.1.
271
+ val_minimum_split : int, optional
272
+ Minimum number of validation (patch or file) extracted from training data,
273
+ by default 1.
274
+
275
+ Raises
276
+ ------
277
+ ValueError
278
+ If both `datamodule` and `train_source` are provided.
279
+ ValueError
280
+ If sources are not of the same type (e.g. train is an array and val is
281
+ a Path).
282
+ ValueError
283
+ If the training target is provided to N2V.
284
+ ValueError
285
+ If neither a datamodule nor a source is provided.
286
+ """
287
+ if datamodule is not None and train_source:
288
+ raise ValueError(
289
+ "Only one of `datamodule` and `train_source` can be provided."
290
+ )
291
+
292
+ # check that inputs are the same type
293
+ source_types = {
294
+ type(s)
295
+ for s in (train_source, val_source, train_target, val_target)
296
+ if s is not None
297
+ }
298
+ if len(source_types) > 1:
299
+ raise ValueError("All sources should be of the same type.")
300
+
301
+ # train
302
+ if datamodule is not None:
303
+ self._train_on_datamodule(datamodule=datamodule)
304
+
305
+ else:
306
+ # raise error if target is provided to N2V
307
+ if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
308
+ if train_target is not None:
309
+ raise ValueError(
310
+ "Training target not compatible with N2V training."
311
+ )
312
+
313
+ # dispatch the training
314
+ if isinstance(train_source, np.ndarray):
315
+ # mypy checks
316
+ assert isinstance(val_source, np.ndarray) or val_source is None
317
+ assert isinstance(train_target, np.ndarray) or train_target is None
318
+ assert isinstance(val_target, np.ndarray) or val_target is None
319
+
320
+ self._train_on_array(
321
+ train_source,
322
+ val_source,
323
+ train_target,
324
+ val_target,
325
+ val_percentage,
326
+ val_minimum_split,
327
+ )
328
+
329
+ elif isinstance(train_source, Path) or isinstance(train_source, str):
330
+ # mypy checks
331
+ assert (
332
+ isinstance(val_source, Path)
333
+ or isinstance(val_source, str)
334
+ or val_source is None
335
+ )
336
+ assert (
337
+ isinstance(train_target, Path)
338
+ or isinstance(train_target, str)
339
+ or train_target is None
340
+ )
341
+ assert (
342
+ isinstance(val_target, Path)
343
+ or isinstance(val_target, str)
344
+ or val_target is None
345
+ )
346
+
347
+ self._train_on_path(
348
+ train_source,
349
+ val_source,
350
+ train_target,
351
+ val_target,
352
+ use_in_memory,
353
+ val_percentage,
354
+ val_minimum_split,
355
+ )
356
+
357
+ else:
358
+ raise ValueError(
359
+ f"Invalid input, expected a str, Path, array or CAREamicsWood "
360
+ f"instance (got {type(train_source)})."
361
+ )
362
+
363
+ def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
364
+ """
365
+ Train the model on the provided datamodule.
366
+
367
+ Parameters
368
+ ----------
369
+ datamodule : CAREamicsWood
370
+ Datamodule to train on.
371
+ """
372
+ # record datamodule
373
+ self.train_datamodule = datamodule
374
+
375
+ self.trainer.fit(self.model, datamodule=datamodule)
376
+
377
+ def _train_on_array(
378
+ self,
379
+ train_data: np.ndarray,
380
+ val_data: Optional[np.ndarray] = None,
381
+ train_target: Optional[np.ndarray] = None,
382
+ val_target: Optional[np.ndarray] = None,
383
+ val_percentage: float = 0.1,
384
+ val_minimum_split: int = 5,
385
+ ) -> None:
386
+ """
387
+ Train the model on the provided data arrays.
388
+
389
+ Parameters
390
+ ----------
391
+ train_data : np.ndarray
392
+ Training data.
393
+ val_data : Optional[np.ndarray], optional
394
+ Validation data, by default None.
395
+ train_target : Optional[np.ndarray], optional
396
+ Train target data, by default None.
397
+ val_target : Optional[np.ndarray], optional
398
+ Validation target data, by default None.
399
+ val_percentage : float, optional
400
+ Percentage of patches to use for validation, by default 0.1.
401
+ val_minimum_split : int, optional
402
+ Minimum number of patches to use for validation, by default 5.
403
+ """
404
+ # create datamodule
405
+ datamodule = CAREamicsTrainData(
406
+ data_config=self.cfg.data_config,
407
+ train_data=train_data,
408
+ val_data=val_data,
409
+ train_data_target=train_target,
410
+ val_data_target=val_target,
411
+ val_percentage=val_percentage,
412
+ val_minimum_split=val_minimum_split,
413
+ )
414
+
415
+ # train
416
+ self.train(datamodule=datamodule)
417
+
418
+ def _train_on_path(
419
+ self,
420
+ path_to_train_data: Union[Path, str],
421
+ path_to_val_data: Optional[Union[Path, str]] = None,
422
+ path_to_train_target: Optional[Union[Path, str]] = None,
423
+ path_to_val_target: Optional[Union[Path, str]] = None,
424
+ use_in_memory: bool = True,
425
+ val_percentage: float = 0.1,
426
+ val_minimum_split: int = 1,
427
+ ) -> None:
428
+ """
429
+ Train the model on the provided data paths.
430
+
431
+ Parameters
432
+ ----------
433
+ path_to_train_data : Union[Path, str]
434
+ Path to the training data.
435
+ path_to_val_data : Optional[Union[Path, str]], optional
436
+ Path to validation data, by default None.
437
+ path_to_train_target : Optional[Union[Path, str]], optional
438
+ Path to train target data, by default None.
439
+ path_to_val_target : Optional[Union[Path, str]], optional
440
+ Path to validation target data, by default None.
441
+ use_in_memory : bool, optional
442
+ Use in memory dataset if possible, by default True.
443
+ val_percentage : float, optional
444
+ Percentage of files to use for validation, by default 0.1.
445
+ val_minimum_split : int, optional
446
+ Minimum number of files to use for validation, by default 1.
447
+ """
448
+ # sanity check on data (path exists)
449
+ path_to_train_data = check_path_exists(path_to_train_data)
450
+
451
+ if path_to_val_data is not None:
452
+ path_to_val_data = check_path_exists(path_to_val_data)
453
+
454
+ if path_to_train_target is not None:
455
+ path_to_train_target = check_path_exists(path_to_train_target)
456
+
457
+ if path_to_val_target is not None:
458
+ path_to_val_target = check_path_exists(path_to_val_target)
459
+
460
+ # create datamodule
461
+ datamodule = CAREamicsTrainData(
462
+ data_config=self.cfg.data_config,
463
+ train_data=path_to_train_data,
464
+ val_data=path_to_val_data,
465
+ train_data_target=path_to_train_target,
466
+ val_data_target=path_to_val_target,
467
+ use_in_memory=use_in_memory,
468
+ val_percentage=val_percentage,
469
+ val_minimum_split=val_minimum_split,
470
+ )
471
+
472
+ # train
473
+ self.train(datamodule=datamodule)
474
+
475
+ @overload
476
+ def predict( # numpydoc ignore=GL08
477
+ self,
478
+ source: CAREamicsPredictData,
479
+ *,
480
+ checkpoint: Optional[Literal["best", "last"]] = None,
481
+ ) -> Union[list, np.ndarray]:
482
+ ...
483
+
484
+ @overload
485
+ def predict( # numpydoc ignore=GL08
486
+ self,
487
+ source: Union[Path, str],
488
+ *,
489
+ batch_size: int = 1,
490
+ tile_size: Optional[Tuple[int, ...]] = None,
491
+ tile_overlap: Tuple[int, ...] = (48, 48),
492
+ axes: Optional[str] = None,
493
+ data_type: Optional[Literal["tiff", "custom"]] = None,
494
+ transforms: Optional[List[TRANSFORMS_UNION]] = None,
495
+ tta_transforms: bool = True,
496
+ dataloader_params: Optional[Dict] = None,
497
+ read_source_func: Optional[Callable] = None,
498
+ extension_filter: str = "",
499
+ checkpoint: Optional[Literal["best", "last"]] = None,
500
+ ) -> Union[list, np.ndarray]:
501
+ ...
502
+
503
+ @overload
504
+ def predict( # numpydoc ignore=GL08
505
+ self,
506
+ source: np.ndarray,
507
+ *,
508
+ batch_size: int = 1,
509
+ tile_size: Optional[Tuple[int, ...]] = None,
510
+ tile_overlap: Tuple[int, ...] = (48, 48),
511
+ axes: Optional[str] = None,
512
+ data_type: Optional[Literal["array"]] = None,
513
+ transforms: Optional[List[TRANSFORMS_UNION]] = None,
514
+ tta_transforms: bool = True,
515
+ dataloader_params: Optional[Dict] = None,
516
+ checkpoint: Optional[Literal["best", "last"]] = None,
517
+ ) -> Union[list, np.ndarray]:
518
+ ...
519
+
520
+ def predict(
521
+ self,
522
+ source: Union[CAREamicsPredictData, Path, str, np.ndarray],
523
+ *,
524
+ batch_size: int = 1,
525
+ tile_size: Optional[Tuple[int, ...]] = None,
526
+ tile_overlap: Tuple[int, ...] = (48, 48),
527
+ axes: Optional[str] = None,
528
+ data_type: Optional[Literal["array", "tiff", "custom"]] = None,
529
+ transforms: Optional[List[TRANSFORMS_UNION]] = None,
530
+ tta_transforms: bool = True,
531
+ dataloader_params: Optional[Dict] = None,
532
+ read_source_func: Optional[Callable] = None,
533
+ extension_filter: str = "",
534
+ checkpoint: Optional[Literal["best", "last"]] = None,
535
+ **kwargs: Any,
536
+ ) -> Union[List[np.ndarray], np.ndarray]:
537
+ """
538
+ Make predictions on the provided data.
539
+
540
+ Input can be a CAREamicsClay instance, a path to a data file, or a numpy array.
541
+
542
+ If `data_type`, `axes` and `tile_size` are not provided, the training
543
+ configuration parameters will be used, with the `patch_size` instead of
544
+ `tile_size`.
545
+
546
+ The default transforms are defined in the `InferenceModel` Pydantic model.
547
+
548
+ Test-time augmentation (TTA) can be switched off using the `tta_transforms`
549
+ parameter.
550
+
551
+ Parameters
552
+ ----------
553
+ source : Union[CAREamicsClay, Path, str, np.ndarray]
554
+ Data to predict on.
555
+ batch_size : int, optional
556
+ Batch size for prediction, by default 1.
557
+ tile_size : Optional[Tuple[int, ...]], optional
558
+ Size of the tiles to use for prediction, by default None.
559
+ tile_overlap : Tuple[int, ...], optional
560
+ Overlap between tiles, by default (48, 48).
561
+ axes : Optional[str], optional
562
+ Axes of the input data, by default None.
563
+ data_type : Optional[Literal["array", "tiff", "custom"]], optional
564
+ Type of the input data, by default None.
565
+ transforms : Optional[List[TRANSFORMS_UNION]], optional
566
+ List of transforms to apply to the data, by default None.
567
+ tta_transforms : bool, optional
568
+ Whether to apply test-time augmentation, by default True.
569
+ dataloader_params : Optional[Dict], optional
570
+ Parameters to pass to the dataloader, by default None.
571
+ read_source_func : Optional[Callable], optional
572
+ Function to read the source data, by default None.
573
+ extension_filter : str, optional
574
+ Filter for the file extension, by default "".
575
+ checkpoint : Optional[Literal["best", "last"]], optional
576
+ Checkpoint to use for prediction, by default None.
577
+ **kwargs : Any
578
+ Unused.
579
+
580
+ Returns
581
+ -------
582
+ Union[List[np.ndarray], np.ndarray]
583
+ Predictions made by the model.
584
+
585
+ Raises
586
+ ------
587
+ ValueError
588
+ If the input is not a CAREamicsClay instance, a path or a numpy array.
589
+ """
590
+ if isinstance(source, CAREamicsPredictData):
591
+ # record datamodule
592
+ self.pred_datamodule = source
593
+
594
+ return self.trainer.predict(
595
+ model=self.model, datamodule=source, ckpt_path=checkpoint
596
+ )
597
+ else:
598
+ if self.cfg is None:
599
+ raise ValueError(
600
+ "No configuration found. Train a model or load from a "
601
+ "checkpoint before predicting."
602
+ )
603
+ # create predict config, reuse training config if parameters missing
604
+ prediction_config = create_inference_configuration(
605
+ training_configuration=self.cfg,
606
+ tile_size=tile_size,
607
+ tile_overlap=tile_overlap,
608
+ data_type=data_type,
609
+ axes=axes,
610
+ transforms=transforms,
611
+ tta_transforms=tta_transforms,
612
+ batch_size=batch_size,
613
+ )
614
+
615
+ # remove batch from dataloader parameters (priority given to config)
616
+ if dataloader_params is None:
617
+ dataloader_params = {}
618
+ if "batch_size" in dataloader_params:
619
+ del dataloader_params["batch_size"]
620
+
621
+ if isinstance(source, Path) or isinstance(source, str):
622
+ # Check the source
623
+ source_path = check_path_exists(source)
624
+
625
+ # create datamodule
626
+ datamodule = CAREamicsPredictData(
627
+ pred_config=prediction_config,
628
+ pred_data=source_path,
629
+ read_source_func=read_source_func,
630
+ extension_filter=extension_filter,
631
+ dataloader_params=dataloader_params,
632
+ )
633
+
634
+ # record datamodule
635
+ self.pred_datamodule = datamodule
636
+
637
+ return self.trainer.predict(
638
+ model=self.model, datamodule=datamodule, ckpt_path=checkpoint
639
+ )
640
+
641
+ elif isinstance(source, np.ndarray):
642
+ # create datamodule
643
+ datamodule = CAREamicsPredictData(
644
+ pred_config=prediction_config,
645
+ pred_data=source,
646
+ dataloader_params=dataloader_params,
647
+ )
648
+
649
+ # record datamodule
650
+ self.pred_datamodule = datamodule
651
+
652
+ return self.trainer.predict(
653
+ model=self.model, datamodule=datamodule, ckpt_path=checkpoint
654
+ )
655
+
656
+ else:
657
+ raise ValueError(
658
+ f"Invalid input. Expected a CAREamicsWood instance, paths or "
659
+ f"np.ndarray (got {type(source)})."
660
+ )
661
+
662
+ def export_to_bmz(
663
+ self,
664
+ path: Union[Path, str],
665
+ name: str,
666
+ authors: List[dict],
667
+ input_array: Optional[np.ndarray] = None,
668
+ general_description: str = "",
669
+ channel_names: Optional[List[str]] = None,
670
+ data_description: Optional[str] = None,
671
+ ) -> None:
672
+ """Export the model to the BioImage Model Zoo format.
673
+
674
+ Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
675
+
676
+ Parameters
677
+ ----------
678
+ path : Union[Path, str]
679
+ Path to save the model.
680
+ name : str
681
+ Name of the model.
682
+ authors : List[dict]
683
+ List of authors of the model.
684
+ input_array : Optional[np.ndarray], optional
685
+ Input array for the model, must be of shape SC(Z)YX, by default None.
686
+ general_description : str
687
+ General description of the model, used in the metadata of the BMZ archive.
688
+ channel_names : Optional[List[str]], optional
689
+ Channel names, by default None.
690
+ data_description : Optional[str], optional
691
+ Description of the data, by default None.
692
+ """
693
+ if input_array is None:
694
+ # generate images, priority is given to the prediction data module
695
+ if self.pred_datamodule is not None:
696
+ # unpack a batch, ignore masks or targets
697
+ input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
698
+
699
+ # convert torch.Tensor to numpy
700
+ input_patch = input_patch.numpy()
701
+ elif self.train_datamodule is not None:
702
+ input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
703
+ input_patch = input_patch.numpy()
704
+ else:
705
+ if (
706
+ self.cfg.data_config.mean is None
707
+ or self.cfg.data_config.std is None
708
+ ):
709
+ raise ValueError(
710
+ "Mean and std cannot be None in the configuration in order to"
711
+ "export to the BMZ format. Was the model trained?"
712
+ )
713
+
714
+ # create a random input array
715
+ input_patch = np.random.normal(
716
+ loc=self.cfg.data_config.mean,
717
+ scale=self.cfg.data_config.std,
718
+ size=self.cfg.data_config.patch_size,
719
+ ).astype(np.float32)[
720
+ np.newaxis, np.newaxis, ...
721
+ ] # add S & C dimensions
722
+ else:
723
+ input_patch = input_array
724
+
725
+ # if there is a batch dimension
726
+ if input_patch.shape[0] > 1:
727
+ input_patch = input_patch[0:1, ...] # keep singleton dim
728
+
729
+ # axes need to be reformated for the export because reshaping was done in the
730
+ # datamodule
731
+ if "Z" in self.cfg.data_config.axes:
732
+ axes = "SCZYX"
733
+ else:
734
+ axes = "SCYX"
735
+
736
+ # predict output, remove extra dimensions for the purpose of the prediction
737
+ output_patch = self.predict(
738
+ input_patch,
739
+ data_type=SupportedData.ARRAY.value,
740
+ axes=axes,
741
+ tta_transforms=False,
742
+ )
743
+
744
+ if not isinstance(output_patch, np.ndarray):
745
+ raise ValueError(
746
+ f"Numpy array required for export to BioImage Model Zoo, got "
747
+ f"{type(output_patch)}."
748
+ )
749
+
750
+ export_to_bmz(
751
+ model=self.model,
752
+ config=self.cfg,
753
+ path=path,
754
+ name=name,
755
+ general_description=general_description,
756
+ authors=authors,
757
+ input_array=input_patch,
758
+ output_array=output_patch,
759
+ channel_names=channel_names,
760
+ data_description=data_description,
761
+ )