careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
"""A class to train, predict and export models in CAREamics."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Literal, Optional, Union, overload
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
7
8
|
from pytorch_lightning import Trainer
|
|
8
9
|
from pytorch_lightning.callbacks import (
|
|
9
10
|
Callback,
|
|
@@ -15,56 +16,54 @@ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
|
15
16
|
from careamics.callbacks import ProgressBarCallback
|
|
16
17
|
from careamics.config import (
|
|
17
18
|
Configuration,
|
|
18
|
-
create_inference_configuration,
|
|
19
19
|
load_configuration,
|
|
20
20
|
)
|
|
21
21
|
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
|
|
22
22
|
from careamics.dataset.dataset_utils import reshape_array
|
|
23
23
|
from careamics.lightning_datamodule import CAREamicsTrainData
|
|
24
24
|
from careamics.lightning_module import CAREamicsModule
|
|
25
|
-
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
|
|
26
|
-
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
|
|
27
25
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
28
|
-
from careamics.
|
|
26
|
+
from careamics.prediction_utils import convert_outputs, create_pred_datamodule
|
|
29
27
|
from careamics.utils import check_path_exists, get_logger
|
|
30
28
|
|
|
31
29
|
from .callbacks import HyperParametersCallback
|
|
30
|
+
from .lightning_prediction_datamodule import CAREamicsPredictData
|
|
32
31
|
|
|
33
32
|
logger = get_logger(__name__)
|
|
34
33
|
|
|
35
34
|
LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
|
|
36
35
|
|
|
37
36
|
|
|
38
|
-
# TODO napari callbacks
|
|
39
|
-
# TODO: how to do AMP? How to continue training?
|
|
40
37
|
class CAREamist:
|
|
41
38
|
"""Main CAREamics class, allowing training and prediction using various algorithms.
|
|
42
39
|
|
|
43
40
|
Parameters
|
|
44
41
|
----------
|
|
45
|
-
source :
|
|
42
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
46
43
|
Path to a configuration file or a trained model.
|
|
47
|
-
work_dir :
|
|
44
|
+
work_dir : str, optional
|
|
48
45
|
Path to working directory in which to save checkpoints and logs,
|
|
49
46
|
by default None.
|
|
50
|
-
experiment_name : str,
|
|
51
|
-
Experiment name used for checkpoints
|
|
47
|
+
experiment_name : str, by default "CAREamics"
|
|
48
|
+
Experiment name used for checkpoints.
|
|
49
|
+
callbacks : list of Callback, optional
|
|
50
|
+
List of callbacks to use during training and prediction, by default None.
|
|
52
51
|
|
|
53
52
|
Attributes
|
|
54
53
|
----------
|
|
55
|
-
model :
|
|
54
|
+
model : CAREamicsModule
|
|
56
55
|
CAREamics model.
|
|
57
56
|
cfg : Configuration
|
|
58
57
|
CAREamics configuration.
|
|
59
58
|
trainer : Trainer
|
|
60
59
|
PyTorch Lightning trainer.
|
|
61
|
-
experiment_logger :
|
|
60
|
+
experiment_logger : TensorBoardLogger or WandbLogger
|
|
62
61
|
Experiment logger, "wandb" or "tensorboard".
|
|
63
|
-
work_dir : Path
|
|
62
|
+
work_dir : pathlib.Path
|
|
64
63
|
Working directory.
|
|
65
|
-
train_datamodule :
|
|
64
|
+
train_datamodule : CAREamicsTrainData
|
|
66
65
|
Training datamodule.
|
|
67
|
-
pred_datamodule :
|
|
66
|
+
pred_datamodule : CAREamicsPredictData
|
|
68
67
|
Prediction datamodule.
|
|
69
68
|
"""
|
|
70
69
|
|
|
@@ -74,6 +73,7 @@ class CAREamist:
|
|
|
74
73
|
source: Union[Path, str],
|
|
75
74
|
work_dir: Optional[str] = None,
|
|
76
75
|
experiment_name: str = "CAREamics",
|
|
76
|
+
callbacks: Optional[list[Callback]] = None,
|
|
77
77
|
) -> None: ...
|
|
78
78
|
|
|
79
79
|
@overload
|
|
@@ -82,6 +82,7 @@ class CAREamist:
|
|
|
82
82
|
source: Configuration,
|
|
83
83
|
work_dir: Optional[str] = None,
|
|
84
84
|
experiment_name: str = "CAREamics",
|
|
85
|
+
callbacks: Optional[list[Callback]] = None,
|
|
85
86
|
) -> None: ...
|
|
86
87
|
|
|
87
88
|
def __init__(
|
|
@@ -89,6 +90,7 @@ class CAREamist:
|
|
|
89
90
|
source: Union[Path, str, Configuration],
|
|
90
91
|
work_dir: Optional[Union[Path, str]] = None,
|
|
91
92
|
experiment_name: str = "CAREamics",
|
|
93
|
+
callbacks: Optional[list[Callback]] = None,
|
|
92
94
|
) -> None:
|
|
93
95
|
"""
|
|
94
96
|
Initialize CAREamist with a configuration object or a path.
|
|
@@ -105,13 +107,15 @@ class CAREamist:
|
|
|
105
107
|
|
|
106
108
|
Parameters
|
|
107
109
|
----------
|
|
108
|
-
source :
|
|
110
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
109
111
|
Path to a configuration file or a trained model.
|
|
110
|
-
work_dir :
|
|
112
|
+
work_dir : str, optional
|
|
111
113
|
Path to working directory in which to save checkpoints and logs,
|
|
112
114
|
by default None.
|
|
113
115
|
experiment_name : str, optional
|
|
114
116
|
Experiment name used for checkpoints, by default "CAREamics".
|
|
117
|
+
callbacks : list of Callback, optional
|
|
118
|
+
List of callbacks to use during training and prediction, by default None.
|
|
115
119
|
|
|
116
120
|
Raises
|
|
117
121
|
------
|
|
@@ -164,7 +168,7 @@ class CAREamist:
|
|
|
164
168
|
self.model, self.cfg = load_pretrained(source)
|
|
165
169
|
|
|
166
170
|
# define the checkpoint saving callback
|
|
167
|
-
self.
|
|
171
|
+
self._define_callbacks(callbacks)
|
|
168
172
|
|
|
169
173
|
# instantiate logger
|
|
170
174
|
if self.cfg.training_config.has_logger():
|
|
@@ -188,32 +192,50 @@ class CAREamist:
|
|
|
188
192
|
logger=self.experiment_logger,
|
|
189
193
|
)
|
|
190
194
|
|
|
191
|
-
# change the prediction loop, necessary for tiled prediction
|
|
192
|
-
self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
|
|
193
|
-
|
|
194
195
|
# place holder for the datamodules
|
|
195
196
|
self.train_datamodule: Optional[CAREamicsTrainData] = None
|
|
196
197
|
self.pred_datamodule: Optional[CAREamicsPredictData] = None
|
|
197
198
|
|
|
198
|
-
def _define_callbacks(self
|
|
199
|
+
def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
|
|
199
200
|
"""
|
|
200
201
|
Define the callbacks for the training loop.
|
|
201
202
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
List of callbacks to
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
callbacks : list of Callback, optional
|
|
206
|
+
List of callbacks to use during training and prediction, by default None.
|
|
206
207
|
"""
|
|
208
|
+
self.callbacks = [] if callbacks is None else callbacks
|
|
209
|
+
|
|
210
|
+
# check that user callbacks are not any of the CAREamics callbacks
|
|
211
|
+
for c in self.callbacks:
|
|
212
|
+
if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"ModelCheckpoint and EarlyStopping callbacks are already defined "
|
|
215
|
+
"in CAREamics and should only be modified through the "
|
|
216
|
+
"training configuration (see TrainingConfig)."
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if isinstance(c, HyperParametersCallback) or isinstance(
|
|
220
|
+
c, ProgressBarCallback
|
|
221
|
+
):
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"HyperParameter and ProgressBar callbacks are defined internally "
|
|
224
|
+
"and should not be passed as callbacks."
|
|
225
|
+
)
|
|
226
|
+
|
|
207
227
|
# checkpoint callback saves checkpoints during training
|
|
208
|
-
self.callbacks
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
228
|
+
self.callbacks.extend(
|
|
229
|
+
[
|
|
230
|
+
HyperParametersCallback(self.cfg),
|
|
231
|
+
ModelCheckpoint(
|
|
232
|
+
dirpath=self.work_dir / Path("checkpoints"),
|
|
233
|
+
filename=self.cfg.experiment_name,
|
|
234
|
+
**self.cfg.training_config.checkpoint_callback.model_dump(),
|
|
235
|
+
),
|
|
236
|
+
ProgressBarCallback(),
|
|
237
|
+
]
|
|
238
|
+
)
|
|
217
239
|
|
|
218
240
|
# early stopping callback
|
|
219
241
|
if self.cfg.training_config.early_stopping_callback is not None:
|
|
@@ -221,16 +243,14 @@ class CAREamist:
|
|
|
221
243
|
EarlyStopping(self.cfg.training_config.early_stopping_callback)
|
|
222
244
|
)
|
|
223
245
|
|
|
224
|
-
return self.callbacks
|
|
225
|
-
|
|
226
246
|
def train(
|
|
227
247
|
self,
|
|
228
248
|
*,
|
|
229
249
|
datamodule: Optional[CAREamicsTrainData] = None,
|
|
230
|
-
train_source: Optional[Union[Path, str,
|
|
231
|
-
val_source: Optional[Union[Path, str,
|
|
232
|
-
train_target: Optional[Union[Path, str,
|
|
233
|
-
val_target: Optional[Union[Path, str,
|
|
250
|
+
train_source: Optional[Union[Path, str, NDArray]] = None,
|
|
251
|
+
val_source: Optional[Union[Path, str, NDArray]] = None,
|
|
252
|
+
train_target: Optional[Union[Path, str, NDArray]] = None,
|
|
253
|
+
val_target: Optional[Union[Path, str, NDArray]] = None,
|
|
234
254
|
use_in_memory: bool = True,
|
|
235
255
|
val_percentage: float = 0.1,
|
|
236
256
|
val_minimum_split: int = 1,
|
|
@@ -253,15 +273,15 @@ class CAREamist:
|
|
|
253
273
|
|
|
254
274
|
Parameters
|
|
255
275
|
----------
|
|
256
|
-
datamodule :
|
|
276
|
+
datamodule : CAREamicsTrainData, optional
|
|
257
277
|
Datamodule to train on, by default None.
|
|
258
|
-
train_source :
|
|
278
|
+
train_source : pathlib.Path or str or NDArray, optional
|
|
259
279
|
Train source, if no datamodule is provided, by default None.
|
|
260
|
-
val_source :
|
|
280
|
+
val_source : pathlib.Path or str or NDArray, optional
|
|
261
281
|
Validation source, if no datamodule is provided, by default None.
|
|
262
|
-
train_target :
|
|
282
|
+
train_target : pathlib.Path or str or NDArray, optional
|
|
263
283
|
Train target source, if no datamodule is provided, by default None.
|
|
264
|
-
val_target :
|
|
284
|
+
val_target : pathlib.Path or str or NDArray, optional
|
|
265
285
|
Validation target source, if no datamodule is provided, by default None.
|
|
266
286
|
use_in_memory : bool, optional
|
|
267
287
|
Use in memory dataset if possible, by default True.
|
|
@@ -355,7 +375,7 @@ class CAREamist:
|
|
|
355
375
|
|
|
356
376
|
else:
|
|
357
377
|
raise ValueError(
|
|
358
|
-
f"Invalid input, expected a str, Path, array or
|
|
378
|
+
f"Invalid input, expected a str, Path, array or CAREamicsTrainData "
|
|
359
379
|
f"instance (got {type(train_source)})."
|
|
360
380
|
)
|
|
361
381
|
|
|
@@ -365,7 +385,7 @@ class CAREamist:
|
|
|
365
385
|
|
|
366
386
|
Parameters
|
|
367
387
|
----------
|
|
368
|
-
datamodule :
|
|
388
|
+
datamodule : CAREamicsTrainData
|
|
369
389
|
Datamodule to train on.
|
|
370
390
|
"""
|
|
371
391
|
# record datamodule
|
|
@@ -375,10 +395,10 @@ class CAREamist:
|
|
|
375
395
|
|
|
376
396
|
def _train_on_array(
|
|
377
397
|
self,
|
|
378
|
-
train_data:
|
|
379
|
-
val_data: Optional[
|
|
380
|
-
train_target: Optional[
|
|
381
|
-
val_target: Optional[
|
|
398
|
+
train_data: NDArray,
|
|
399
|
+
val_data: Optional[NDArray] = None,
|
|
400
|
+
train_target: Optional[NDArray] = None,
|
|
401
|
+
val_target: Optional[NDArray] = None,
|
|
382
402
|
val_percentage: float = 0.1,
|
|
383
403
|
val_minimum_split: int = 5,
|
|
384
404
|
) -> None:
|
|
@@ -387,13 +407,13 @@ class CAREamist:
|
|
|
387
407
|
|
|
388
408
|
Parameters
|
|
389
409
|
----------
|
|
390
|
-
train_data :
|
|
410
|
+
train_data : NDArray
|
|
391
411
|
Training data.
|
|
392
|
-
val_data :
|
|
412
|
+
val_data : NDArray, optional
|
|
393
413
|
Validation data, by default None.
|
|
394
|
-
train_target :
|
|
414
|
+
train_target : NDArray, optional
|
|
395
415
|
Train target data, by default None.
|
|
396
|
-
val_target :
|
|
416
|
+
val_target : NDArray, optional
|
|
397
417
|
Validation target data, by default None.
|
|
398
418
|
val_percentage : float, optional
|
|
399
419
|
Percentage of patches to use for validation, by default 0.1.
|
|
@@ -429,13 +449,13 @@ class CAREamist:
|
|
|
429
449
|
|
|
430
450
|
Parameters
|
|
431
451
|
----------
|
|
432
|
-
path_to_train_data :
|
|
452
|
+
path_to_train_data : pathlib.Path or str
|
|
433
453
|
Path to the training data.
|
|
434
|
-
path_to_val_data :
|
|
454
|
+
path_to_val_data : pathlib.Path or str, optional
|
|
435
455
|
Path to validation data, by default None.
|
|
436
|
-
path_to_train_target :
|
|
456
|
+
path_to_train_target : pathlib.Path or str, optional
|
|
437
457
|
Path to train target data, by default None.
|
|
438
|
-
path_to_val_target :
|
|
458
|
+
path_to_val_target : pathlib.Path or str, optional
|
|
439
459
|
Path to validation target data, by default None.
|
|
440
460
|
use_in_memory : bool, optional
|
|
441
461
|
Use in memory dataset if possible, by default True.
|
|
@@ -477,7 +497,7 @@ class CAREamist:
|
|
|
477
497
|
source: CAREamicsPredictData,
|
|
478
498
|
*,
|
|
479
499
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
480
|
-
) -> Union[list,
|
|
500
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
481
501
|
|
|
482
502
|
@overload
|
|
483
503
|
def predict( # numpydoc ignore=GL08
|
|
@@ -485,59 +505,62 @@ class CAREamist:
|
|
|
485
505
|
source: Union[Path, str],
|
|
486
506
|
*,
|
|
487
507
|
batch_size: int = 1,
|
|
488
|
-
tile_size: Optional[
|
|
489
|
-
tile_overlap:
|
|
508
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
509
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
490
510
|
axes: Optional[str] = None,
|
|
491
511
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
492
512
|
tta_transforms: bool = True,
|
|
493
|
-
dataloader_params: Optional[
|
|
513
|
+
dataloader_params: Optional[dict] = None,
|
|
494
514
|
read_source_func: Optional[Callable] = None,
|
|
495
515
|
extension_filter: str = "",
|
|
496
516
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
497
|
-
) -> Union[list,
|
|
517
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
498
518
|
|
|
499
519
|
@overload
|
|
500
520
|
def predict( # numpydoc ignore=GL08
|
|
501
521
|
self,
|
|
502
|
-
source:
|
|
522
|
+
source: NDArray,
|
|
503
523
|
*,
|
|
504
524
|
batch_size: int = 1,
|
|
505
|
-
tile_size: Optional[
|
|
506
|
-
tile_overlap:
|
|
525
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
526
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
507
527
|
axes: Optional[str] = None,
|
|
508
528
|
data_type: Optional[Literal["array"]] = None,
|
|
509
529
|
tta_transforms: bool = True,
|
|
510
|
-
dataloader_params: Optional[
|
|
530
|
+
dataloader_params: Optional[dict] = None,
|
|
511
531
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
512
|
-
) -> Union[list,
|
|
532
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
513
533
|
|
|
514
534
|
def predict(
|
|
515
535
|
self,
|
|
516
|
-
source: Union[CAREamicsPredictData, Path, str,
|
|
536
|
+
source: Union[CAREamicsPredictData, Path, str, NDArray],
|
|
517
537
|
*,
|
|
518
|
-
batch_size: int =
|
|
519
|
-
tile_size: Optional[
|
|
520
|
-
tile_overlap:
|
|
538
|
+
batch_size: Optional[int] = None,
|
|
539
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
540
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
521
541
|
axes: Optional[str] = None,
|
|
522
542
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
523
543
|
tta_transforms: bool = True,
|
|
524
|
-
dataloader_params: Optional[
|
|
544
|
+
dataloader_params: Optional[dict] = None,
|
|
525
545
|
read_source_func: Optional[Callable] = None,
|
|
526
546
|
extension_filter: str = "",
|
|
527
547
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
528
548
|
**kwargs: Any,
|
|
529
|
-
) -> Union[
|
|
549
|
+
) -> Union[list[NDArray], NDArray]:
|
|
530
550
|
"""
|
|
531
551
|
Make predictions on the provided data.
|
|
532
552
|
|
|
533
|
-
Input can be a
|
|
553
|
+
Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
|
|
554
|
+
array.
|
|
534
555
|
|
|
535
556
|
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
536
557
|
configuration parameters will be used, with the `patch_size` instead of
|
|
537
558
|
`tile_size`.
|
|
538
559
|
|
|
539
560
|
Test-time augmentation (TTA) can be switched off using the `tta_transforms`
|
|
540
|
-
parameter.
|
|
561
|
+
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
562
|
+
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
563
|
+
should not be used if you did not train with these augmentations.
|
|
541
564
|
|
|
542
565
|
Note that if you are using a UNet model and tiling, the tile size must be
|
|
543
566
|
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
@@ -547,221 +570,96 @@ class CAREamist:
|
|
|
547
570
|
|
|
548
571
|
Parameters
|
|
549
572
|
----------
|
|
550
|
-
source :
|
|
573
|
+
source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
|
|
551
574
|
Data to predict on.
|
|
552
|
-
batch_size : int,
|
|
553
|
-
Batch size for prediction
|
|
554
|
-
tile_size :
|
|
555
|
-
Size of the tiles to use for prediction
|
|
556
|
-
tile_overlap :
|
|
557
|
-
Overlap between tiles
|
|
558
|
-
axes :
|
|
575
|
+
batch_size : int, default=1
|
|
576
|
+
Batch size for prediction.
|
|
577
|
+
tile_size : tuple of int, optional
|
|
578
|
+
Size of the tiles to use for prediction.
|
|
579
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
580
|
+
Overlap between tiles.
|
|
581
|
+
axes : str, optional
|
|
559
582
|
Axes of the input data, by default None.
|
|
560
|
-
data_type :
|
|
561
|
-
Type of the input data
|
|
562
|
-
tta_transforms : bool,
|
|
563
|
-
Whether to apply test-time augmentation
|
|
564
|
-
dataloader_params :
|
|
565
|
-
Parameters to pass to the dataloader
|
|
566
|
-
read_source_func :
|
|
567
|
-
Function to read the source data
|
|
568
|
-
extension_filter : str,
|
|
569
|
-
Filter for the file extension
|
|
570
|
-
checkpoint :
|
|
571
|
-
Checkpoint to use for prediction
|
|
583
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
584
|
+
Type of the input data.
|
|
585
|
+
tta_transforms : bool, default=True
|
|
586
|
+
Whether to apply test-time augmentation.
|
|
587
|
+
dataloader_params : dict, optional
|
|
588
|
+
Parameters to pass to the dataloader.
|
|
589
|
+
read_source_func : Callable, optional
|
|
590
|
+
Function to read the source data.
|
|
591
|
+
extension_filter : str, default=""
|
|
592
|
+
Filter for the file extension.
|
|
593
|
+
checkpoint : {"best", "last"}, optional
|
|
594
|
+
Checkpoint to use for prediction.
|
|
572
595
|
**kwargs : Any
|
|
573
596
|
Unused.
|
|
574
597
|
|
|
575
598
|
Returns
|
|
576
599
|
-------
|
|
577
|
-
|
|
600
|
+
list of NDArray or NDArray
|
|
578
601
|
Predictions made by the model.
|
|
579
|
-
|
|
580
|
-
Raises
|
|
581
|
-
------
|
|
582
|
-
ValueError
|
|
583
|
-
If the input is not a CAREamicsClay instance, a path or a numpy array.
|
|
584
602
|
"""
|
|
585
|
-
if
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
603
|
+
# Reuse batch size if not provided explicitly
|
|
604
|
+
if batch_size is None:
|
|
605
|
+
batch_size = (
|
|
606
|
+
self.train_datamodule.batch_size
|
|
607
|
+
if self.train_datamodule
|
|
608
|
+
else self.cfg.data_config.batch_size
|
|
591
609
|
)
|
|
592
|
-
else:
|
|
593
|
-
if self.cfg is None:
|
|
594
|
-
raise ValueError(
|
|
595
|
-
"No configuration found. Train a model or load from a "
|
|
596
|
-
"checkpoint before predicting."
|
|
597
|
-
)
|
|
598
|
-
# create predict config, reuse training config if parameters missing
|
|
599
|
-
prediction_config = create_inference_configuration(
|
|
600
|
-
configuration=self.cfg,
|
|
601
|
-
tile_size=tile_size,
|
|
602
|
-
tile_overlap=tile_overlap,
|
|
603
|
-
data_type=data_type,
|
|
604
|
-
axes=axes,
|
|
605
|
-
tta_transforms=tta_transforms,
|
|
606
|
-
batch_size=batch_size,
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
# remove batch from dataloader parameters (priority given to config)
|
|
610
|
-
if dataloader_params is None:
|
|
611
|
-
dataloader_params = {}
|
|
612
|
-
if "batch_size" in dataloader_params:
|
|
613
|
-
del dataloader_params["batch_size"]
|
|
614
|
-
|
|
615
|
-
if isinstance(source, Path) or isinstance(source, str):
|
|
616
|
-
# Check the source
|
|
617
|
-
source_path = check_path_exists(source)
|
|
618
|
-
|
|
619
|
-
# create datamodule
|
|
620
|
-
datamodule = CAREamicsPredictData(
|
|
621
|
-
pred_config=prediction_config,
|
|
622
|
-
pred_data=source_path,
|
|
623
|
-
read_source_func=read_source_func,
|
|
624
|
-
extension_filter=extension_filter,
|
|
625
|
-
dataloader_params=dataloader_params,
|
|
626
|
-
)
|
|
627
|
-
|
|
628
|
-
# record datamodule
|
|
629
|
-
self.pred_datamodule = datamodule
|
|
630
|
-
|
|
631
|
-
return self.trainer.predict(
|
|
632
|
-
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
633
|
-
)
|
|
634
|
-
|
|
635
|
-
elif isinstance(source, np.ndarray):
|
|
636
|
-
# create datamodule
|
|
637
|
-
datamodule = CAREamicsPredictData(
|
|
638
|
-
pred_config=prediction_config,
|
|
639
|
-
pred_data=source,
|
|
640
|
-
dataloader_params=dataloader_params,
|
|
641
|
-
)
|
|
642
|
-
|
|
643
|
-
# record datamodule
|
|
644
|
-
self.pred_datamodule = datamodule
|
|
645
|
-
|
|
646
|
-
return self.trainer.predict(
|
|
647
|
-
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
648
|
-
)
|
|
649
|
-
|
|
650
|
-
else:
|
|
651
|
-
raise ValueError(
|
|
652
|
-
f"Invalid input. Expected a CAREamicsWood instance, paths or "
|
|
653
|
-
f"np.ndarray (got {type(source)})."
|
|
654
|
-
)
|
|
655
|
-
|
|
656
|
-
def _create_data_for_bmz(
|
|
657
|
-
self,
|
|
658
|
-
input_array: Optional[np.ndarray] = None,
|
|
659
|
-
) -> np.ndarray:
|
|
660
|
-
"""Create data for BMZ export.
|
|
661
|
-
|
|
662
|
-
If no `input_array` is provided, this method checks if there is a prediction
|
|
663
|
-
datamodule, or a training data module, to extract a patch. If none exists,
|
|
664
|
-
then a random aray is created.
|
|
665
|
-
|
|
666
|
-
If there is a non-singleton batch dimension, this method returns only the first
|
|
667
|
-
element.
|
|
668
|
-
|
|
669
|
-
Parameters
|
|
670
|
-
----------
|
|
671
|
-
input_array : Optional[np.ndarray], optional
|
|
672
|
-
Input array, by default None.
|
|
673
|
-
|
|
674
|
-
Returns
|
|
675
|
-
-------
|
|
676
|
-
np.ndarray
|
|
677
|
-
Input data for BMZ export.
|
|
678
|
-
|
|
679
|
-
Raises
|
|
680
|
-
------
|
|
681
|
-
ValueError
|
|
682
|
-
If mean and std are not provided in the configuration.
|
|
683
|
-
"""
|
|
684
|
-
if input_array is None:
|
|
685
|
-
if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
|
|
686
|
-
raise ValueError(
|
|
687
|
-
"Mean and std cannot be None in the configuration in order to"
|
|
688
|
-
"export to the BMZ format. Was the model trained?"
|
|
689
|
-
)
|
|
690
|
-
|
|
691
|
-
# generate images, priority is given to the prediction data module
|
|
692
|
-
if self.pred_datamodule is not None:
|
|
693
|
-
# unpack a batch, ignore masks or targets
|
|
694
|
-
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
|
|
695
|
-
|
|
696
|
-
# convert torch.Tensor to numpy
|
|
697
|
-
input_patch = input_patch.numpy()
|
|
698
|
-
|
|
699
|
-
# denormalize
|
|
700
|
-
denormalize = Denormalize(
|
|
701
|
-
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
|
|
702
|
-
)
|
|
703
|
-
input_patch, _ = denormalize(input_patch)
|
|
704
|
-
|
|
705
|
-
elif self.train_datamodule is not None:
|
|
706
|
-
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
|
|
707
|
-
input_patch = input_patch.numpy()
|
|
708
|
-
|
|
709
|
-
# denormalize
|
|
710
|
-
denormalize = Denormalize(
|
|
711
|
-
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
|
|
712
|
-
)
|
|
713
|
-
input_patch, _ = denormalize(input_patch)
|
|
714
|
-
else:
|
|
715
|
-
# create a random input array
|
|
716
|
-
input_patch = np.random.normal(
|
|
717
|
-
loc=self.cfg.data_config.mean,
|
|
718
|
-
scale=self.cfg.data_config.std,
|
|
719
|
-
size=self.cfg.data_config.patch_size,
|
|
720
|
-
).astype(np.float32)[
|
|
721
|
-
np.newaxis, np.newaxis, ...
|
|
722
|
-
] # add S & C dimensions
|
|
723
|
-
else:
|
|
724
|
-
# potentially correct shape
|
|
725
|
-
input_patch = reshape_array(input_array, self.cfg.data_config.axes)
|
|
726
610
|
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
611
|
+
self.pred_datamodule = create_pred_datamodule(
|
|
612
|
+
source=source,
|
|
613
|
+
config=self.cfg,
|
|
614
|
+
batch_size=batch_size,
|
|
615
|
+
tile_size=tile_size,
|
|
616
|
+
tile_overlap=tile_overlap,
|
|
617
|
+
axes=axes,
|
|
618
|
+
data_type=data_type,
|
|
619
|
+
tta_transforms=tta_transforms,
|
|
620
|
+
dataloader_params=dataloader_params,
|
|
621
|
+
read_source_func=read_source_func,
|
|
622
|
+
extension_filter=extension_filter,
|
|
623
|
+
)
|
|
730
624
|
|
|
731
|
-
|
|
625
|
+
predictions = self.trainer.predict(
|
|
626
|
+
model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint
|
|
627
|
+
)
|
|
628
|
+
return convert_outputs(predictions, self.pred_datamodule.tiled)
|
|
732
629
|
|
|
733
630
|
def export_to_bmz(
|
|
734
631
|
self,
|
|
735
632
|
path: Union[Path, str],
|
|
736
633
|
name: str,
|
|
737
|
-
|
|
738
|
-
|
|
634
|
+
input_array: NDArray,
|
|
635
|
+
authors: list[dict],
|
|
739
636
|
general_description: str = "",
|
|
740
|
-
channel_names: Optional[
|
|
637
|
+
channel_names: Optional[list[str]] = None,
|
|
741
638
|
data_description: Optional[str] = None,
|
|
742
639
|
) -> None:
|
|
743
640
|
"""Export the model to the BioImage Model Zoo format.
|
|
744
641
|
|
|
745
|
-
Input array must be of
|
|
642
|
+
Input array must be of the same dimensions as the axes recorded in the
|
|
643
|
+
configuration of the `CAREamist`.
|
|
746
644
|
|
|
747
645
|
Parameters
|
|
748
646
|
----------
|
|
749
|
-
path :
|
|
647
|
+
path : pathlib.Path or str
|
|
750
648
|
Path to save the model.
|
|
751
649
|
name : str
|
|
752
650
|
Name of the model.
|
|
753
|
-
|
|
651
|
+
input_array : NDArray
|
|
652
|
+
Input array used to validate the model and as example.
|
|
653
|
+
authors : list of dict
|
|
754
654
|
List of authors of the model.
|
|
755
|
-
input_array : Optional[np.ndarray], optional
|
|
756
|
-
Input array for the model, must be of shape SC(Z)YX, by default None.
|
|
757
655
|
general_description : str
|
|
758
656
|
General description of the model, used in the metadata of the BMZ archive.
|
|
759
|
-
channel_names :
|
|
657
|
+
channel_names : list of str, optional
|
|
760
658
|
Channel names, by default None.
|
|
761
|
-
data_description :
|
|
659
|
+
data_description : str, optional
|
|
762
660
|
Description of the data, by default None.
|
|
763
661
|
"""
|
|
764
|
-
input_patch = self.
|
|
662
|
+
input_patch = reshape_array(input_array, self.cfg.data_config.axes)
|
|
765
663
|
|
|
766
664
|
# axes need to be reformated for the export because reshaping was done in the
|
|
767
665
|
# datamodule
|
|
@@ -778,11 +676,10 @@ class CAREamist:
|
|
|
778
676
|
tta_transforms=False,
|
|
779
677
|
)
|
|
780
678
|
|
|
781
|
-
if
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
)
|
|
679
|
+
if isinstance(output_patch, list):
|
|
680
|
+
output = np.concatenate(output_patch, axis=0)
|
|
681
|
+
else:
|
|
682
|
+
output = output_patch
|
|
786
683
|
|
|
787
684
|
export_to_bmz(
|
|
788
685
|
model=self.model,
|
|
@@ -792,7 +689,7 @@ class CAREamist:
|
|
|
792
689
|
general_description=general_description,
|
|
793
690
|
authors=authors,
|
|
794
691
|
input_array=input_patch,
|
|
795
|
-
output_array=
|
|
692
|
+
output_array=output,
|
|
796
693
|
channel_names=channel_names,
|
|
797
694
|
data_description=data_description,
|
|
798
695
|
)
|