careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
"""A class to train, predict and export models in CAREamics."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Literal, Optional, Union, overload
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
7
8
|
from pytorch_lightning import Trainer
|
|
8
9
|
from pytorch_lightning.callbacks import (
|
|
9
10
|
Callback,
|
|
@@ -15,55 +16,54 @@ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
|
15
16
|
from careamics.callbacks import ProgressBarCallback
|
|
16
17
|
from careamics.config import (
|
|
17
18
|
Configuration,
|
|
18
|
-
create_inference_configuration,
|
|
19
19
|
load_configuration,
|
|
20
20
|
)
|
|
21
|
-
from careamics.config.inference_model import TRANSFORMS_UNION
|
|
22
21
|
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
|
|
22
|
+
from careamics.dataset.dataset_utils import reshape_array
|
|
23
23
|
from careamics.lightning_datamodule import CAREamicsTrainData
|
|
24
24
|
from careamics.lightning_module import CAREamicsModule
|
|
25
|
-
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
|
|
26
|
-
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
|
|
27
25
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
26
|
+
from careamics.prediction_utils import convert_outputs, create_pred_datamodule
|
|
28
27
|
from careamics.utils import check_path_exists, get_logger
|
|
29
28
|
|
|
30
29
|
from .callbacks import HyperParametersCallback
|
|
30
|
+
from .lightning_prediction_datamodule import CAREamicsPredictData
|
|
31
31
|
|
|
32
32
|
logger = get_logger(__name__)
|
|
33
33
|
|
|
34
34
|
LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
# TODO napari callbacks
|
|
38
|
-
# TODO: how to do AMP? How to continue training?
|
|
39
37
|
class CAREamist:
|
|
40
38
|
"""Main CAREamics class, allowing training and prediction using various algorithms.
|
|
41
39
|
|
|
42
40
|
Parameters
|
|
43
41
|
----------
|
|
44
|
-
source :
|
|
42
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
45
43
|
Path to a configuration file or a trained model.
|
|
46
|
-
work_dir :
|
|
44
|
+
work_dir : str, optional
|
|
47
45
|
Path to working directory in which to save checkpoints and logs,
|
|
48
46
|
by default None.
|
|
49
|
-
experiment_name : str,
|
|
50
|
-
Experiment name used for checkpoints
|
|
47
|
+
experiment_name : str, by default "CAREamics"
|
|
48
|
+
Experiment name used for checkpoints.
|
|
49
|
+
callbacks : list of Callback, optional
|
|
50
|
+
List of callbacks to use during training and prediction, by default None.
|
|
51
51
|
|
|
52
52
|
Attributes
|
|
53
53
|
----------
|
|
54
|
-
model :
|
|
54
|
+
model : CAREamicsModule
|
|
55
55
|
CAREamics model.
|
|
56
56
|
cfg : Configuration
|
|
57
57
|
CAREamics configuration.
|
|
58
58
|
trainer : Trainer
|
|
59
59
|
PyTorch Lightning trainer.
|
|
60
|
-
experiment_logger :
|
|
60
|
+
experiment_logger : TensorBoardLogger or WandbLogger
|
|
61
61
|
Experiment logger, "wandb" or "tensorboard".
|
|
62
|
-
work_dir : Path
|
|
62
|
+
work_dir : pathlib.Path
|
|
63
63
|
Working directory.
|
|
64
|
-
train_datamodule :
|
|
64
|
+
train_datamodule : CAREamicsTrainData
|
|
65
65
|
Training datamodule.
|
|
66
|
-
pred_datamodule :
|
|
66
|
+
pred_datamodule : CAREamicsPredictData
|
|
67
67
|
Prediction datamodule.
|
|
68
68
|
"""
|
|
69
69
|
|
|
@@ -73,6 +73,7 @@ class CAREamist:
|
|
|
73
73
|
source: Union[Path, str],
|
|
74
74
|
work_dir: Optional[str] = None,
|
|
75
75
|
experiment_name: str = "CAREamics",
|
|
76
|
+
callbacks: Optional[list[Callback]] = None,
|
|
76
77
|
) -> None: ...
|
|
77
78
|
|
|
78
79
|
@overload
|
|
@@ -81,6 +82,7 @@ class CAREamist:
|
|
|
81
82
|
source: Configuration,
|
|
82
83
|
work_dir: Optional[str] = None,
|
|
83
84
|
experiment_name: str = "CAREamics",
|
|
85
|
+
callbacks: Optional[list[Callback]] = None,
|
|
84
86
|
) -> None: ...
|
|
85
87
|
|
|
86
88
|
def __init__(
|
|
@@ -88,6 +90,7 @@ class CAREamist:
|
|
|
88
90
|
source: Union[Path, str, Configuration],
|
|
89
91
|
work_dir: Optional[Union[Path, str]] = None,
|
|
90
92
|
experiment_name: str = "CAREamics",
|
|
93
|
+
callbacks: Optional[list[Callback]] = None,
|
|
91
94
|
) -> None:
|
|
92
95
|
"""
|
|
93
96
|
Initialize CAREamist with a configuration object or a path.
|
|
@@ -104,13 +107,15 @@ class CAREamist:
|
|
|
104
107
|
|
|
105
108
|
Parameters
|
|
106
109
|
----------
|
|
107
|
-
source :
|
|
110
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
108
111
|
Path to a configuration file or a trained model.
|
|
109
|
-
work_dir :
|
|
112
|
+
work_dir : str, optional
|
|
110
113
|
Path to working directory in which to save checkpoints and logs,
|
|
111
114
|
by default None.
|
|
112
115
|
experiment_name : str, optional
|
|
113
116
|
Experiment name used for checkpoints, by default "CAREamics".
|
|
117
|
+
callbacks : list of Callback, optional
|
|
118
|
+
List of callbacks to use during training and prediction, by default None.
|
|
114
119
|
|
|
115
120
|
Raises
|
|
116
121
|
------
|
|
@@ -163,7 +168,7 @@ class CAREamist:
|
|
|
163
168
|
self.model, self.cfg = load_pretrained(source)
|
|
164
169
|
|
|
165
170
|
# define the checkpoint saving callback
|
|
166
|
-
self.
|
|
171
|
+
self._define_callbacks(callbacks)
|
|
167
172
|
|
|
168
173
|
# instantiate logger
|
|
169
174
|
if self.cfg.training_config.has_logger():
|
|
@@ -187,32 +192,50 @@ class CAREamist:
|
|
|
187
192
|
logger=self.experiment_logger,
|
|
188
193
|
)
|
|
189
194
|
|
|
190
|
-
# change the prediction loop, necessary for tiled prediction
|
|
191
|
-
self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
|
|
192
|
-
|
|
193
195
|
# place holder for the datamodules
|
|
194
196
|
self.train_datamodule: Optional[CAREamicsTrainData] = None
|
|
195
197
|
self.pred_datamodule: Optional[CAREamicsPredictData] = None
|
|
196
198
|
|
|
197
|
-
def _define_callbacks(self
|
|
199
|
+
def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
|
|
198
200
|
"""
|
|
199
201
|
Define the callbacks for the training loop.
|
|
200
202
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
List of callbacks to
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
callbacks : list of Callback, optional
|
|
206
|
+
List of callbacks to use during training and prediction, by default None.
|
|
205
207
|
"""
|
|
208
|
+
self.callbacks = [] if callbacks is None else callbacks
|
|
209
|
+
|
|
210
|
+
# check that user callbacks are not any of the CAREamics callbacks
|
|
211
|
+
for c in self.callbacks:
|
|
212
|
+
if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"ModelCheckpoint and EarlyStopping callbacks are already defined "
|
|
215
|
+
"in CAREamics and should only be modified through the "
|
|
216
|
+
"training configuration (see TrainingConfig)."
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if isinstance(c, HyperParametersCallback) or isinstance(
|
|
220
|
+
c, ProgressBarCallback
|
|
221
|
+
):
|
|
222
|
+
raise ValueError(
|
|
223
|
+
"HyperParameter and ProgressBar callbacks are defined internally "
|
|
224
|
+
"and should not be passed as callbacks."
|
|
225
|
+
)
|
|
226
|
+
|
|
206
227
|
# checkpoint callback saves checkpoints during training
|
|
207
|
-
self.callbacks
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
228
|
+
self.callbacks.extend(
|
|
229
|
+
[
|
|
230
|
+
HyperParametersCallback(self.cfg),
|
|
231
|
+
ModelCheckpoint(
|
|
232
|
+
dirpath=self.work_dir / Path("checkpoints"),
|
|
233
|
+
filename=self.cfg.experiment_name,
|
|
234
|
+
**self.cfg.training_config.checkpoint_callback.model_dump(),
|
|
235
|
+
),
|
|
236
|
+
ProgressBarCallback(),
|
|
237
|
+
]
|
|
238
|
+
)
|
|
216
239
|
|
|
217
240
|
# early stopping callback
|
|
218
241
|
if self.cfg.training_config.early_stopping_callback is not None:
|
|
@@ -220,16 +243,14 @@ class CAREamist:
|
|
|
220
243
|
EarlyStopping(self.cfg.training_config.early_stopping_callback)
|
|
221
244
|
)
|
|
222
245
|
|
|
223
|
-
return self.callbacks
|
|
224
|
-
|
|
225
246
|
def train(
|
|
226
247
|
self,
|
|
227
248
|
*,
|
|
228
249
|
datamodule: Optional[CAREamicsTrainData] = None,
|
|
229
|
-
train_source: Optional[Union[Path, str,
|
|
230
|
-
val_source: Optional[Union[Path, str,
|
|
231
|
-
train_target: Optional[Union[Path, str,
|
|
232
|
-
val_target: Optional[Union[Path, str,
|
|
250
|
+
train_source: Optional[Union[Path, str, NDArray]] = None,
|
|
251
|
+
val_source: Optional[Union[Path, str, NDArray]] = None,
|
|
252
|
+
train_target: Optional[Union[Path, str, NDArray]] = None,
|
|
253
|
+
val_target: Optional[Union[Path, str, NDArray]] = None,
|
|
233
254
|
use_in_memory: bool = True,
|
|
234
255
|
val_percentage: float = 0.1,
|
|
235
256
|
val_minimum_split: int = 1,
|
|
@@ -252,15 +273,15 @@ class CAREamist:
|
|
|
252
273
|
|
|
253
274
|
Parameters
|
|
254
275
|
----------
|
|
255
|
-
datamodule :
|
|
276
|
+
datamodule : CAREamicsTrainData, optional
|
|
256
277
|
Datamodule to train on, by default None.
|
|
257
|
-
train_source :
|
|
278
|
+
train_source : pathlib.Path or str or NDArray, optional
|
|
258
279
|
Train source, if no datamodule is provided, by default None.
|
|
259
|
-
val_source :
|
|
280
|
+
val_source : pathlib.Path or str or NDArray, optional
|
|
260
281
|
Validation source, if no datamodule is provided, by default None.
|
|
261
|
-
train_target :
|
|
282
|
+
train_target : pathlib.Path or str or NDArray, optional
|
|
262
283
|
Train target source, if no datamodule is provided, by default None.
|
|
263
|
-
val_target :
|
|
284
|
+
val_target : pathlib.Path or str or NDArray, optional
|
|
264
285
|
Validation target source, if no datamodule is provided, by default None.
|
|
265
286
|
use_in_memory : bool, optional
|
|
266
287
|
Use in memory dataset if possible, by default True.
|
|
@@ -354,7 +375,7 @@ class CAREamist:
|
|
|
354
375
|
|
|
355
376
|
else:
|
|
356
377
|
raise ValueError(
|
|
357
|
-
f"Invalid input, expected a str, Path, array or
|
|
378
|
+
f"Invalid input, expected a str, Path, array or CAREamicsTrainData "
|
|
358
379
|
f"instance (got {type(train_source)})."
|
|
359
380
|
)
|
|
360
381
|
|
|
@@ -364,7 +385,7 @@ class CAREamist:
|
|
|
364
385
|
|
|
365
386
|
Parameters
|
|
366
387
|
----------
|
|
367
|
-
datamodule :
|
|
388
|
+
datamodule : CAREamicsTrainData
|
|
368
389
|
Datamodule to train on.
|
|
369
390
|
"""
|
|
370
391
|
# record datamodule
|
|
@@ -374,10 +395,10 @@ class CAREamist:
|
|
|
374
395
|
|
|
375
396
|
def _train_on_array(
|
|
376
397
|
self,
|
|
377
|
-
train_data:
|
|
378
|
-
val_data: Optional[
|
|
379
|
-
train_target: Optional[
|
|
380
|
-
val_target: Optional[
|
|
398
|
+
train_data: NDArray,
|
|
399
|
+
val_data: Optional[NDArray] = None,
|
|
400
|
+
train_target: Optional[NDArray] = None,
|
|
401
|
+
val_target: Optional[NDArray] = None,
|
|
381
402
|
val_percentage: float = 0.1,
|
|
382
403
|
val_minimum_split: int = 5,
|
|
383
404
|
) -> None:
|
|
@@ -386,13 +407,13 @@ class CAREamist:
|
|
|
386
407
|
|
|
387
408
|
Parameters
|
|
388
409
|
----------
|
|
389
|
-
train_data :
|
|
410
|
+
train_data : NDArray
|
|
390
411
|
Training data.
|
|
391
|
-
val_data :
|
|
412
|
+
val_data : NDArray, optional
|
|
392
413
|
Validation data, by default None.
|
|
393
|
-
train_target :
|
|
414
|
+
train_target : NDArray, optional
|
|
394
415
|
Train target data, by default None.
|
|
395
|
-
val_target :
|
|
416
|
+
val_target : NDArray, optional
|
|
396
417
|
Validation target data, by default None.
|
|
397
418
|
val_percentage : float, optional
|
|
398
419
|
Percentage of patches to use for validation, by default 0.1.
|
|
@@ -428,13 +449,13 @@ class CAREamist:
|
|
|
428
449
|
|
|
429
450
|
Parameters
|
|
430
451
|
----------
|
|
431
|
-
path_to_train_data :
|
|
452
|
+
path_to_train_data : pathlib.Path or str
|
|
432
453
|
Path to the training data.
|
|
433
|
-
path_to_val_data :
|
|
454
|
+
path_to_val_data : pathlib.Path or str, optional
|
|
434
455
|
Path to validation data, by default None.
|
|
435
|
-
path_to_train_target :
|
|
456
|
+
path_to_train_target : pathlib.Path or str, optional
|
|
436
457
|
Path to train target data, by default None.
|
|
437
|
-
path_to_val_target :
|
|
458
|
+
path_to_val_target : pathlib.Path or str, optional
|
|
438
459
|
Path to validation target data, by default None.
|
|
439
460
|
use_in_memory : bool, optional
|
|
440
461
|
Use in memory dataset if possible, by default True.
|
|
@@ -476,7 +497,7 @@ class CAREamist:
|
|
|
476
497
|
source: CAREamicsPredictData,
|
|
477
498
|
*,
|
|
478
499
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
479
|
-
) -> Union[list,
|
|
500
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
480
501
|
|
|
481
502
|
@overload
|
|
482
503
|
def predict( # numpydoc ignore=GL08
|
|
@@ -484,64 +505,62 @@ class CAREamist:
|
|
|
484
505
|
source: Union[Path, str],
|
|
485
506
|
*,
|
|
486
507
|
batch_size: int = 1,
|
|
487
|
-
tile_size: Optional[
|
|
488
|
-
tile_overlap:
|
|
508
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
509
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
489
510
|
axes: Optional[str] = None,
|
|
490
511
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
491
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
492
512
|
tta_transforms: bool = True,
|
|
493
|
-
dataloader_params: Optional[
|
|
513
|
+
dataloader_params: Optional[dict] = None,
|
|
494
514
|
read_source_func: Optional[Callable] = None,
|
|
495
515
|
extension_filter: str = "",
|
|
496
516
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
497
|
-
) -> Union[list,
|
|
517
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
498
518
|
|
|
499
519
|
@overload
|
|
500
520
|
def predict( # numpydoc ignore=GL08
|
|
501
521
|
self,
|
|
502
|
-
source:
|
|
522
|
+
source: NDArray,
|
|
503
523
|
*,
|
|
504
524
|
batch_size: int = 1,
|
|
505
|
-
tile_size: Optional[
|
|
506
|
-
tile_overlap:
|
|
525
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
526
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
507
527
|
axes: Optional[str] = None,
|
|
508
528
|
data_type: Optional[Literal["array"]] = None,
|
|
509
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
510
529
|
tta_transforms: bool = True,
|
|
511
|
-
dataloader_params: Optional[
|
|
530
|
+
dataloader_params: Optional[dict] = None,
|
|
512
531
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
513
|
-
) -> Union[list,
|
|
532
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
514
533
|
|
|
515
534
|
def predict(
|
|
516
535
|
self,
|
|
517
|
-
source: Union[CAREamicsPredictData, Path, str,
|
|
536
|
+
source: Union[CAREamicsPredictData, Path, str, NDArray],
|
|
518
537
|
*,
|
|
519
|
-
batch_size: int =
|
|
520
|
-
tile_size: Optional[
|
|
521
|
-
tile_overlap:
|
|
538
|
+
batch_size: Optional[int] = None,
|
|
539
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
540
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
522
541
|
axes: Optional[str] = None,
|
|
523
542
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
524
|
-
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
525
543
|
tta_transforms: bool = True,
|
|
526
|
-
dataloader_params: Optional[
|
|
544
|
+
dataloader_params: Optional[dict] = None,
|
|
527
545
|
read_source_func: Optional[Callable] = None,
|
|
528
546
|
extension_filter: str = "",
|
|
529
547
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
530
548
|
**kwargs: Any,
|
|
531
|
-
) -> Union[
|
|
549
|
+
) -> Union[list[NDArray], NDArray]:
|
|
532
550
|
"""
|
|
533
551
|
Make predictions on the provided data.
|
|
534
552
|
|
|
535
|
-
Input can be a
|
|
553
|
+
Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
|
|
554
|
+
array.
|
|
536
555
|
|
|
537
556
|
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
538
557
|
configuration parameters will be used, with the `patch_size` instead of
|
|
539
558
|
`tile_size`.
|
|
540
559
|
|
|
541
|
-
The default transforms are defined in the `InferenceModel` Pydantic model.
|
|
542
|
-
|
|
543
560
|
Test-time augmentation (TTA) can be switched off using the `tta_transforms`
|
|
544
|
-
parameter.
|
|
561
|
+
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
562
|
+
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
563
|
+
should not be used if you did not train with these augmentations.
|
|
545
564
|
|
|
546
565
|
Note that if you are using a UNet model and tiling, the tile size must be
|
|
547
566
|
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
@@ -551,181 +570,96 @@ class CAREamist:
|
|
|
551
570
|
|
|
552
571
|
Parameters
|
|
553
572
|
----------
|
|
554
|
-
source :
|
|
573
|
+
source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
|
|
555
574
|
Data to predict on.
|
|
556
|
-
batch_size : int,
|
|
557
|
-
Batch size for prediction
|
|
558
|
-
tile_size :
|
|
559
|
-
Size of the tiles to use for prediction
|
|
560
|
-
tile_overlap :
|
|
561
|
-
Overlap between tiles
|
|
562
|
-
axes :
|
|
575
|
+
batch_size : int, default=1
|
|
576
|
+
Batch size for prediction.
|
|
577
|
+
tile_size : tuple of int, optional
|
|
578
|
+
Size of the tiles to use for prediction.
|
|
579
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
580
|
+
Overlap between tiles.
|
|
581
|
+
axes : str, optional
|
|
563
582
|
Axes of the input data, by default None.
|
|
564
|
-
data_type :
|
|
565
|
-
Type of the input data
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
checkpoint : Optional[Literal["best", "last"]], optional
|
|
577
|
-
Checkpoint to use for prediction, by default None.
|
|
583
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
584
|
+
Type of the input data.
|
|
585
|
+
tta_transforms : bool, default=True
|
|
586
|
+
Whether to apply test-time augmentation.
|
|
587
|
+
dataloader_params : dict, optional
|
|
588
|
+
Parameters to pass to the dataloader.
|
|
589
|
+
read_source_func : Callable, optional
|
|
590
|
+
Function to read the source data.
|
|
591
|
+
extension_filter : str, default=""
|
|
592
|
+
Filter for the file extension.
|
|
593
|
+
checkpoint : {"best", "last"}, optional
|
|
594
|
+
Checkpoint to use for prediction.
|
|
578
595
|
**kwargs : Any
|
|
579
596
|
Unused.
|
|
580
597
|
|
|
581
598
|
Returns
|
|
582
599
|
-------
|
|
583
|
-
|
|
600
|
+
list of NDArray or NDArray
|
|
584
601
|
Predictions made by the model.
|
|
585
|
-
|
|
586
|
-
Raises
|
|
587
|
-
------
|
|
588
|
-
ValueError
|
|
589
|
-
If the input is not a CAREamicsClay instance, a path or a numpy array.
|
|
590
602
|
"""
|
|
591
|
-
if
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
)
|
|
598
|
-
else:
|
|
599
|
-
if self.cfg is None:
|
|
600
|
-
raise ValueError(
|
|
601
|
-
"No configuration found. Train a model or load from a "
|
|
602
|
-
"checkpoint before predicting."
|
|
603
|
-
)
|
|
604
|
-
# create predict config, reuse training config if parameters missing
|
|
605
|
-
prediction_config = create_inference_configuration(
|
|
606
|
-
configuration=self.cfg,
|
|
607
|
-
tile_size=tile_size,
|
|
608
|
-
tile_overlap=tile_overlap,
|
|
609
|
-
data_type=data_type,
|
|
610
|
-
axes=axes,
|
|
611
|
-
transforms=transforms,
|
|
612
|
-
tta_transforms=tta_transforms,
|
|
613
|
-
batch_size=batch_size,
|
|
603
|
+
# Reuse batch size if not provided explicitly
|
|
604
|
+
if batch_size is None:
|
|
605
|
+
batch_size = (
|
|
606
|
+
self.train_datamodule.batch_size
|
|
607
|
+
if self.train_datamodule
|
|
608
|
+
else self.cfg.data_config.batch_size
|
|
614
609
|
)
|
|
615
610
|
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
pred_data=source_path,
|
|
630
|
-
read_source_func=read_source_func,
|
|
631
|
-
extension_filter=extension_filter,
|
|
632
|
-
dataloader_params=dataloader_params,
|
|
633
|
-
)
|
|
634
|
-
|
|
635
|
-
# record datamodule
|
|
636
|
-
self.pred_datamodule = datamodule
|
|
637
|
-
|
|
638
|
-
return self.trainer.predict(
|
|
639
|
-
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
640
|
-
)
|
|
641
|
-
|
|
642
|
-
elif isinstance(source, np.ndarray):
|
|
643
|
-
# create datamodule
|
|
644
|
-
datamodule = CAREamicsPredictData(
|
|
645
|
-
pred_config=prediction_config,
|
|
646
|
-
pred_data=source,
|
|
647
|
-
dataloader_params=dataloader_params,
|
|
648
|
-
)
|
|
649
|
-
|
|
650
|
-
# record datamodule
|
|
651
|
-
self.pred_datamodule = datamodule
|
|
652
|
-
|
|
653
|
-
return self.trainer.predict(
|
|
654
|
-
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
655
|
-
)
|
|
611
|
+
self.pred_datamodule = create_pred_datamodule(
|
|
612
|
+
source=source,
|
|
613
|
+
config=self.cfg,
|
|
614
|
+
batch_size=batch_size,
|
|
615
|
+
tile_size=tile_size,
|
|
616
|
+
tile_overlap=tile_overlap,
|
|
617
|
+
axes=axes,
|
|
618
|
+
data_type=data_type,
|
|
619
|
+
tta_transforms=tta_transforms,
|
|
620
|
+
dataloader_params=dataloader_params,
|
|
621
|
+
read_source_func=read_source_func,
|
|
622
|
+
extension_filter=extension_filter,
|
|
623
|
+
)
|
|
656
624
|
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
)
|
|
625
|
+
predictions = self.trainer.predict(
|
|
626
|
+
model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint
|
|
627
|
+
)
|
|
628
|
+
return convert_outputs(predictions, self.pred_datamodule.tiled)
|
|
662
629
|
|
|
663
630
|
def export_to_bmz(
|
|
664
631
|
self,
|
|
665
632
|
path: Union[Path, str],
|
|
666
633
|
name: str,
|
|
667
|
-
|
|
668
|
-
|
|
634
|
+
input_array: NDArray,
|
|
635
|
+
authors: list[dict],
|
|
669
636
|
general_description: str = "",
|
|
670
|
-
channel_names: Optional[
|
|
637
|
+
channel_names: Optional[list[str]] = None,
|
|
671
638
|
data_description: Optional[str] = None,
|
|
672
639
|
) -> None:
|
|
673
640
|
"""Export the model to the BioImage Model Zoo format.
|
|
674
641
|
|
|
675
|
-
Input array must be of
|
|
642
|
+
Input array must be of the same dimensions as the axes recorded in the
|
|
643
|
+
configuration of the `CAREamist`.
|
|
676
644
|
|
|
677
645
|
Parameters
|
|
678
646
|
----------
|
|
679
|
-
path :
|
|
647
|
+
path : pathlib.Path or str
|
|
680
648
|
Path to save the model.
|
|
681
649
|
name : str
|
|
682
650
|
Name of the model.
|
|
683
|
-
|
|
651
|
+
input_array : NDArray
|
|
652
|
+
Input array used to validate the model and as example.
|
|
653
|
+
authors : list of dict
|
|
684
654
|
List of authors of the model.
|
|
685
|
-
input_array : Optional[np.ndarray], optional
|
|
686
|
-
Input array for the model, must be of shape SC(Z)YX, by default None.
|
|
687
655
|
general_description : str
|
|
688
656
|
General description of the model, used in the metadata of the BMZ archive.
|
|
689
|
-
channel_names :
|
|
657
|
+
channel_names : list of str, optional
|
|
690
658
|
Channel names, by default None.
|
|
691
|
-
data_description :
|
|
659
|
+
data_description : str, optional
|
|
692
660
|
Description of the data, by default None.
|
|
693
661
|
"""
|
|
694
|
-
|
|
695
|
-
# generate images, priority is given to the prediction data module
|
|
696
|
-
if self.pred_datamodule is not None:
|
|
697
|
-
# unpack a batch, ignore masks or targets
|
|
698
|
-
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
|
|
699
|
-
|
|
700
|
-
# convert torch.Tensor to numpy
|
|
701
|
-
input_patch = input_patch.numpy()
|
|
702
|
-
elif self.train_datamodule is not None:
|
|
703
|
-
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
|
|
704
|
-
input_patch = input_patch.numpy()
|
|
705
|
-
else:
|
|
706
|
-
if (
|
|
707
|
-
self.cfg.data_config.mean is None
|
|
708
|
-
or self.cfg.data_config.std is None
|
|
709
|
-
):
|
|
710
|
-
raise ValueError(
|
|
711
|
-
"Mean and std cannot be None in the configuration in order to"
|
|
712
|
-
"export to the BMZ format. Was the model trained?"
|
|
713
|
-
)
|
|
714
|
-
|
|
715
|
-
# create a random input array
|
|
716
|
-
input_patch = np.random.normal(
|
|
717
|
-
loc=self.cfg.data_config.mean,
|
|
718
|
-
scale=self.cfg.data_config.std,
|
|
719
|
-
size=self.cfg.data_config.patch_size,
|
|
720
|
-
).astype(np.float32)[
|
|
721
|
-
np.newaxis, np.newaxis, ...
|
|
722
|
-
] # add S & C dimensions
|
|
723
|
-
else:
|
|
724
|
-
input_patch = input_array
|
|
725
|
-
|
|
726
|
-
# if there is a batch dimension
|
|
727
|
-
if input_patch.shape[0] > 1:
|
|
728
|
-
input_patch = input_patch[0:1, ...] # keep singleton dim
|
|
662
|
+
input_patch = reshape_array(input_array, self.cfg.data_config.axes)
|
|
729
663
|
|
|
730
664
|
# axes need to be reformated for the export because reshaping was done in the
|
|
731
665
|
# datamodule
|
|
@@ -742,11 +676,10 @@ class CAREamist:
|
|
|
742
676
|
tta_transforms=False,
|
|
743
677
|
)
|
|
744
678
|
|
|
745
|
-
if
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
)
|
|
679
|
+
if isinstance(output_patch, list):
|
|
680
|
+
output = np.concatenate(output_patch, axis=0)
|
|
681
|
+
else:
|
|
682
|
+
output = output_patch
|
|
750
683
|
|
|
751
684
|
export_to_bmz(
|
|
752
685
|
model=self.model,
|
|
@@ -756,7 +689,7 @@ class CAREamist:
|
|
|
756
689
|
general_description=general_description,
|
|
757
690
|
authors=authors,
|
|
758
691
|
input_array=input_patch,
|
|
759
|
-
output_array=
|
|
692
|
+
output_array=output,
|
|
760
693
|
channel_names=channel_names,
|
|
761
694
|
data_description=data_description,
|
|
762
695
|
)
|