careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
careamics/careamist.py ADDED
@@ -0,0 +1,726 @@
1
+ """A class to train, predict and export models in CAREamics."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Literal, Optional, Union, overload
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from pytorch_lightning import Trainer
9
+ from pytorch_lightning.callbacks import (
10
+ Callback,
11
+ EarlyStopping,
12
+ ModelCheckpoint,
13
+ )
14
+ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
+
16
+ from careamics.config import (
17
+ Configuration,
18
+ load_configuration,
19
+ )
20
+ from careamics.config.support import (
21
+ SupportedAlgorithm,
22
+ SupportedArchitecture,
23
+ SupportedData,
24
+ SupportedLogger,
25
+ )
26
+ from careamics.dataset.dataset_utils import reshape_array
27
+ from careamics.lightning import (
28
+ CAREamicsModule,
29
+ HyperParametersCallback,
30
+ PredictDataModule,
31
+ ProgressBarCallback,
32
+ TrainDataModule,
33
+ create_predict_datamodule,
34
+ )
35
+ from careamics.model_io import export_to_bmz, load_pretrained
36
+ from careamics.prediction_utils import convert_outputs
37
+ from careamics.utils import check_path_exists, get_logger
38
+
39
+ logger = get_logger(__name__)
40
+
41
+ LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
42
+
43
+
44
+ class CAREamist:
45
+ """Main CAREamics class, allowing training and prediction using various algorithms.
46
+
47
+ Parameters
48
+ ----------
49
+ source : pathlib.Path or str or CAREamics Configuration
50
+ Path to a configuration file or a trained model.
51
+ work_dir : str, optional
52
+ Path to working directory in which to save checkpoints and logs,
53
+ by default None.
54
+ experiment_name : str, by default "CAREamics"
55
+ Experiment name used for checkpoints.
56
+ callbacks : list of Callback, optional
57
+ List of callbacks to use during training and prediction, by default None.
58
+
59
+ Attributes
60
+ ----------
61
+ model : CAREamicsModule
62
+ CAREamics model.
63
+ cfg : Configuration
64
+ CAREamics configuration.
65
+ trainer : Trainer
66
+ PyTorch Lightning trainer.
67
+ experiment_logger : TensorBoardLogger or WandbLogger
68
+ Experiment logger, "wandb" or "tensorboard".
69
+ work_dir : pathlib.Path
70
+ Working directory.
71
+ train_datamodule : TrainDataModule
72
+ Training datamodule.
73
+ pred_datamodule : PredictDataModule
74
+ Prediction datamodule.
75
+ """
76
+
77
+ @overload
78
+ def __init__( # numpydoc ignore=GL08
79
+ self,
80
+ source: Union[Path, str],
81
+ work_dir: Optional[str] = None,
82
+ experiment_name: str = "CAREamics",
83
+ callbacks: Optional[list[Callback]] = None,
84
+ ) -> None: ...
85
+
86
+ @overload
87
+ def __init__( # numpydoc ignore=GL08
88
+ self,
89
+ source: Configuration,
90
+ work_dir: Optional[str] = None,
91
+ experiment_name: str = "CAREamics",
92
+ callbacks: Optional[list[Callback]] = None,
93
+ ) -> None: ...
94
+
95
+ def __init__(
96
+ self,
97
+ source: Union[Path, str, Configuration],
98
+ work_dir: Optional[Union[Path, str]] = None,
99
+ experiment_name: str = "CAREamics",
100
+ callbacks: Optional[list[Callback]] = None,
101
+ ) -> None:
102
+ """
103
+ Initialize CAREamist with a configuration object or a path.
104
+
105
+ A configuration object can be created using directly by calling `Configuration`,
106
+ using the configuration factory or loading a configuration from a yaml file.
107
+
108
+ Path can contain either a yaml file with parameters, or a saved checkpoint.
109
+
110
+ If no working directory is provided, the current working directory is used.
111
+
112
+ If `source` is a checkpoint, then `experiment_name` is used to name the
113
+ checkpoint, and is recorded in the configuration.
114
+
115
+ Parameters
116
+ ----------
117
+ source : pathlib.Path or str or CAREamics Configuration
118
+ Path to a configuration file or a trained model.
119
+ work_dir : str, optional
120
+ Path to working directory in which to save checkpoints and logs,
121
+ by default None.
122
+ experiment_name : str, optional
123
+ Experiment name used for checkpoints, by default "CAREamics".
124
+ callbacks : list of Callback, optional
125
+ List of callbacks to use during training and prediction, by default None.
126
+
127
+ Raises
128
+ ------
129
+ NotImplementedError
130
+ If the model is loaded from BioImage Model Zoo.
131
+ ValueError
132
+ If no hyper parameters are found in the checkpoint.
133
+ ValueError
134
+ If no data module hyper parameters are found in the checkpoint.
135
+ """
136
+ # select current working directory if work_dir is None
137
+ if work_dir is None:
138
+ self.work_dir = Path.cwd()
139
+ logger.warning(
140
+ f"No working directory provided. Using current working directory: "
141
+ f"{self.work_dir}."
142
+ )
143
+ else:
144
+ self.work_dir = Path(work_dir)
145
+
146
+ # configuration object
147
+ if isinstance(source, Configuration):
148
+ self.cfg = source
149
+
150
+ # instantiate model
151
+ self.model = CAREamicsModule(
152
+ algorithm_config=self.cfg.algorithm_config,
153
+ )
154
+
155
+ # path to configuration file or model
156
+ else:
157
+ source = check_path_exists(source)
158
+
159
+ # configuration file
160
+ if source.is_file() and (
161
+ source.suffix == ".yaml" or source.suffix == ".yml"
162
+ ):
163
+ # load configuration
164
+ self.cfg = load_configuration(source)
165
+
166
+ # instantiate model
167
+ self.model = CAREamicsModule(
168
+ algorithm_config=self.cfg.algorithm_config,
169
+ )
170
+
171
+ # attempt loading a pre-trained model
172
+ else:
173
+ self.model, self.cfg = load_pretrained(source)
174
+
175
+ # define the checkpoint saving callback
176
+ self._define_callbacks(callbacks)
177
+
178
+ # instantiate logger
179
+ if self.cfg.training_config.has_logger():
180
+ if self.cfg.training_config.logger == SupportedLogger.WANDB:
181
+ self.experiment_logger: LOGGER_TYPES = WandbLogger(
182
+ name=self.cfg.experiment_name,
183
+ save_dir=self.work_dir / Path("logs"),
184
+ )
185
+ elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
186
+ self.experiment_logger = TensorBoardLogger(
187
+ save_dir=self.work_dir / Path("logs"),
188
+ )
189
+ else:
190
+ self.experiment_logger = None
191
+
192
+ # instantiate trainer
193
+ self.trainer = Trainer(
194
+ max_epochs=self.cfg.training_config.num_epochs,
195
+ callbacks=self.callbacks,
196
+ default_root_dir=self.work_dir,
197
+ logger=self.experiment_logger,
198
+ )
199
+
200
+ # place holder for the datamodules
201
+ self.train_datamodule: Optional[TrainDataModule] = None
202
+ self.pred_datamodule: Optional[PredictDataModule] = None
203
+
204
+ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
205
+ """Define the callbacks for the training loop.
206
+
207
+ Parameters
208
+ ----------
209
+ callbacks : list of Callback, optional
210
+ List of callbacks to use during training and prediction, by default None.
211
+ """
212
+ self.callbacks = [] if callbacks is None else callbacks
213
+
214
+ # check that user callbacks are not any of the CAREamics callbacks
215
+ for c in self.callbacks:
216
+ if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
217
+ raise ValueError(
218
+ "ModelCheckpoint and EarlyStopping callbacks are already defined "
219
+ "in CAREamics and should only be modified through the "
220
+ "training configuration (see TrainingConfig)."
221
+ )
222
+
223
+ if isinstance(c, HyperParametersCallback) or isinstance(
224
+ c, ProgressBarCallback
225
+ ):
226
+ raise ValueError(
227
+ "HyperParameter and ProgressBar callbacks are defined internally "
228
+ "and should not be passed as callbacks."
229
+ )
230
+
231
+ # checkpoint callback saves checkpoints during training
232
+ self.callbacks.extend(
233
+ [
234
+ HyperParametersCallback(self.cfg),
235
+ ModelCheckpoint(
236
+ dirpath=self.work_dir / Path("checkpoints"),
237
+ filename=self.cfg.experiment_name,
238
+ **self.cfg.training_config.checkpoint_callback.model_dump(),
239
+ ),
240
+ ProgressBarCallback(),
241
+ ]
242
+ )
243
+
244
+ # early stopping callback
245
+ if self.cfg.training_config.early_stopping_callback is not None:
246
+ self.callbacks.append(
247
+ EarlyStopping(self.cfg.training_config.early_stopping_callback)
248
+ )
249
+
250
+ # TODO: is there are more elegant way than calling train again after _train_on_paths
251
+ def train(
252
+ self,
253
+ *,
254
+ datamodule: Optional[TrainDataModule] = None,
255
+ train_source: Optional[Union[Path, str, NDArray]] = None,
256
+ val_source: Optional[Union[Path, str, NDArray]] = None,
257
+ train_target: Optional[Union[Path, str, NDArray]] = None,
258
+ val_target: Optional[Union[Path, str, NDArray]] = None,
259
+ use_in_memory: bool = True,
260
+ val_percentage: float = 0.1,
261
+ val_minimum_split: int = 1,
262
+ ) -> None:
263
+ """
264
+ Train the model on the provided data.
265
+
266
+ If a datamodule is provided, then training will be performed using it.
267
+ Alternatively, the training data can be provided as arrays or paths.
268
+
269
+ If `use_in_memory` is set to True, the source provided as Path or str will be
270
+ loaded in memory if it fits. Otherwise, training will be performed by loading
271
+ patches from the files one by one. Training on arrays is always performed
272
+ in memory.
273
+
274
+ If no validation source is provided, then the validation is extracted from
275
+ the training data using `val_percentage` and `val_minimum_split`. In the case
276
+ of data provided as Path or str, the percentage and minimum number are applied
277
+ to the number of files. For arrays, it is the number of patches.
278
+
279
+ Parameters
280
+ ----------
281
+ datamodule : TrainDataModule, optional
282
+ Datamodule to train on, by default None.
283
+ train_source : pathlib.Path or str or NDArray, optional
284
+ Train source, if no datamodule is provided, by default None.
285
+ val_source : pathlib.Path or str or NDArray, optional
286
+ Validation source, if no datamodule is provided, by default None.
287
+ train_target : pathlib.Path or str or NDArray, optional
288
+ Train target source, if no datamodule is provided, by default None.
289
+ val_target : pathlib.Path or str or NDArray, optional
290
+ Validation target source, if no datamodule is provided, by default None.
291
+ use_in_memory : bool, optional
292
+ Use in memory dataset if possible, by default True.
293
+ val_percentage : float, optional
294
+ Percentage of validation extracted from training data, by default 0.1.
295
+ val_minimum_split : int, optional
296
+ Minimum number of validation (patch or file) extracted from training data,
297
+ by default 1.
298
+
299
+ Raises
300
+ ------
301
+ ValueError
302
+ If both `datamodule` and `train_source` are provided.
303
+ ValueError
304
+ If sources are not of the same type (e.g. train is an array and val is
305
+ a Path).
306
+ ValueError
307
+ If the training target is provided to N2V.
308
+ ValueError
309
+ If neither a datamodule nor a source is provided.
310
+ """
311
+ if datamodule is not None and train_source is not None:
312
+ raise ValueError(
313
+ "Only one of `datamodule` and `train_source` can be provided."
314
+ )
315
+
316
+ # check that inputs are the same type
317
+ source_types = {
318
+ type(s)
319
+ for s in (train_source, val_source, train_target, val_target)
320
+ if s is not None
321
+ }
322
+ if len(source_types) > 1:
323
+ raise ValueError("All sources should be of the same type.")
324
+
325
+ # train
326
+ if datamodule is not None:
327
+ self._train_on_datamodule(datamodule=datamodule)
328
+
329
+ else:
330
+ # raise error if target is provided to N2V
331
+ if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
332
+ if train_target is not None:
333
+ raise ValueError(
334
+ "Training target not compatible with N2V training."
335
+ )
336
+
337
+ # dispatch the training
338
+ if isinstance(train_source, np.ndarray):
339
+ # mypy checks
340
+ assert isinstance(val_source, np.ndarray) or val_source is None
341
+ assert isinstance(train_target, np.ndarray) or train_target is None
342
+ assert isinstance(val_target, np.ndarray) or val_target is None
343
+
344
+ self._train_on_array(
345
+ train_source,
346
+ val_source,
347
+ train_target,
348
+ val_target,
349
+ val_percentage,
350
+ val_minimum_split,
351
+ )
352
+
353
+ elif isinstance(train_source, Path) or isinstance(train_source, str):
354
+ # mypy checks
355
+ assert (
356
+ isinstance(val_source, Path)
357
+ or isinstance(val_source, str)
358
+ or val_source is None
359
+ )
360
+ assert (
361
+ isinstance(train_target, Path)
362
+ or isinstance(train_target, str)
363
+ or train_target is None
364
+ )
365
+ assert (
366
+ isinstance(val_target, Path)
367
+ or isinstance(val_target, str)
368
+ or val_target is None
369
+ )
370
+
371
+ self._train_on_path(
372
+ train_source,
373
+ val_source,
374
+ train_target,
375
+ val_target,
376
+ use_in_memory,
377
+ val_percentage,
378
+ val_minimum_split,
379
+ )
380
+
381
+ else:
382
+ raise ValueError(
383
+ f"Invalid input, expected a str, Path, array or TrainDataModule "
384
+ f"instance (got {type(train_source)})."
385
+ )
386
+
387
+ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
388
+ """
389
+ Train the model on the provided datamodule.
390
+
391
+ Parameters
392
+ ----------
393
+ datamodule : TrainDataModule
394
+ Datamodule to train on.
395
+ """
396
+ # record datamodule
397
+ self.train_datamodule = datamodule
398
+
399
+ self.trainer.fit(self.model, datamodule=datamodule)
400
+
401
+ def _train_on_array(
402
+ self,
403
+ train_data: NDArray,
404
+ val_data: Optional[NDArray] = None,
405
+ train_target: Optional[NDArray] = None,
406
+ val_target: Optional[NDArray] = None,
407
+ val_percentage: float = 0.1,
408
+ val_minimum_split: int = 5,
409
+ ) -> None:
410
+ """
411
+ Train the model on the provided data arrays.
412
+
413
+ Parameters
414
+ ----------
415
+ train_data : NDArray
416
+ Training data.
417
+ val_data : NDArray, optional
418
+ Validation data, by default None.
419
+ train_target : NDArray, optional
420
+ Train target data, by default None.
421
+ val_target : NDArray, optional
422
+ Validation target data, by default None.
423
+ val_percentage : float, optional
424
+ Percentage of patches to use for validation, by default 0.1.
425
+ val_minimum_split : int, optional
426
+ Minimum number of patches to use for validation, by default 5.
427
+ """
428
+ # create datamodule
429
+ datamodule = TrainDataModule(
430
+ data_config=self.cfg.data_config,
431
+ train_data=train_data,
432
+ val_data=val_data,
433
+ train_data_target=train_target,
434
+ val_data_target=val_target,
435
+ val_percentage=val_percentage,
436
+ val_minimum_split=val_minimum_split,
437
+ )
438
+
439
+ # train
440
+ self.train(datamodule=datamodule)
441
+
442
+ def _train_on_path(
443
+ self,
444
+ path_to_train_data: Union[Path, str],
445
+ path_to_val_data: Optional[Union[Path, str]] = None,
446
+ path_to_train_target: Optional[Union[Path, str]] = None,
447
+ path_to_val_target: Optional[Union[Path, str]] = None,
448
+ use_in_memory: bool = True,
449
+ val_percentage: float = 0.1,
450
+ val_minimum_split: int = 1,
451
+ ) -> None:
452
+ """
453
+ Train the model on the provided data paths.
454
+
455
+ Parameters
456
+ ----------
457
+ path_to_train_data : pathlib.Path or str
458
+ Path to the training data.
459
+ path_to_val_data : pathlib.Path or str, optional
460
+ Path to validation data, by default None.
461
+ path_to_train_target : pathlib.Path or str, optional
462
+ Path to train target data, by default None.
463
+ path_to_val_target : pathlib.Path or str, optional
464
+ Path to validation target data, by default None.
465
+ use_in_memory : bool, optional
466
+ Use in memory dataset if possible, by default True.
467
+ val_percentage : float, optional
468
+ Percentage of files to use for validation, by default 0.1.
469
+ val_minimum_split : int, optional
470
+ Minimum number of files to use for validation, by default 1.
471
+ """
472
+ # sanity check on data (path exists)
473
+ path_to_train_data = check_path_exists(path_to_train_data)
474
+
475
+ if path_to_val_data is not None:
476
+ path_to_val_data = check_path_exists(path_to_val_data)
477
+
478
+ if path_to_train_target is not None:
479
+ path_to_train_target = check_path_exists(path_to_train_target)
480
+
481
+ if path_to_val_target is not None:
482
+ path_to_val_target = check_path_exists(path_to_val_target)
483
+
484
+ # create datamodule
485
+ datamodule = TrainDataModule(
486
+ data_config=self.cfg.data_config,
487
+ train_data=path_to_train_data,
488
+ val_data=path_to_val_data,
489
+ train_data_target=path_to_train_target,
490
+ val_data_target=path_to_val_target,
491
+ use_in_memory=use_in_memory,
492
+ val_percentage=val_percentage,
493
+ val_minimum_split=val_minimum_split,
494
+ )
495
+
496
+ # train
497
+ self.train(datamodule=datamodule)
498
+
499
+ @overload
500
+ def predict( # numpydoc ignore=GL08
501
+ self, source: PredictDataModule
502
+ ) -> Union[list[NDArray], NDArray]: ...
503
+
504
+ @overload
505
+ def predict( # numpydoc ignore=GL08
506
+ self,
507
+ source: Union[Path, str],
508
+ *,
509
+ batch_size: int = 1,
510
+ tile_size: Optional[tuple[int, ...]] = None,
511
+ tile_overlap: tuple[int, ...] = (48, 48),
512
+ axes: Optional[str] = None,
513
+ data_type: Optional[Literal["tiff", "custom"]] = None,
514
+ tta_transforms: bool = True,
515
+ dataloader_params: Optional[dict] = None,
516
+ read_source_func: Optional[Callable] = None,
517
+ extension_filter: str = "",
518
+ ) -> Union[list[NDArray], NDArray]: ...
519
+
520
+ @overload
521
+ def predict( # numpydoc ignore=GL08
522
+ self,
523
+ source: NDArray,
524
+ *,
525
+ batch_size: int = 1,
526
+ tile_size: Optional[tuple[int, ...]] = None,
527
+ tile_overlap: tuple[int, ...] = (48, 48),
528
+ axes: Optional[str] = None,
529
+ data_type: Optional[Literal["array"]] = None,
530
+ tta_transforms: bool = True,
531
+ dataloader_params: Optional[dict] = None,
532
+ ) -> Union[list[NDArray], NDArray]: ...
533
+
534
+ def predict(
535
+ self,
536
+ source: Union[PredictDataModule, Path, str, NDArray],
537
+ *,
538
+ batch_size: Optional[int] = None,
539
+ tile_size: Optional[tuple[int, ...]] = None,
540
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
541
+ axes: Optional[str] = None,
542
+ data_type: Optional[Literal["array", "tiff", "custom"]] = None,
543
+ tta_transforms: bool = True,
544
+ dataloader_params: Optional[dict] = None,
545
+ read_source_func: Optional[Callable] = None,
546
+ extension_filter: str = "",
547
+ **kwargs: Any,
548
+ ) -> Union[list[NDArray], NDArray]:
549
+ """
550
+ Make predictions on the provided data.
551
+
552
+ Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
553
+ array.
554
+
555
+ If `data_type`, `axes` and `tile_size` are not provided, the training
556
+ configuration parameters will be used, with the `patch_size` instead of
557
+ `tile_size`.
558
+
559
+ Test-time augmentation (TTA) can be switched off using the `tta_transforms`
560
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
561
+ rotations to the prediction input and averages the predictions. TTA augmentation
562
+ should not be used if you did not train with these augmentations.
563
+
564
+ Note that if you are using a UNet model and tiling, the tile size must be
565
+ divisible in every dimension by 2**d, where d is the depth of the model. This
566
+ avoids artefacts arising from the broken shift invariance induced by the
567
+ pooling layers of the UNet. If your image has less dimensions, as it may
568
+ happen in the Z dimension, consider padding your image.
569
+
570
+ Parameters
571
+ ----------
572
+ source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
573
+ Data to predict on.
574
+ batch_size : int, default=1
575
+ Batch size for prediction.
576
+ tile_size : tuple of int, optional
577
+ Size of the tiles to use for prediction.
578
+ tile_overlap : tuple of int, default=(48, 48)
579
+ Overlap between tiles, can be None.
580
+ axes : str, optional
581
+ Axes of the input data, by default None.
582
+ data_type : {"array", "tiff", "custom"}, optional
583
+ Type of the input data.
584
+ tta_transforms : bool, default=True
585
+ Whether to apply test-time augmentation.
586
+ dataloader_params : dict, optional
587
+ Parameters to pass to the dataloader.
588
+ read_source_func : Callable, optional
589
+ Function to read the source data.
590
+ extension_filter : str, default=""
591
+ Filter for the file extension.
592
+ **kwargs : Any
593
+ Unused.
594
+
595
+ Returns
596
+ -------
597
+ list of NDArray or NDArray
598
+ Predictions made by the model.
599
+
600
+ Raises
601
+ ------
602
+ ValueError
603
+ If mean and std are not provided in the configuration.
604
+ ValueError
605
+ If tile size is not divisible by 2**depth for UNet models.
606
+ ValueError
607
+ If tile overlap is not specified.
608
+ """
609
+ if (
610
+ self.cfg.data_config.image_means is None
611
+ or self.cfg.data_config.image_stds is None
612
+ ):
613
+ raise ValueError("Mean and std must be provided in the configuration.")
614
+
615
+ # tile size for UNets
616
+ if tile_size is not None:
617
+ model = self.cfg.algorithm_config.model
618
+
619
+ if model.architecture == SupportedArchitecture.UNET.value:
620
+ # tile size must be equal to k*2^n, where n is the number of pooling
621
+ # layers (equal to the depth) and k is an integer
622
+ depth = model.depth
623
+ tile_increment = 2**depth
624
+
625
+ for i, t in enumerate(tile_size):
626
+ if t % tile_increment != 0:
627
+ raise ValueError(
628
+ f"Tile size must be divisible by {tile_increment} along "
629
+ f"all axes (got {t} for axis {i}). If your image size is "
630
+ f"smaller along one axis (e.g. Z), consider padding the "
631
+ f"image."
632
+ )
633
+
634
+ # tile overlaps must be specified
635
+ if tile_overlap is None:
636
+ raise ValueError("Tile overlap must be specified.")
637
+
638
+ # create the prediction
639
+ self.pred_datamodule = create_predict_datamodule(
640
+ pred_data=source,
641
+ data_type=data_type or self.cfg.data_config.data_type,
642
+ axes=axes or self.cfg.data_config.axes,
643
+ image_means=self.cfg.data_config.image_means,
644
+ image_stds=self.cfg.data_config.image_stds,
645
+ tile_size=tile_size,
646
+ tile_overlap=tile_overlap,
647
+ batch_size=batch_size or self.cfg.data_config.batch_size,
648
+ tta_transforms=tta_transforms,
649
+ read_source_func=read_source_func,
650
+ extension_filter=extension_filter,
651
+ dataloader_params=dataloader_params,
652
+ )
653
+
654
+ # predict
655
+ predictions = self.trainer.predict(
656
+ model=self.model, datamodule=self.pred_datamodule
657
+ )
658
+ return convert_outputs(predictions, self.pred_datamodule.tiled)
659
+
660
+ def export_to_bmz(
661
+ self,
662
+ path_to_archive: Union[Path, str],
663
+ friendly_model_name: str,
664
+ input_array: NDArray,
665
+ authors: list[dict],
666
+ general_description: str = "",
667
+ channel_names: Optional[list[str]] = None,
668
+ data_description: Optional[str] = None,
669
+ ) -> None:
670
+ """Export the model to the BioImage Model Zoo format.
671
+
672
+ This method packages the current weights into a zip file that can be uploaded
673
+ to the BioImage Model Zoo. The archive consists of the model weights, the model
674
+ specifications and various files (inputs, outputs, README, env.yaml etc.).
675
+
676
+ `path_to_archive` should point to a file with a ".zip" extension.
677
+
678
+ `friendly_model_name` is the name used for the model in the BMZ specs
679
+ and website, it should consist of letters, numbers, dashes, underscores and
680
+ parentheses only.
681
+
682
+ Input array must be of the same dimensions as the axes recorded in the
683
+ configuration of the `CAREamist`.
684
+
685
+ Parameters
686
+ ----------
687
+ path_to_archive : pathlib.Path or str
688
+ Path in which to save the model, including file name, which should end with
689
+ ".zip".
690
+ friendly_model_name : str
691
+ Name of the model as used in the BMZ specs, it should consist of letters,
692
+ numbers, dashes, underscores and parentheses only.
693
+ input_array : NDArray
694
+ Input array used to validate the model and as example.
695
+ authors : list of dict
696
+ List of authors of the model.
697
+ general_description : str
698
+ General description of the model, used in the metadata of the BMZ archive.
699
+ channel_names : list of str, optional
700
+ Channel names, by default None.
701
+ data_description : str, optional
702
+ Description of the data, by default None.
703
+ """
704
+ # TODO: add in docs that it is expected that input_array dimensions match
705
+ # those in data_config
706
+
707
+ output_patch = self.predict(
708
+ input_array,
709
+ data_type=SupportedData.ARRAY.value,
710
+ tta_transforms=False,
711
+ )
712
+ output = np.concatenate(output_patch, axis=0)
713
+ input_array = reshape_array(input_array, self.cfg.data_config.axes)
714
+
715
+ export_to_bmz(
716
+ model=self.model,
717
+ config=self.cfg,
718
+ path_to_archive=path_to_archive,
719
+ model_name=friendly_model_name,
720
+ general_description=general_description,
721
+ authors=authors,
722
+ input_array=input_array,
723
+ output_array=output,
724
+ channel_names=channel_names,
725
+ data_description=data_description,
726
+ )