careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
careamics/careamist.py ADDED
@@ -0,0 +1,729 @@
1
+ """A class to train, predict and export models in CAREamics."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Literal, Optional, Union, overload
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from pytorch_lightning import Trainer
9
+ from pytorch_lightning.callbacks import (
10
+ Callback,
11
+ EarlyStopping,
12
+ ModelCheckpoint,
13
+ )
14
+ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
+
16
+ from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
17
+ from careamics.config.support import (
18
+ SupportedAlgorithm,
19
+ SupportedArchitecture,
20
+ SupportedData,
21
+ SupportedLogger,
22
+ )
23
+ from careamics.dataset.dataset_utils import reshape_array
24
+ from careamics.lightning import (
25
+ FCNModule,
26
+ HyperParametersCallback,
27
+ PredictDataModule,
28
+ ProgressBarCallback,
29
+ TrainDataModule,
30
+ create_predict_datamodule,
31
+ )
32
+ from careamics.model_io import export_to_bmz, load_pretrained
33
+ from careamics.prediction_utils import convert_outputs
34
+ from careamics.utils import check_path_exists, get_logger
35
+
36
+ logger = get_logger(__name__)
37
+
38
+ LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
39
+
40
+
41
+ class CAREamist:
42
+ """Main CAREamics class, allowing training and prediction using various algorithms.
43
+
44
+ Parameters
45
+ ----------
46
+ source : pathlib.Path or str or CAREamics Configuration
47
+ Path to a configuration file or a trained model.
48
+ work_dir : str, optional
49
+ Path to working directory in which to save checkpoints and logs,
50
+ by default None.
51
+ experiment_name : str, by default "CAREamics"
52
+ Experiment name used for checkpoints.
53
+ callbacks : list of Callback, optional
54
+ List of callbacks to use during training and prediction, by default None.
55
+
56
+ Attributes
57
+ ----------
58
+ model : CAREamicsModule
59
+ CAREamics model.
60
+ cfg : Configuration
61
+ CAREamics configuration.
62
+ trainer : Trainer
63
+ PyTorch Lightning trainer.
64
+ experiment_logger : TensorBoardLogger or WandbLogger
65
+ Experiment logger, "wandb" or "tensorboard".
66
+ work_dir : pathlib.Path
67
+ Working directory.
68
+ train_datamodule : TrainDataModule
69
+ Training datamodule.
70
+ pred_datamodule : PredictDataModule
71
+ Prediction datamodule.
72
+ """
73
+
74
+ @overload
75
+ def __init__( # numpydoc ignore=GL08
76
+ self,
77
+ source: Union[Path, str],
78
+ work_dir: Optional[str] = None,
79
+ experiment_name: str = "CAREamics",
80
+ callbacks: Optional[list[Callback]] = None,
81
+ ) -> None: ...
82
+
83
+ @overload
84
+ def __init__( # numpydoc ignore=GL08
85
+ self,
86
+ source: Configuration,
87
+ work_dir: Optional[str] = None,
88
+ experiment_name: str = "CAREamics",
89
+ callbacks: Optional[list[Callback]] = None,
90
+ ) -> None: ...
91
+
92
+ def __init__(
93
+ self,
94
+ source: Union[Path, str, Configuration],
95
+ work_dir: Optional[Union[Path, str]] = None,
96
+ experiment_name: str = "CAREamics",
97
+ callbacks: Optional[list[Callback]] = None,
98
+ ) -> None:
99
+ """
100
+ Initialize CAREamist with a configuration object or a path.
101
+
102
+ A configuration object can be created using directly by calling `Configuration`,
103
+ using the configuration factory or loading a configuration from a yaml file.
104
+
105
+ Path can contain either a yaml file with parameters, or a saved checkpoint.
106
+
107
+ If no working directory is provided, the current working directory is used.
108
+
109
+ If `source` is a checkpoint, then `experiment_name` is used to name the
110
+ checkpoint, and is recorded in the configuration.
111
+
112
+ Parameters
113
+ ----------
114
+ source : pathlib.Path or str or CAREamics Configuration
115
+ Path to a configuration file or a trained model.
116
+ work_dir : str, optional
117
+ Path to working directory in which to save checkpoints and logs,
118
+ by default None.
119
+ experiment_name : str, optional
120
+ Experiment name used for checkpoints, by default "CAREamics".
121
+ callbacks : list of Callback, optional
122
+ List of callbacks to use during training and prediction, by default None.
123
+
124
+ Raises
125
+ ------
126
+ NotImplementedError
127
+ If the model is loaded from BioImage Model Zoo.
128
+ ValueError
129
+ If no hyper parameters are found in the checkpoint.
130
+ ValueError
131
+ If no data module hyper parameters are found in the checkpoint.
132
+ """
133
+ # select current working directory if work_dir is None
134
+ if work_dir is None:
135
+ self.work_dir = Path.cwd()
136
+ logger.warning(
137
+ f"No working directory provided. Using current working directory: "
138
+ f"{self.work_dir}."
139
+ )
140
+ else:
141
+ self.work_dir = Path(work_dir)
142
+
143
+ # configuration object
144
+ if isinstance(source, Configuration):
145
+ self.cfg = source
146
+
147
+ # instantiate model
148
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
149
+ self.model = FCNModule(
150
+ algorithm_config=self.cfg.algorithm_config,
151
+ )
152
+ else:
153
+ raise NotImplementedError("Architecture not supported.")
154
+
155
+ # path to configuration file or model
156
+ else:
157
+ source = check_path_exists(source)
158
+
159
+ # configuration file
160
+ if source.is_file() and (
161
+ source.suffix == ".yaml" or source.suffix == ".yml"
162
+ ):
163
+ # load configuration
164
+ self.cfg = load_configuration(source)
165
+
166
+ # instantiate model
167
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
168
+ self.model = FCNModule(
169
+ algorithm_config=self.cfg.algorithm_config,
170
+ ) # type: ignore
171
+ else:
172
+ raise NotImplementedError("Architecture not supported.")
173
+
174
+ # attempt loading a pre-trained model
175
+ else:
176
+ self.model, self.cfg = load_pretrained(source)
177
+
178
+ # define the checkpoint saving callback
179
+ self._define_callbacks(callbacks)
180
+
181
+ # instantiate logger
182
+ if self.cfg.training_config.has_logger():
183
+ if self.cfg.training_config.logger == SupportedLogger.WANDB:
184
+ self.experiment_logger: LOGGER_TYPES = WandbLogger(
185
+ name=self.cfg.experiment_name,
186
+ save_dir=self.work_dir / Path("logs"),
187
+ )
188
+ elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
189
+ self.experiment_logger = TensorBoardLogger(
190
+ save_dir=self.work_dir / Path("logs"),
191
+ )
192
+ else:
193
+ self.experiment_logger = None
194
+
195
+ # instantiate trainer
196
+ self.trainer = Trainer(
197
+ max_epochs=self.cfg.training_config.num_epochs,
198
+ callbacks=self.callbacks,
199
+ default_root_dir=self.work_dir,
200
+ logger=self.experiment_logger,
201
+ )
202
+
203
+ # place holder for the datamodules
204
+ self.train_datamodule: Optional[TrainDataModule] = None
205
+ self.pred_datamodule: Optional[PredictDataModule] = None
206
+
207
+ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
208
+ """Define the callbacks for the training loop.
209
+
210
+ Parameters
211
+ ----------
212
+ callbacks : list of Callback, optional
213
+ List of callbacks to use during training and prediction, by default None.
214
+ """
215
+ self.callbacks = [] if callbacks is None else callbacks
216
+
217
+ # check that user callbacks are not any of the CAREamics callbacks
218
+ for c in self.callbacks:
219
+ if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
220
+ raise ValueError(
221
+ "ModelCheckpoint and EarlyStopping callbacks are already defined "
222
+ "in CAREamics and should only be modified through the "
223
+ "training configuration (see TrainingConfig)."
224
+ )
225
+
226
+ if isinstance(c, HyperParametersCallback) or isinstance(
227
+ c, ProgressBarCallback
228
+ ):
229
+ raise ValueError(
230
+ "HyperParameter and ProgressBar callbacks are defined internally "
231
+ "and should not be passed as callbacks."
232
+ )
233
+
234
+ # checkpoint callback saves checkpoints during training
235
+ self.callbacks.extend(
236
+ [
237
+ HyperParametersCallback(self.cfg),
238
+ ModelCheckpoint(
239
+ dirpath=self.work_dir / Path("checkpoints"),
240
+ filename=self.cfg.experiment_name,
241
+ **self.cfg.training_config.checkpoint_callback.model_dump(),
242
+ ),
243
+ ProgressBarCallback(),
244
+ ]
245
+ )
246
+
247
+ # early stopping callback
248
+ if self.cfg.training_config.early_stopping_callback is not None:
249
+ self.callbacks.append(
250
+ EarlyStopping(self.cfg.training_config.early_stopping_callback)
251
+ )
252
+
253
+ # TODO: is there are more elegant way than calling train again after _train_on_paths
254
+ def train(
255
+ self,
256
+ *,
257
+ datamodule: Optional[TrainDataModule] = None,
258
+ train_source: Optional[Union[Path, str, NDArray]] = None,
259
+ val_source: Optional[Union[Path, str, NDArray]] = None,
260
+ train_target: Optional[Union[Path, str, NDArray]] = None,
261
+ val_target: Optional[Union[Path, str, NDArray]] = None,
262
+ use_in_memory: bool = True,
263
+ val_percentage: float = 0.1,
264
+ val_minimum_split: int = 1,
265
+ ) -> None:
266
+ """
267
+ Train the model on the provided data.
268
+
269
+ If a datamodule is provided, then training will be performed using it.
270
+ Alternatively, the training data can be provided as arrays or paths.
271
+
272
+ If `use_in_memory` is set to True, the source provided as Path or str will be
273
+ loaded in memory if it fits. Otherwise, training will be performed by loading
274
+ patches from the files one by one. Training on arrays is always performed
275
+ in memory.
276
+
277
+ If no validation source is provided, then the validation is extracted from
278
+ the training data using `val_percentage` and `val_minimum_split`. In the case
279
+ of data provided as Path or str, the percentage and minimum number are applied
280
+ to the number of files. For arrays, it is the number of patches.
281
+
282
+ Parameters
283
+ ----------
284
+ datamodule : TrainDataModule, optional
285
+ Datamodule to train on, by default None.
286
+ train_source : pathlib.Path or str or NDArray, optional
287
+ Train source, if no datamodule is provided, by default None.
288
+ val_source : pathlib.Path or str or NDArray, optional
289
+ Validation source, if no datamodule is provided, by default None.
290
+ train_target : pathlib.Path or str or NDArray, optional
291
+ Train target source, if no datamodule is provided, by default None.
292
+ val_target : pathlib.Path or str or NDArray, optional
293
+ Validation target source, if no datamodule is provided, by default None.
294
+ use_in_memory : bool, optional
295
+ Use in memory dataset if possible, by default True.
296
+ val_percentage : float, optional
297
+ Percentage of validation extracted from training data, by default 0.1.
298
+ val_minimum_split : int, optional
299
+ Minimum number of validation (patch or file) extracted from training data,
300
+ by default 1.
301
+
302
+ Raises
303
+ ------
304
+ ValueError
305
+ If both `datamodule` and `train_source` are provided.
306
+ ValueError
307
+ If sources are not of the same type (e.g. train is an array and val is
308
+ a Path).
309
+ ValueError
310
+ If the training target is provided to N2V.
311
+ ValueError
312
+ If neither a datamodule nor a source is provided.
313
+ """
314
+ if datamodule is not None and train_source is not None:
315
+ raise ValueError(
316
+ "Only one of `datamodule` and `train_source` can be provided."
317
+ )
318
+
319
+ # check that inputs are the same type
320
+ source_types = {
321
+ type(s)
322
+ for s in (train_source, val_source, train_target, val_target)
323
+ if s is not None
324
+ }
325
+ if len(source_types) > 1:
326
+ raise ValueError("All sources should be of the same type.")
327
+
328
+ # train
329
+ if datamodule is not None:
330
+ self._train_on_datamodule(datamodule=datamodule)
331
+
332
+ else:
333
+ # raise error if target is provided to N2V
334
+ if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
335
+ if train_target is not None:
336
+ raise ValueError(
337
+ "Training target not compatible with N2V training."
338
+ )
339
+
340
+ # dispatch the training
341
+ if isinstance(train_source, np.ndarray):
342
+ # mypy checks
343
+ assert isinstance(val_source, np.ndarray) or val_source is None
344
+ assert isinstance(train_target, np.ndarray) or train_target is None
345
+ assert isinstance(val_target, np.ndarray) or val_target is None
346
+
347
+ self._train_on_array(
348
+ train_source,
349
+ val_source,
350
+ train_target,
351
+ val_target,
352
+ val_percentage,
353
+ val_minimum_split,
354
+ )
355
+
356
+ elif isinstance(train_source, Path) or isinstance(train_source, str):
357
+ # mypy checks
358
+ assert (
359
+ isinstance(val_source, Path)
360
+ or isinstance(val_source, str)
361
+ or val_source is None
362
+ )
363
+ assert (
364
+ isinstance(train_target, Path)
365
+ or isinstance(train_target, str)
366
+ or train_target is None
367
+ )
368
+ assert (
369
+ isinstance(val_target, Path)
370
+ or isinstance(val_target, str)
371
+ or val_target is None
372
+ )
373
+
374
+ self._train_on_path(
375
+ train_source,
376
+ val_source,
377
+ train_target,
378
+ val_target,
379
+ use_in_memory,
380
+ val_percentage,
381
+ val_minimum_split,
382
+ )
383
+
384
+ else:
385
+ raise ValueError(
386
+ f"Invalid input, expected a str, Path, array or TrainDataModule "
387
+ f"instance (got {type(train_source)})."
388
+ )
389
+
390
+ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
391
+ """
392
+ Train the model on the provided datamodule.
393
+
394
+ Parameters
395
+ ----------
396
+ datamodule : TrainDataModule
397
+ Datamodule to train on.
398
+ """
399
+ # record datamodule
400
+ self.train_datamodule = datamodule
401
+
402
+ self.trainer.fit(self.model, datamodule=datamodule)
403
+
404
+ def _train_on_array(
405
+ self,
406
+ train_data: NDArray,
407
+ val_data: Optional[NDArray] = None,
408
+ train_target: Optional[NDArray] = None,
409
+ val_target: Optional[NDArray] = None,
410
+ val_percentage: float = 0.1,
411
+ val_minimum_split: int = 5,
412
+ ) -> None:
413
+ """
414
+ Train the model on the provided data arrays.
415
+
416
+ Parameters
417
+ ----------
418
+ train_data : NDArray
419
+ Training data.
420
+ val_data : NDArray, optional
421
+ Validation data, by default None.
422
+ train_target : NDArray, optional
423
+ Train target data, by default None.
424
+ val_target : NDArray, optional
425
+ Validation target data, by default None.
426
+ val_percentage : float, optional
427
+ Percentage of patches to use for validation, by default 0.1.
428
+ val_minimum_split : int, optional
429
+ Minimum number of patches to use for validation, by default 5.
430
+ """
431
+ # create datamodule
432
+ datamodule = TrainDataModule(
433
+ data_config=self.cfg.data_config,
434
+ train_data=train_data,
435
+ val_data=val_data,
436
+ train_data_target=train_target,
437
+ val_data_target=val_target,
438
+ val_percentage=val_percentage,
439
+ val_minimum_split=val_minimum_split,
440
+ )
441
+
442
+ # train
443
+ self.train(datamodule=datamodule)
444
+
445
+ def _train_on_path(
446
+ self,
447
+ path_to_train_data: Union[Path, str],
448
+ path_to_val_data: Optional[Union[Path, str]] = None,
449
+ path_to_train_target: Optional[Union[Path, str]] = None,
450
+ path_to_val_target: Optional[Union[Path, str]] = None,
451
+ use_in_memory: bool = True,
452
+ val_percentage: float = 0.1,
453
+ val_minimum_split: int = 1,
454
+ ) -> None:
455
+ """
456
+ Train the model on the provided data paths.
457
+
458
+ Parameters
459
+ ----------
460
+ path_to_train_data : pathlib.Path or str
461
+ Path to the training data.
462
+ path_to_val_data : pathlib.Path or str, optional
463
+ Path to validation data, by default None.
464
+ path_to_train_target : pathlib.Path or str, optional
465
+ Path to train target data, by default None.
466
+ path_to_val_target : pathlib.Path or str, optional
467
+ Path to validation target data, by default None.
468
+ use_in_memory : bool, optional
469
+ Use in memory dataset if possible, by default True.
470
+ val_percentage : float, optional
471
+ Percentage of files to use for validation, by default 0.1.
472
+ val_minimum_split : int, optional
473
+ Minimum number of files to use for validation, by default 1.
474
+ """
475
+ # sanity check on data (path exists)
476
+ path_to_train_data = check_path_exists(path_to_train_data)
477
+
478
+ if path_to_val_data is not None:
479
+ path_to_val_data = check_path_exists(path_to_val_data)
480
+
481
+ if path_to_train_target is not None:
482
+ path_to_train_target = check_path_exists(path_to_train_target)
483
+
484
+ if path_to_val_target is not None:
485
+ path_to_val_target = check_path_exists(path_to_val_target)
486
+
487
+ # create datamodule
488
+ datamodule = TrainDataModule(
489
+ data_config=self.cfg.data_config,
490
+ train_data=path_to_train_data,
491
+ val_data=path_to_val_data,
492
+ train_data_target=path_to_train_target,
493
+ val_data_target=path_to_val_target,
494
+ use_in_memory=use_in_memory,
495
+ val_percentage=val_percentage,
496
+ val_minimum_split=val_minimum_split,
497
+ )
498
+
499
+ # train
500
+ self.train(datamodule=datamodule)
501
+
502
+ @overload
503
+ def predict( # numpydoc ignore=GL08
504
+ self, source: PredictDataModule
505
+ ) -> Union[list[NDArray], NDArray]: ...
506
+
507
+ @overload
508
+ def predict( # numpydoc ignore=GL08
509
+ self,
510
+ source: Union[Path, str],
511
+ *,
512
+ batch_size: int = 1,
513
+ tile_size: Optional[tuple[int, ...]] = None,
514
+ tile_overlap: tuple[int, ...] = (48, 48),
515
+ axes: Optional[str] = None,
516
+ data_type: Optional[Literal["tiff", "custom"]] = None,
517
+ tta_transforms: bool = True,
518
+ dataloader_params: Optional[dict] = None,
519
+ read_source_func: Optional[Callable] = None,
520
+ extension_filter: str = "",
521
+ ) -> Union[list[NDArray], NDArray]: ...
522
+
523
+ @overload
524
+ def predict( # numpydoc ignore=GL08
525
+ self,
526
+ source: NDArray,
527
+ *,
528
+ batch_size: int = 1,
529
+ tile_size: Optional[tuple[int, ...]] = None,
530
+ tile_overlap: tuple[int, ...] = (48, 48),
531
+ axes: Optional[str] = None,
532
+ data_type: Optional[Literal["array"]] = None,
533
+ tta_transforms: bool = True,
534
+ dataloader_params: Optional[dict] = None,
535
+ ) -> Union[list[NDArray], NDArray]: ...
536
+
537
+ def predict(
538
+ self,
539
+ source: Union[PredictDataModule, Path, str, NDArray],
540
+ *,
541
+ batch_size: Optional[int] = None,
542
+ tile_size: Optional[tuple[int, ...]] = None,
543
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
544
+ axes: Optional[str] = None,
545
+ data_type: Optional[Literal["array", "tiff", "custom"]] = None,
546
+ tta_transforms: bool = True,
547
+ dataloader_params: Optional[dict] = None,
548
+ read_source_func: Optional[Callable] = None,
549
+ extension_filter: str = "",
550
+ **kwargs: Any,
551
+ ) -> Union[list[NDArray], NDArray]:
552
+ """
553
+ Make predictions on the provided data.
554
+
555
+ Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
556
+ array.
557
+
558
+ If `data_type`, `axes` and `tile_size` are not provided, the training
559
+ configuration parameters will be used, with the `patch_size` instead of
560
+ `tile_size`.
561
+
562
+ Test-time augmentation (TTA) can be switched off using the `tta_transforms`
563
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
564
+ rotations to the prediction input and averages the predictions. TTA augmentation
565
+ should not be used if you did not train with these augmentations.
566
+
567
+ Note that if you are using a UNet model and tiling, the tile size must be
568
+ divisible in every dimension by 2**d, where d is the depth of the model. This
569
+ avoids artefacts arising from the broken shift invariance induced by the
570
+ pooling layers of the UNet. If your image has less dimensions, as it may
571
+ happen in the Z dimension, consider padding your image.
572
+
573
+ Parameters
574
+ ----------
575
+ source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
576
+ Data to predict on.
577
+ batch_size : int, default=1
578
+ Batch size for prediction.
579
+ tile_size : tuple of int, optional
580
+ Size of the tiles to use for prediction.
581
+ tile_overlap : tuple of int, default=(48, 48)
582
+ Overlap between tiles, can be None.
583
+ axes : str, optional
584
+ Axes of the input data, by default None.
585
+ data_type : {"array", "tiff", "custom"}, optional
586
+ Type of the input data.
587
+ tta_transforms : bool, default=True
588
+ Whether to apply test-time augmentation.
589
+ dataloader_params : dict, optional
590
+ Parameters to pass to the dataloader.
591
+ read_source_func : Callable, optional
592
+ Function to read the source data.
593
+ extension_filter : str, default=""
594
+ Filter for the file extension.
595
+ **kwargs : Any
596
+ Unused.
597
+
598
+ Returns
599
+ -------
600
+ list of NDArray or NDArray
601
+ Predictions made by the model.
602
+
603
+ Raises
604
+ ------
605
+ ValueError
606
+ If mean and std are not provided in the configuration.
607
+ ValueError
608
+ If tile size is not divisible by 2**depth for UNet models.
609
+ ValueError
610
+ If tile overlap is not specified.
611
+ """
612
+ if (
613
+ self.cfg.data_config.image_means is None
614
+ or self.cfg.data_config.image_stds is None
615
+ ):
616
+ raise ValueError("Mean and std must be provided in the configuration.")
617
+
618
+ # tile size for UNets
619
+ if tile_size is not None:
620
+ model = self.cfg.algorithm_config.model
621
+
622
+ if model.architecture == SupportedArchitecture.UNET.value:
623
+ # tile size must be equal to k*2^n, where n is the number of pooling
624
+ # layers (equal to the depth) and k is an integer
625
+ depth = model.depth
626
+ tile_increment = 2**depth
627
+
628
+ for i, t in enumerate(tile_size):
629
+ if t % tile_increment != 0:
630
+ raise ValueError(
631
+ f"Tile size must be divisible by {tile_increment} along "
632
+ f"all axes (got {t} for axis {i}). If your image size is "
633
+ f"smaller along one axis (e.g. Z), consider padding the "
634
+ f"image."
635
+ )
636
+
637
+ # tile overlaps must be specified
638
+ if tile_overlap is None:
639
+ raise ValueError("Tile overlap must be specified.")
640
+
641
+ # create the prediction
642
+ self.pred_datamodule = create_predict_datamodule(
643
+ pred_data=source,
644
+ data_type=data_type or self.cfg.data_config.data_type,
645
+ axes=axes or self.cfg.data_config.axes,
646
+ image_means=self.cfg.data_config.image_means,
647
+ image_stds=self.cfg.data_config.image_stds,
648
+ tile_size=tile_size,
649
+ tile_overlap=tile_overlap,
650
+ batch_size=batch_size or self.cfg.data_config.batch_size,
651
+ tta_transforms=tta_transforms,
652
+ read_source_func=read_source_func,
653
+ extension_filter=extension_filter,
654
+ dataloader_params=dataloader_params,
655
+ )
656
+
657
+ # predict
658
+ predictions = self.trainer.predict(
659
+ model=self.model, datamodule=self.pred_datamodule
660
+ )
661
+ return convert_outputs(predictions, self.pred_datamodule.tiled)
662
+
663
+ def export_to_bmz(
664
+ self,
665
+ path_to_archive: Union[Path, str],
666
+ friendly_model_name: str,
667
+ input_array: NDArray,
668
+ authors: list[dict],
669
+ general_description: str = "",
670
+ channel_names: Optional[list[str]] = None,
671
+ data_description: Optional[str] = None,
672
+ ) -> None:
673
+ """Export the model to the BioImage Model Zoo format.
674
+
675
+ This method packages the current weights into a zip file that can be uploaded
676
+ to the BioImage Model Zoo. The archive consists of the model weights, the model
677
+ specifications and various files (inputs, outputs, README, env.yaml etc.).
678
+
679
+ `path_to_archive` should point to a file with a ".zip" extension.
680
+
681
+ `friendly_model_name` is the name used for the model in the BMZ specs
682
+ and website, it should consist of letters, numbers, dashes, underscores and
683
+ parentheses only.
684
+
685
+ Input array must be of the same dimensions as the axes recorded in the
686
+ configuration of the `CAREamist`.
687
+
688
+ Parameters
689
+ ----------
690
+ path_to_archive : pathlib.Path or str
691
+ Path in which to save the model, including file name, which should end with
692
+ ".zip".
693
+ friendly_model_name : str
694
+ Name of the model as used in the BMZ specs, it should consist of letters,
695
+ numbers, dashes, underscores and parentheses only.
696
+ input_array : NDArray
697
+ Input array used to validate the model and as example.
698
+ authors : list of dict
699
+ List of authors of the model.
700
+ general_description : str
701
+ General description of the model, used in the metadata of the BMZ archive.
702
+ channel_names : list of str, optional
703
+ Channel names, by default None.
704
+ data_description : str, optional
705
+ Description of the data, by default None.
706
+ """
707
+ # TODO: add in docs that it is expected that input_array dimensions match
708
+ # those in data_config
709
+
710
+ output_patch = self.predict(
711
+ input_array,
712
+ data_type=SupportedData.ARRAY.value,
713
+ tta_transforms=False,
714
+ )
715
+ output = np.concatenate(output_patch, axis=0)
716
+ input_array = reshape_array(input_array, self.cfg.data_config.axes)
717
+
718
+ export_to_bmz(
719
+ model=self.model,
720
+ config=self.cfg,
721
+ path_to_archive=path_to_archive,
722
+ model_name=friendly_model_name,
723
+ general_description=general_description,
724
+ authors=authors,
725
+ input_array=input_array,
726
+ output_array=output,
727
+ channel_names=channel_names,
728
+ data_description=data_description,
729
+ )