careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
"""A class to train, predict and export models in CAREamics."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Literal, Optional, Union, overload
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
7
8
|
from pytorch_lightning import Trainer
|
|
8
9
|
from pytorch_lightning.callbacks import (
|
|
9
10
|
Callback,
|
|
@@ -12,59 +13,64 @@ from pytorch_lightning.callbacks import (
|
|
|
12
13
|
)
|
|
13
14
|
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
14
15
|
|
|
15
|
-
from careamics.callbacks import ProgressBarCallback
|
|
16
16
|
from careamics.config import (
|
|
17
17
|
Configuration,
|
|
18
|
-
create_inference_configuration,
|
|
19
18
|
load_configuration,
|
|
20
19
|
)
|
|
21
|
-
from careamics.config.support import
|
|
20
|
+
from careamics.config.support import (
|
|
21
|
+
SupportedAlgorithm,
|
|
22
|
+
SupportedArchitecture,
|
|
23
|
+
SupportedData,
|
|
24
|
+
SupportedLogger,
|
|
25
|
+
)
|
|
22
26
|
from careamics.dataset.dataset_utils import reshape_array
|
|
23
|
-
from careamics.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
+
from careamics.lightning import (
|
|
28
|
+
CAREamicsModule,
|
|
29
|
+
HyperParametersCallback,
|
|
30
|
+
PredictDataModule,
|
|
31
|
+
ProgressBarCallback,
|
|
32
|
+
TrainDataModule,
|
|
33
|
+
create_predict_datamodule,
|
|
34
|
+
)
|
|
27
35
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
28
|
-
from careamics.
|
|
36
|
+
from careamics.prediction_utils import convert_outputs
|
|
29
37
|
from careamics.utils import check_path_exists, get_logger
|
|
30
38
|
|
|
31
|
-
from .callbacks import HyperParametersCallback
|
|
32
|
-
|
|
33
39
|
logger = get_logger(__name__)
|
|
34
40
|
|
|
35
41
|
LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
|
|
36
42
|
|
|
37
43
|
|
|
38
|
-
# TODO napari callbacks
|
|
39
|
-
# TODO: how to do AMP? How to continue training?
|
|
40
44
|
class CAREamist:
|
|
41
45
|
"""Main CAREamics class, allowing training and prediction using various algorithms.
|
|
42
46
|
|
|
43
47
|
Parameters
|
|
44
48
|
----------
|
|
45
|
-
source :
|
|
49
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
46
50
|
Path to a configuration file or a trained model.
|
|
47
|
-
work_dir :
|
|
51
|
+
work_dir : str, optional
|
|
48
52
|
Path to working directory in which to save checkpoints and logs,
|
|
49
53
|
by default None.
|
|
50
|
-
experiment_name : str,
|
|
51
|
-
Experiment name used for checkpoints
|
|
54
|
+
experiment_name : str, by default "CAREamics"
|
|
55
|
+
Experiment name used for checkpoints.
|
|
56
|
+
callbacks : list of Callback, optional
|
|
57
|
+
List of callbacks to use during training and prediction, by default None.
|
|
52
58
|
|
|
53
59
|
Attributes
|
|
54
60
|
----------
|
|
55
|
-
model :
|
|
61
|
+
model : CAREamicsModule
|
|
56
62
|
CAREamics model.
|
|
57
63
|
cfg : Configuration
|
|
58
64
|
CAREamics configuration.
|
|
59
65
|
trainer : Trainer
|
|
60
66
|
PyTorch Lightning trainer.
|
|
61
|
-
experiment_logger :
|
|
67
|
+
experiment_logger : TensorBoardLogger or WandbLogger
|
|
62
68
|
Experiment logger, "wandb" or "tensorboard".
|
|
63
|
-
work_dir : Path
|
|
69
|
+
work_dir : pathlib.Path
|
|
64
70
|
Working directory.
|
|
65
|
-
train_datamodule :
|
|
71
|
+
train_datamodule : TrainDataModule
|
|
66
72
|
Training datamodule.
|
|
67
|
-
pred_datamodule :
|
|
73
|
+
pred_datamodule : PredictDataModule
|
|
68
74
|
Prediction datamodule.
|
|
69
75
|
"""
|
|
70
76
|
|
|
@@ -74,6 +80,7 @@ class CAREamist:
|
|
|
74
80
|
source: Union[Path, str],
|
|
75
81
|
work_dir: Optional[str] = None,
|
|
76
82
|
experiment_name: str = "CAREamics",
|
|
83
|
+
callbacks: Optional[list[Callback]] = None,
|
|
77
84
|
) -> None: ...
|
|
78
85
|
|
|
79
86
|
@overload
|
|
@@ -82,6 +89,7 @@ class CAREamist:
|
|
|
82
89
|
source: Configuration,
|
|
83
90
|
work_dir: Optional[str] = None,
|
|
84
91
|
experiment_name: str = "CAREamics",
|
|
92
|
+
callbacks: Optional[list[Callback]] = None,
|
|
85
93
|
) -> None: ...
|
|
86
94
|
|
|
87
95
|
def __init__(
|
|
@@ -89,6 +97,7 @@ class CAREamist:
|
|
|
89
97
|
source: Union[Path, str, Configuration],
|
|
90
98
|
work_dir: Optional[Union[Path, str]] = None,
|
|
91
99
|
experiment_name: str = "CAREamics",
|
|
100
|
+
callbacks: Optional[list[Callback]] = None,
|
|
92
101
|
) -> None:
|
|
93
102
|
"""
|
|
94
103
|
Initialize CAREamist with a configuration object or a path.
|
|
@@ -105,13 +114,15 @@ class CAREamist:
|
|
|
105
114
|
|
|
106
115
|
Parameters
|
|
107
116
|
----------
|
|
108
|
-
source :
|
|
117
|
+
source : pathlib.Path or str or CAREamics Configuration
|
|
109
118
|
Path to a configuration file or a trained model.
|
|
110
|
-
work_dir :
|
|
119
|
+
work_dir : str, optional
|
|
111
120
|
Path to working directory in which to save checkpoints and logs,
|
|
112
121
|
by default None.
|
|
113
122
|
experiment_name : str, optional
|
|
114
123
|
Experiment name used for checkpoints, by default "CAREamics".
|
|
124
|
+
callbacks : list of Callback, optional
|
|
125
|
+
List of callbacks to use during training and prediction, by default None.
|
|
115
126
|
|
|
116
127
|
Raises
|
|
117
128
|
------
|
|
@@ -164,7 +175,7 @@ class CAREamist:
|
|
|
164
175
|
self.model, self.cfg = load_pretrained(source)
|
|
165
176
|
|
|
166
177
|
# define the checkpoint saving callback
|
|
167
|
-
self.
|
|
178
|
+
self._define_callbacks(callbacks)
|
|
168
179
|
|
|
169
180
|
# instantiate logger
|
|
170
181
|
if self.cfg.training_config.has_logger():
|
|
@@ -188,32 +199,50 @@ class CAREamist:
|
|
|
188
199
|
logger=self.experiment_logger,
|
|
189
200
|
)
|
|
190
201
|
|
|
191
|
-
# change the prediction loop, necessary for tiled prediction
|
|
192
|
-
self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
|
|
193
|
-
|
|
194
202
|
# place holder for the datamodules
|
|
195
|
-
self.train_datamodule: Optional[
|
|
196
|
-
self.pred_datamodule: Optional[
|
|
203
|
+
self.train_datamodule: Optional[TrainDataModule] = None
|
|
204
|
+
self.pred_datamodule: Optional[PredictDataModule] = None
|
|
197
205
|
|
|
198
|
-
def _define_callbacks(self
|
|
206
|
+
def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
|
|
199
207
|
"""
|
|
200
208
|
Define the callbacks for the training loop.
|
|
201
209
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
List of callbacks to
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
callbacks : list of Callback, optional
|
|
213
|
+
List of callbacks to use during training and prediction, by default None.
|
|
206
214
|
"""
|
|
215
|
+
self.callbacks = [] if callbacks is None else callbacks
|
|
216
|
+
|
|
217
|
+
# check that user callbacks are not any of the CAREamics callbacks
|
|
218
|
+
for c in self.callbacks:
|
|
219
|
+
if isinstance(c, ModelCheckpoint) or isinstance(c, EarlyStopping):
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"ModelCheckpoint and EarlyStopping callbacks are already defined "
|
|
222
|
+
"in CAREamics and should only be modified through the "
|
|
223
|
+
"training configuration (see TrainingConfig)."
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if isinstance(c, HyperParametersCallback) or isinstance(
|
|
227
|
+
c, ProgressBarCallback
|
|
228
|
+
):
|
|
229
|
+
raise ValueError(
|
|
230
|
+
"HyperParameter and ProgressBar callbacks are defined internally "
|
|
231
|
+
"and should not be passed as callbacks."
|
|
232
|
+
)
|
|
233
|
+
|
|
207
234
|
# checkpoint callback saves checkpoints during training
|
|
208
|
-
self.callbacks
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
235
|
+
self.callbacks.extend(
|
|
236
|
+
[
|
|
237
|
+
HyperParametersCallback(self.cfg),
|
|
238
|
+
ModelCheckpoint(
|
|
239
|
+
dirpath=self.work_dir / Path("checkpoints"),
|
|
240
|
+
filename=self.cfg.experiment_name,
|
|
241
|
+
**self.cfg.training_config.checkpoint_callback.model_dump(),
|
|
242
|
+
),
|
|
243
|
+
ProgressBarCallback(),
|
|
244
|
+
]
|
|
245
|
+
)
|
|
217
246
|
|
|
218
247
|
# early stopping callback
|
|
219
248
|
if self.cfg.training_config.early_stopping_callback is not None:
|
|
@@ -221,16 +250,14 @@ class CAREamist:
|
|
|
221
250
|
EarlyStopping(self.cfg.training_config.early_stopping_callback)
|
|
222
251
|
)
|
|
223
252
|
|
|
224
|
-
return self.callbacks
|
|
225
|
-
|
|
226
253
|
def train(
|
|
227
254
|
self,
|
|
228
255
|
*,
|
|
229
|
-
datamodule: Optional[
|
|
230
|
-
train_source: Optional[Union[Path, str,
|
|
231
|
-
val_source: Optional[Union[Path, str,
|
|
232
|
-
train_target: Optional[Union[Path, str,
|
|
233
|
-
val_target: Optional[Union[Path, str,
|
|
256
|
+
datamodule: Optional[TrainDataModule] = None,
|
|
257
|
+
train_source: Optional[Union[Path, str, NDArray]] = None,
|
|
258
|
+
val_source: Optional[Union[Path, str, NDArray]] = None,
|
|
259
|
+
train_target: Optional[Union[Path, str, NDArray]] = None,
|
|
260
|
+
val_target: Optional[Union[Path, str, NDArray]] = None,
|
|
234
261
|
use_in_memory: bool = True,
|
|
235
262
|
val_percentage: float = 0.1,
|
|
236
263
|
val_minimum_split: int = 1,
|
|
@@ -253,15 +280,15 @@ class CAREamist:
|
|
|
253
280
|
|
|
254
281
|
Parameters
|
|
255
282
|
----------
|
|
256
|
-
datamodule :
|
|
283
|
+
datamodule : TrainDataModule, optional
|
|
257
284
|
Datamodule to train on, by default None.
|
|
258
|
-
train_source :
|
|
285
|
+
train_source : pathlib.Path or str or NDArray, optional
|
|
259
286
|
Train source, if no datamodule is provided, by default None.
|
|
260
|
-
val_source :
|
|
287
|
+
val_source : pathlib.Path or str or NDArray, optional
|
|
261
288
|
Validation source, if no datamodule is provided, by default None.
|
|
262
|
-
train_target :
|
|
289
|
+
train_target : pathlib.Path or str or NDArray, optional
|
|
263
290
|
Train target source, if no datamodule is provided, by default None.
|
|
264
|
-
val_target :
|
|
291
|
+
val_target : pathlib.Path or str or NDArray, optional
|
|
265
292
|
Validation target source, if no datamodule is provided, by default None.
|
|
266
293
|
use_in_memory : bool, optional
|
|
267
294
|
Use in memory dataset if possible, by default True.
|
|
@@ -355,17 +382,17 @@ class CAREamist:
|
|
|
355
382
|
|
|
356
383
|
else:
|
|
357
384
|
raise ValueError(
|
|
358
|
-
f"Invalid input, expected a str, Path, array or
|
|
385
|
+
f"Invalid input, expected a str, Path, array or TrainDataModule "
|
|
359
386
|
f"instance (got {type(train_source)})."
|
|
360
387
|
)
|
|
361
388
|
|
|
362
|
-
def _train_on_datamodule(self, datamodule:
|
|
389
|
+
def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
|
|
363
390
|
"""
|
|
364
391
|
Train the model on the provided datamodule.
|
|
365
392
|
|
|
366
393
|
Parameters
|
|
367
394
|
----------
|
|
368
|
-
datamodule :
|
|
395
|
+
datamodule : TrainDataModule
|
|
369
396
|
Datamodule to train on.
|
|
370
397
|
"""
|
|
371
398
|
# record datamodule
|
|
@@ -375,10 +402,10 @@ class CAREamist:
|
|
|
375
402
|
|
|
376
403
|
def _train_on_array(
|
|
377
404
|
self,
|
|
378
|
-
train_data:
|
|
379
|
-
val_data: Optional[
|
|
380
|
-
train_target: Optional[
|
|
381
|
-
val_target: Optional[
|
|
405
|
+
train_data: NDArray,
|
|
406
|
+
val_data: Optional[NDArray] = None,
|
|
407
|
+
train_target: Optional[NDArray] = None,
|
|
408
|
+
val_target: Optional[NDArray] = None,
|
|
382
409
|
val_percentage: float = 0.1,
|
|
383
410
|
val_minimum_split: int = 5,
|
|
384
411
|
) -> None:
|
|
@@ -387,13 +414,13 @@ class CAREamist:
|
|
|
387
414
|
|
|
388
415
|
Parameters
|
|
389
416
|
----------
|
|
390
|
-
train_data :
|
|
417
|
+
train_data : NDArray
|
|
391
418
|
Training data.
|
|
392
|
-
val_data :
|
|
419
|
+
val_data : NDArray, optional
|
|
393
420
|
Validation data, by default None.
|
|
394
|
-
train_target :
|
|
421
|
+
train_target : NDArray, optional
|
|
395
422
|
Train target data, by default None.
|
|
396
|
-
val_target :
|
|
423
|
+
val_target : NDArray, optional
|
|
397
424
|
Validation target data, by default None.
|
|
398
425
|
val_percentage : float, optional
|
|
399
426
|
Percentage of patches to use for validation, by default 0.1.
|
|
@@ -401,7 +428,7 @@ class CAREamist:
|
|
|
401
428
|
Minimum number of patches to use for validation, by default 5.
|
|
402
429
|
"""
|
|
403
430
|
# create datamodule
|
|
404
|
-
datamodule =
|
|
431
|
+
datamodule = TrainDataModule(
|
|
405
432
|
data_config=self.cfg.data_config,
|
|
406
433
|
train_data=train_data,
|
|
407
434
|
val_data=val_data,
|
|
@@ -429,13 +456,13 @@ class CAREamist:
|
|
|
429
456
|
|
|
430
457
|
Parameters
|
|
431
458
|
----------
|
|
432
|
-
path_to_train_data :
|
|
459
|
+
path_to_train_data : pathlib.Path or str
|
|
433
460
|
Path to the training data.
|
|
434
|
-
path_to_val_data :
|
|
461
|
+
path_to_val_data : pathlib.Path or str, optional
|
|
435
462
|
Path to validation data, by default None.
|
|
436
|
-
path_to_train_target :
|
|
463
|
+
path_to_train_target : pathlib.Path or str, optional
|
|
437
464
|
Path to train target data, by default None.
|
|
438
|
-
path_to_val_target :
|
|
465
|
+
path_to_val_target : pathlib.Path or str, optional
|
|
439
466
|
Path to validation target data, by default None.
|
|
440
467
|
use_in_memory : bool, optional
|
|
441
468
|
Use in memory dataset if possible, by default True.
|
|
@@ -457,7 +484,7 @@ class CAREamist:
|
|
|
457
484
|
path_to_val_target = check_path_exists(path_to_val_target)
|
|
458
485
|
|
|
459
486
|
# create datamodule
|
|
460
|
-
datamodule =
|
|
487
|
+
datamodule = TrainDataModule(
|
|
461
488
|
data_config=self.cfg.data_config,
|
|
462
489
|
train_data=path_to_train_data,
|
|
463
490
|
val_data=path_to_val_data,
|
|
@@ -473,11 +500,8 @@ class CAREamist:
|
|
|
473
500
|
|
|
474
501
|
@overload
|
|
475
502
|
def predict( # numpydoc ignore=GL08
|
|
476
|
-
self,
|
|
477
|
-
|
|
478
|
-
*,
|
|
479
|
-
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
480
|
-
) -> Union[list, np.ndarray]: ...
|
|
503
|
+
self, source: PredictDataModule
|
|
504
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
481
505
|
|
|
482
506
|
@overload
|
|
483
507
|
def predict( # numpydoc ignore=GL08
|
|
@@ -485,59 +509,59 @@ class CAREamist:
|
|
|
485
509
|
source: Union[Path, str],
|
|
486
510
|
*,
|
|
487
511
|
batch_size: int = 1,
|
|
488
|
-
tile_size: Optional[
|
|
489
|
-
tile_overlap:
|
|
512
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
513
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
490
514
|
axes: Optional[str] = None,
|
|
491
515
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
492
516
|
tta_transforms: bool = True,
|
|
493
|
-
dataloader_params: Optional[
|
|
517
|
+
dataloader_params: Optional[dict] = None,
|
|
494
518
|
read_source_func: Optional[Callable] = None,
|
|
495
519
|
extension_filter: str = "",
|
|
496
|
-
|
|
497
|
-
) -> Union[list, np.ndarray]: ...
|
|
520
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
498
521
|
|
|
499
522
|
@overload
|
|
500
523
|
def predict( # numpydoc ignore=GL08
|
|
501
524
|
self,
|
|
502
|
-
source:
|
|
525
|
+
source: NDArray,
|
|
503
526
|
*,
|
|
504
527
|
batch_size: int = 1,
|
|
505
|
-
tile_size: Optional[
|
|
506
|
-
tile_overlap:
|
|
528
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
529
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
507
530
|
axes: Optional[str] = None,
|
|
508
531
|
data_type: Optional[Literal["array"]] = None,
|
|
509
532
|
tta_transforms: bool = True,
|
|
510
|
-
dataloader_params: Optional[
|
|
511
|
-
|
|
512
|
-
) -> Union[list, np.ndarray]: ...
|
|
533
|
+
dataloader_params: Optional[dict] = None,
|
|
534
|
+
) -> Union[list[NDArray], NDArray]: ...
|
|
513
535
|
|
|
514
536
|
def predict(
|
|
515
537
|
self,
|
|
516
|
-
source: Union[
|
|
538
|
+
source: Union[PredictDataModule, Path, str, NDArray],
|
|
517
539
|
*,
|
|
518
|
-
batch_size: int =
|
|
519
|
-
tile_size: Optional[
|
|
520
|
-
tile_overlap:
|
|
540
|
+
batch_size: Optional[int] = None,
|
|
541
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
542
|
+
tile_overlap: tuple[int, ...] = (48, 48),
|
|
521
543
|
axes: Optional[str] = None,
|
|
522
544
|
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
523
545
|
tta_transforms: bool = True,
|
|
524
|
-
dataloader_params: Optional[
|
|
546
|
+
dataloader_params: Optional[dict] = None,
|
|
525
547
|
read_source_func: Optional[Callable] = None,
|
|
526
548
|
extension_filter: str = "",
|
|
527
|
-
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
528
549
|
**kwargs: Any,
|
|
529
|
-
) -> Union[
|
|
550
|
+
) -> Union[list[NDArray], NDArray]:
|
|
530
551
|
"""
|
|
531
552
|
Make predictions on the provided data.
|
|
532
553
|
|
|
533
|
-
Input can be a
|
|
554
|
+
Input can be a CAREamicsPredData instance, a path to a data file, or a numpy
|
|
555
|
+
array.
|
|
534
556
|
|
|
535
557
|
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
536
558
|
configuration parameters will be used, with the `patch_size` instead of
|
|
537
559
|
`tile_size`.
|
|
538
560
|
|
|
539
561
|
Test-time augmentation (TTA) can be switched off using the `tta_transforms`
|
|
540
|
-
parameter.
|
|
562
|
+
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
563
|
+
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
564
|
+
should not be used if you did not train with these augmentations.
|
|
541
565
|
|
|
542
566
|
Note that if you are using a UNet model and tiling, the tile size must be
|
|
543
567
|
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
@@ -547,242 +571,136 @@ class CAREamist:
|
|
|
547
571
|
|
|
548
572
|
Parameters
|
|
549
573
|
----------
|
|
550
|
-
source :
|
|
574
|
+
source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
|
|
551
575
|
Data to predict on.
|
|
552
|
-
batch_size : int,
|
|
553
|
-
Batch size for prediction
|
|
554
|
-
tile_size :
|
|
555
|
-
Size of the tiles to use for prediction
|
|
556
|
-
tile_overlap :
|
|
557
|
-
Overlap between tiles
|
|
558
|
-
axes :
|
|
576
|
+
batch_size : int, default=1
|
|
577
|
+
Batch size for prediction.
|
|
578
|
+
tile_size : tuple of int, optional
|
|
579
|
+
Size of the tiles to use for prediction.
|
|
580
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
581
|
+
Overlap between tiles.
|
|
582
|
+
axes : str, optional
|
|
559
583
|
Axes of the input data, by default None.
|
|
560
|
-
data_type :
|
|
561
|
-
Type of the input data
|
|
562
|
-
tta_transforms : bool,
|
|
563
|
-
Whether to apply test-time augmentation
|
|
564
|
-
dataloader_params :
|
|
565
|
-
Parameters to pass to the dataloader
|
|
566
|
-
read_source_func :
|
|
567
|
-
Function to read the source data
|
|
568
|
-
extension_filter : str,
|
|
569
|
-
Filter for the file extension
|
|
570
|
-
checkpoint : Optional[Literal["best", "last"]], optional
|
|
571
|
-
Checkpoint to use for prediction, by default None.
|
|
584
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
585
|
+
Type of the input data.
|
|
586
|
+
tta_transforms : bool, default=True
|
|
587
|
+
Whether to apply test-time augmentation.
|
|
588
|
+
dataloader_params : dict, optional
|
|
589
|
+
Parameters to pass to the dataloader.
|
|
590
|
+
read_source_func : Callable, optional
|
|
591
|
+
Function to read the source data.
|
|
592
|
+
extension_filter : str, default=""
|
|
593
|
+
Filter for the file extension.
|
|
572
594
|
**kwargs : Any
|
|
573
595
|
Unused.
|
|
574
596
|
|
|
575
597
|
Returns
|
|
576
598
|
-------
|
|
577
|
-
|
|
599
|
+
list of NDArray or NDArray
|
|
578
600
|
Predictions made by the model.
|
|
579
601
|
|
|
580
|
-
Raises
|
|
581
|
-
------
|
|
582
|
-
ValueError
|
|
583
|
-
If the input is not a CAREamicsClay instance, a path or a numpy array.
|
|
584
|
-
"""
|
|
585
|
-
if isinstance(source, CAREamicsPredictData):
|
|
586
|
-
# record datamodule
|
|
587
|
-
self.pred_datamodule = source
|
|
588
|
-
|
|
589
|
-
return self.trainer.predict(
|
|
590
|
-
model=self.model, datamodule=source, ckpt_path=checkpoint
|
|
591
|
-
)
|
|
592
|
-
else:
|
|
593
|
-
if self.cfg is None:
|
|
594
|
-
raise ValueError(
|
|
595
|
-
"No configuration found. Train a model or load from a "
|
|
596
|
-
"checkpoint before predicting."
|
|
597
|
-
)
|
|
598
|
-
# create predict config, reuse training config if parameters missing
|
|
599
|
-
prediction_config = create_inference_configuration(
|
|
600
|
-
configuration=self.cfg,
|
|
601
|
-
tile_size=tile_size,
|
|
602
|
-
tile_overlap=tile_overlap,
|
|
603
|
-
data_type=data_type,
|
|
604
|
-
axes=axes,
|
|
605
|
-
tta_transforms=tta_transforms,
|
|
606
|
-
batch_size=batch_size,
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
# remove batch from dataloader parameters (priority given to config)
|
|
610
|
-
if dataloader_params is None:
|
|
611
|
-
dataloader_params = {}
|
|
612
|
-
if "batch_size" in dataloader_params:
|
|
613
|
-
del dataloader_params["batch_size"]
|
|
614
|
-
|
|
615
|
-
if isinstance(source, Path) or isinstance(source, str):
|
|
616
|
-
# Check the source
|
|
617
|
-
source_path = check_path_exists(source)
|
|
618
|
-
|
|
619
|
-
# create datamodule
|
|
620
|
-
datamodule = CAREamicsPredictData(
|
|
621
|
-
pred_config=prediction_config,
|
|
622
|
-
pred_data=source_path,
|
|
623
|
-
read_source_func=read_source_func,
|
|
624
|
-
extension_filter=extension_filter,
|
|
625
|
-
dataloader_params=dataloader_params,
|
|
626
|
-
)
|
|
627
|
-
|
|
628
|
-
# record datamodule
|
|
629
|
-
self.pred_datamodule = datamodule
|
|
630
|
-
|
|
631
|
-
return self.trainer.predict(
|
|
632
|
-
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
633
|
-
)
|
|
634
|
-
|
|
635
|
-
elif isinstance(source, np.ndarray):
|
|
636
|
-
# create datamodule
|
|
637
|
-
datamodule = CAREamicsPredictData(
|
|
638
|
-
pred_config=prediction_config,
|
|
639
|
-
pred_data=source,
|
|
640
|
-
dataloader_params=dataloader_params,
|
|
641
|
-
)
|
|
642
|
-
|
|
643
|
-
# record datamodule
|
|
644
|
-
self.pred_datamodule = datamodule
|
|
645
|
-
|
|
646
|
-
return self.trainer.predict(
|
|
647
|
-
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
648
|
-
)
|
|
649
|
-
|
|
650
|
-
else:
|
|
651
|
-
raise ValueError(
|
|
652
|
-
f"Invalid input. Expected a CAREamicsWood instance, paths or "
|
|
653
|
-
f"np.ndarray (got {type(source)})."
|
|
654
|
-
)
|
|
655
|
-
|
|
656
|
-
def _create_data_for_bmz(
|
|
657
|
-
self,
|
|
658
|
-
input_array: Optional[np.ndarray] = None,
|
|
659
|
-
) -> np.ndarray:
|
|
660
|
-
"""Create data for BMZ export.
|
|
661
|
-
|
|
662
|
-
If no `input_array` is provided, this method checks if there is a prediction
|
|
663
|
-
datamodule, or a training data module, to extract a patch. If none exists,
|
|
664
|
-
then a random aray is created.
|
|
665
|
-
|
|
666
|
-
If there is a non-singleton batch dimension, this method returns only the first
|
|
667
|
-
element.
|
|
668
|
-
|
|
669
|
-
Parameters
|
|
670
|
-
----------
|
|
671
|
-
input_array : Optional[np.ndarray], optional
|
|
672
|
-
Input array, by default None.
|
|
673
|
-
|
|
674
|
-
Returns
|
|
675
|
-
-------
|
|
676
|
-
np.ndarray
|
|
677
|
-
Input data for BMZ export.
|
|
678
|
-
|
|
679
602
|
Raises
|
|
680
603
|
------
|
|
681
604
|
ValueError
|
|
682
605
|
If mean and std are not provided in the configuration.
|
|
606
|
+
ValueError
|
|
607
|
+
If tile size is not divisible by 2**depth for UNet models.
|
|
608
|
+
ValueError
|
|
609
|
+
If tile overlap is not specified.
|
|
683
610
|
"""
|
|
684
|
-
if
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
#
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
if input_patch.shape[0] > 1:
|
|
729
|
-
input_patch = input_patch[[0], ...] # keep singleton dim
|
|
611
|
+
if (
|
|
612
|
+
self.cfg.data_config.image_means is None
|
|
613
|
+
or self.cfg.data_config.image_stds is None
|
|
614
|
+
):
|
|
615
|
+
raise ValueError("Mean and std must be provided in the configuration.")
|
|
616
|
+
|
|
617
|
+
# tile size for UNets
|
|
618
|
+
if tile_size is not None:
|
|
619
|
+
model = self.cfg.algorithm_config.model
|
|
620
|
+
|
|
621
|
+
if model.architecture == SupportedArchitecture.UNET.value:
|
|
622
|
+
# tile size must be equal to k*2^n, where n is the number of pooling
|
|
623
|
+
# layers (equal to the depth) and k is an integer
|
|
624
|
+
depth = model.depth
|
|
625
|
+
tile_increment = 2**depth
|
|
626
|
+
|
|
627
|
+
for i, t in enumerate(tile_size):
|
|
628
|
+
if t % tile_increment != 0:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
f"Tile size must be divisible by {tile_increment} along "
|
|
631
|
+
f"all axes (got {t} for axis {i}). If your image size is "
|
|
632
|
+
f"smaller along one axis (e.g. Z), consider padding the "
|
|
633
|
+
f"image."
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
# tile overlaps must be specified
|
|
637
|
+
if tile_overlap is None:
|
|
638
|
+
raise ValueError("Tile overlap must be specified.")
|
|
639
|
+
|
|
640
|
+
# create the prediction
|
|
641
|
+
self.pred_datamodule = create_predict_datamodule(
|
|
642
|
+
pred_data=source,
|
|
643
|
+
data_type=data_type or self.cfg.data_config.data_type,
|
|
644
|
+
axes=axes or self.cfg.data_config.axes,
|
|
645
|
+
image_means=self.cfg.data_config.image_means,
|
|
646
|
+
image_stds=self.cfg.data_config.image_stds,
|
|
647
|
+
tile_size=tile_size,
|
|
648
|
+
tile_overlap=tile_overlap,
|
|
649
|
+
batch_size=batch_size or self.cfg.data_config.batch_size,
|
|
650
|
+
tta_transforms=tta_transforms,
|
|
651
|
+
read_source_func=read_source_func,
|
|
652
|
+
extension_filter=extension_filter,
|
|
653
|
+
dataloader_params=dataloader_params,
|
|
654
|
+
)
|
|
730
655
|
|
|
731
|
-
|
|
656
|
+
# predict
|
|
657
|
+
predictions = self.trainer.predict(
|
|
658
|
+
model=self.model, datamodule=self.pred_datamodule
|
|
659
|
+
)
|
|
660
|
+
return convert_outputs(predictions, self.pred_datamodule.tiled)
|
|
732
661
|
|
|
733
662
|
def export_to_bmz(
|
|
734
663
|
self,
|
|
735
664
|
path: Union[Path, str],
|
|
736
665
|
name: str,
|
|
737
|
-
|
|
738
|
-
|
|
666
|
+
input_array: NDArray,
|
|
667
|
+
authors: list[dict],
|
|
739
668
|
general_description: str = "",
|
|
740
|
-
channel_names: Optional[
|
|
669
|
+
channel_names: Optional[list[str]] = None,
|
|
741
670
|
data_description: Optional[str] = None,
|
|
742
671
|
) -> None:
|
|
743
672
|
"""Export the model to the BioImage Model Zoo format.
|
|
744
673
|
|
|
745
|
-
Input array must be of
|
|
674
|
+
Input array must be of the same dimensions as the axes recorded in the
|
|
675
|
+
configuration of the `CAREamist`.
|
|
746
676
|
|
|
747
677
|
Parameters
|
|
748
678
|
----------
|
|
749
|
-
path :
|
|
679
|
+
path : pathlib.Path or str
|
|
750
680
|
Path to save the model.
|
|
751
681
|
name : str
|
|
752
682
|
Name of the model.
|
|
753
|
-
|
|
683
|
+
input_array : NDArray
|
|
684
|
+
Input array used to validate the model and as example.
|
|
685
|
+
authors : list of dict
|
|
754
686
|
List of authors of the model.
|
|
755
|
-
input_array : Optional[np.ndarray], optional
|
|
756
|
-
Input array for the model, must be of shape SC(Z)YX, by default None.
|
|
757
687
|
general_description : str
|
|
758
688
|
General description of the model, used in the metadata of the BMZ archive.
|
|
759
|
-
channel_names :
|
|
689
|
+
channel_names : list of str, optional
|
|
760
690
|
Channel names, by default None.
|
|
761
|
-
data_description :
|
|
691
|
+
data_description : str, optional
|
|
762
692
|
Description of the data, by default None.
|
|
763
693
|
"""
|
|
764
|
-
|
|
694
|
+
# TODO: add in docs that it is expected that input_array dimensions match
|
|
695
|
+
# those in data_config
|
|
765
696
|
|
|
766
|
-
# axes need to be reformated for the export because reshaping was done in the
|
|
767
|
-
# datamodule
|
|
768
|
-
if "Z" in self.cfg.data_config.axes:
|
|
769
|
-
axes = "SCZYX"
|
|
770
|
-
else:
|
|
771
|
-
axes = "SCYX"
|
|
772
|
-
|
|
773
|
-
# predict output, remove extra dimensions for the purpose of the prediction
|
|
774
697
|
output_patch = self.predict(
|
|
775
|
-
|
|
698
|
+
input_array,
|
|
776
699
|
data_type=SupportedData.ARRAY.value,
|
|
777
|
-
axes=axes,
|
|
778
700
|
tta_transforms=False,
|
|
779
701
|
)
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
raise ValueError(
|
|
783
|
-
f"Numpy array required for export to BioImage Model Zoo, got "
|
|
784
|
-
f"{type(output_patch)}."
|
|
785
|
-
)
|
|
702
|
+
output = np.concatenate(output_patch, axis=0)
|
|
703
|
+
input_array = reshape_array(input_array, self.cfg.data_config.axes)
|
|
786
704
|
|
|
787
705
|
export_to_bmz(
|
|
788
706
|
model=self.model,
|
|
@@ -791,8 +709,8 @@ class CAREamist:
|
|
|
791
709
|
name=name,
|
|
792
710
|
general_description=general_description,
|
|
793
711
|
authors=authors,
|
|
794
|
-
input_array=
|
|
795
|
-
output_array=
|
|
712
|
+
input_array=input_array,
|
|
713
|
+
output_array=output,
|
|
796
714
|
channel_names=channel_names,
|
|
797
715
|
data_description=data_description,
|
|
798
716
|
)
|