careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
1
|
+
"""A class to train, predict and export models in CAREamics."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pytorch_lightning import Trainer
|
|
8
|
+
from pytorch_lightning.callbacks import (
|
|
9
|
+
Callback,
|
|
10
|
+
EarlyStopping,
|
|
11
|
+
ModelCheckpoint,
|
|
12
|
+
)
|
|
13
|
+
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
14
|
+
|
|
15
|
+
from careamics.callbacks import ProgressBarCallback
|
|
16
|
+
from careamics.config import (
|
|
17
|
+
Configuration,
|
|
18
|
+
create_inference_configuration,
|
|
19
|
+
load_configuration,
|
|
20
|
+
)
|
|
21
|
+
from careamics.config.inference_model import TRANSFORMS_UNION
|
|
22
|
+
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
|
|
23
|
+
from careamics.lightning_datamodule import CAREamicsTrainData
|
|
24
|
+
from careamics.lightning_module import CAREamicsModule
|
|
25
|
+
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
|
|
26
|
+
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
|
|
27
|
+
from careamics.model_io import export_to_bmz, load_pretrained
|
|
28
|
+
from careamics.utils import check_path_exists, get_logger
|
|
29
|
+
|
|
30
|
+
from .callbacks import HyperParametersCallback
|
|
31
|
+
|
|
32
|
+
logger = get_logger(__name__)
|
|
33
|
+
|
|
34
|
+
LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# TODO napari callbacks
|
|
38
|
+
# TODO: how to do AMP? How to continue training?
|
|
39
|
+
class CAREamist:
|
|
40
|
+
"""Main CAREamics class, allowing training and prediction using various algorithms.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
source : Union[Path, str, Configuration]
|
|
45
|
+
Path to a configuration file or a trained model.
|
|
46
|
+
work_dir : Optional[str], optional
|
|
47
|
+
Path to working directory in which to save checkpoints and logs,
|
|
48
|
+
by default None.
|
|
49
|
+
experiment_name : str, optional
|
|
50
|
+
Experiment name used for checkpoints, by default "CAREamics".
|
|
51
|
+
|
|
52
|
+
Attributes
|
|
53
|
+
----------
|
|
54
|
+
model : CAREamicsKiln
|
|
55
|
+
CAREamics model.
|
|
56
|
+
cfg : Configuration
|
|
57
|
+
CAREamics configuration.
|
|
58
|
+
trainer : Trainer
|
|
59
|
+
PyTorch Lightning trainer.
|
|
60
|
+
experiment_logger : Optional[Union[TensorBoardLogger, WandbLogger]]
|
|
61
|
+
Experiment logger, "wandb" or "tensorboard".
|
|
62
|
+
work_dir : Path
|
|
63
|
+
Working directory.
|
|
64
|
+
train_datamodule : Optional[CAREamicsWood]
|
|
65
|
+
Training datamodule.
|
|
66
|
+
pred_datamodule : Optional[CAREamicsClay]
|
|
67
|
+
Prediction datamodule.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
@overload
|
|
71
|
+
def __init__( # numpydoc ignore=GL08
|
|
72
|
+
self,
|
|
73
|
+
source: Union[Path, str],
|
|
74
|
+
work_dir: Optional[str] = None,
|
|
75
|
+
experiment_name: str = "CAREamics",
|
|
76
|
+
) -> None:
|
|
77
|
+
...
|
|
78
|
+
|
|
79
|
+
@overload
|
|
80
|
+
def __init__( # numpydoc ignore=GL08
|
|
81
|
+
self,
|
|
82
|
+
source: Configuration,
|
|
83
|
+
work_dir: Optional[str] = None,
|
|
84
|
+
experiment_name: str = "CAREamics",
|
|
85
|
+
) -> None:
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
source: Union[Path, str, Configuration],
|
|
91
|
+
work_dir: Optional[Union[Path, str]] = None,
|
|
92
|
+
experiment_name: str = "CAREamics",
|
|
93
|
+
) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Initialize CAREamist with a configuration object or a path.
|
|
96
|
+
|
|
97
|
+
A configuration object can be created using directly by calling `Configuration`,
|
|
98
|
+
using the configuration factory or loading a configuration from a yaml file.
|
|
99
|
+
|
|
100
|
+
Path can contain either a yaml file with parameters, or a saved checkpoint.
|
|
101
|
+
|
|
102
|
+
If no working directory is provided, the current working directory is used.
|
|
103
|
+
|
|
104
|
+
If `source` is a checkpoint, then `experiment_name` is used to name the
|
|
105
|
+
checkpoint, and is recorded in the configuration.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
source : Union[Path, str, Configuration]
|
|
110
|
+
Path to a configuration file or a trained model.
|
|
111
|
+
work_dir : Optional[str], optional
|
|
112
|
+
Path to working directory in which to save checkpoints and logs,
|
|
113
|
+
by default None.
|
|
114
|
+
experiment_name : str, optional
|
|
115
|
+
Experiment name used for checkpoints, by default "CAREamics".
|
|
116
|
+
|
|
117
|
+
Raises
|
|
118
|
+
------
|
|
119
|
+
NotImplementedError
|
|
120
|
+
If the model is loaded from BioImage Model Zoo.
|
|
121
|
+
ValueError
|
|
122
|
+
If no hyper parameters are found in the checkpoint.
|
|
123
|
+
ValueError
|
|
124
|
+
If no data module hyper parameters are found in the checkpoint.
|
|
125
|
+
"""
|
|
126
|
+
super().__init__()
|
|
127
|
+
|
|
128
|
+
# select current working directory if work_dir is None
|
|
129
|
+
if work_dir is None:
|
|
130
|
+
self.work_dir = Path.cwd()
|
|
131
|
+
logger.warning(
|
|
132
|
+
f"No working directory provided. Using current working directory: "
|
|
133
|
+
f"{self.work_dir}."
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
self.work_dir = Path(work_dir)
|
|
137
|
+
|
|
138
|
+
# configuration object
|
|
139
|
+
if isinstance(source, Configuration):
|
|
140
|
+
self.cfg = source
|
|
141
|
+
|
|
142
|
+
# instantiate model
|
|
143
|
+
self.model = CAREamicsModule(
|
|
144
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# path to configuration file or model
|
|
148
|
+
else:
|
|
149
|
+
source = check_path_exists(source)
|
|
150
|
+
|
|
151
|
+
# configuration file
|
|
152
|
+
if source.is_file() and (
|
|
153
|
+
source.suffix == ".yaml" or source.suffix == ".yml"
|
|
154
|
+
):
|
|
155
|
+
# load configuration
|
|
156
|
+
self.cfg = load_configuration(source)
|
|
157
|
+
|
|
158
|
+
# instantiate model
|
|
159
|
+
self.model = CAREamicsModule(
|
|
160
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# attempt loading a pre-trained model
|
|
164
|
+
else:
|
|
165
|
+
self.model, self.cfg = load_pretrained(source)
|
|
166
|
+
|
|
167
|
+
# define the checkpoint saving callback
|
|
168
|
+
self.callbacks = self._define_callbacks()
|
|
169
|
+
|
|
170
|
+
# instantiate logger
|
|
171
|
+
if self.cfg.training_config.has_logger():
|
|
172
|
+
if self.cfg.training_config.logger == SupportedLogger.WANDB:
|
|
173
|
+
self.experiment_logger: LOGGER_TYPES = WandbLogger(
|
|
174
|
+
name=experiment_name,
|
|
175
|
+
save_dir=self.work_dir / Path("logs"),
|
|
176
|
+
)
|
|
177
|
+
elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
|
|
178
|
+
self.experiment_logger = TensorBoardLogger(
|
|
179
|
+
save_dir=self.work_dir / Path("logs"),
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
self.experiment_logger = None
|
|
183
|
+
|
|
184
|
+
# instantiate trainer
|
|
185
|
+
self.trainer = Trainer(
|
|
186
|
+
max_epochs=self.cfg.training_config.num_epochs,
|
|
187
|
+
callbacks=self.callbacks,
|
|
188
|
+
default_root_dir=self.work_dir,
|
|
189
|
+
logger=self.experiment_logger,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# change the prediction loop, necessary for tiled prediction
|
|
193
|
+
self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
|
|
194
|
+
|
|
195
|
+
# place holder for the datamodules
|
|
196
|
+
self.train_datamodule: Optional[CAREamicsTrainData] = None
|
|
197
|
+
self.pred_datamodule: Optional[CAREamicsPredictData] = None
|
|
198
|
+
|
|
199
|
+
def _define_callbacks(self) -> List[Callback]:
|
|
200
|
+
"""
|
|
201
|
+
Define the callbacks for the training loop.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
List[Callback]
|
|
206
|
+
List of callbacks to be used during training.
|
|
207
|
+
"""
|
|
208
|
+
# checkpoint callback saves checkpoints during training
|
|
209
|
+
self.callbacks = [
|
|
210
|
+
HyperParametersCallback(self.cfg),
|
|
211
|
+
ModelCheckpoint(
|
|
212
|
+
dirpath=self.work_dir / Path("checkpoints"),
|
|
213
|
+
filename=self.cfg.experiment_name,
|
|
214
|
+
**self.cfg.training_config.checkpoint_callback.model_dump(),
|
|
215
|
+
),
|
|
216
|
+
ProgressBarCallback(),
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
# early stopping callback
|
|
220
|
+
if self.cfg.training_config.early_stopping_callback is not None:
|
|
221
|
+
self.callbacks.append(
|
|
222
|
+
EarlyStopping(self.cfg.training_config.early_stopping_callback)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return self.callbacks
|
|
226
|
+
|
|
227
|
+
def train(
|
|
228
|
+
self,
|
|
229
|
+
*,
|
|
230
|
+
datamodule: Optional[CAREamicsTrainData] = None,
|
|
231
|
+
train_source: Optional[Union[Path, str, np.ndarray]] = None,
|
|
232
|
+
val_source: Optional[Union[Path, str, np.ndarray]] = None,
|
|
233
|
+
train_target: Optional[Union[Path, str, np.ndarray]] = None,
|
|
234
|
+
val_target: Optional[Union[Path, str, np.ndarray]] = None,
|
|
235
|
+
use_in_memory: bool = True,
|
|
236
|
+
val_percentage: float = 0.1,
|
|
237
|
+
val_minimum_split: int = 1,
|
|
238
|
+
) -> None:
|
|
239
|
+
"""
|
|
240
|
+
Train the model on the provided data.
|
|
241
|
+
|
|
242
|
+
If a datamodule is provided, then training will be performed using it.
|
|
243
|
+
Alternatively, the training data can be provided as arrays or paths.
|
|
244
|
+
|
|
245
|
+
If `use_in_memory` is set to True, the source provided as Path or str will be
|
|
246
|
+
loaded in memory if it fits. Otherwise, training will be performed by loading
|
|
247
|
+
patches from the files one by one. Training on arrays is always performed
|
|
248
|
+
in memory.
|
|
249
|
+
|
|
250
|
+
If no validation source is provided, then the validation is extracted from
|
|
251
|
+
the training data using `val_percentage` and `val_minimum_split`. In the case
|
|
252
|
+
of data provided as Path or str, the percentage and minimum number are applied
|
|
253
|
+
to the number of files. For arrays, it is the number of patches.
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
datamodule : Optional[CAREamicsWood], optional
|
|
258
|
+
Datamodule to train on, by default None.
|
|
259
|
+
train_source : Optional[Union[Path, str, np.ndarray]], optional
|
|
260
|
+
Train source, if no datamodule is provided, by default None.
|
|
261
|
+
val_source : Optional[Union[Path, str, np.ndarray]], optional
|
|
262
|
+
Validation source, if no datamodule is provided, by default None.
|
|
263
|
+
train_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
264
|
+
Train target source, if no datamodule is provided, by default None.
|
|
265
|
+
val_target : Optional[Union[Path, str, np.ndarray]], optional
|
|
266
|
+
Validation target source, if no datamodule is provided, by default None.
|
|
267
|
+
use_in_memory : bool, optional
|
|
268
|
+
Use in memory dataset if possible, by default True.
|
|
269
|
+
val_percentage : float, optional
|
|
270
|
+
Percentage of validation extracted from training data, by default 0.1.
|
|
271
|
+
val_minimum_split : int, optional
|
|
272
|
+
Minimum number of validation (patch or file) extracted from training data,
|
|
273
|
+
by default 1.
|
|
274
|
+
|
|
275
|
+
Raises
|
|
276
|
+
------
|
|
277
|
+
ValueError
|
|
278
|
+
If both `datamodule` and `train_source` are provided.
|
|
279
|
+
ValueError
|
|
280
|
+
If sources are not of the same type (e.g. train is an array and val is
|
|
281
|
+
a Path).
|
|
282
|
+
ValueError
|
|
283
|
+
If the training target is provided to N2V.
|
|
284
|
+
ValueError
|
|
285
|
+
If neither a datamodule nor a source is provided.
|
|
286
|
+
"""
|
|
287
|
+
if datamodule is not None and train_source:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"Only one of `datamodule` and `train_source` can be provided."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# check that inputs are the same type
|
|
293
|
+
source_types = {
|
|
294
|
+
type(s)
|
|
295
|
+
for s in (train_source, val_source, train_target, val_target)
|
|
296
|
+
if s is not None
|
|
297
|
+
}
|
|
298
|
+
if len(source_types) > 1:
|
|
299
|
+
raise ValueError("All sources should be of the same type.")
|
|
300
|
+
|
|
301
|
+
# train
|
|
302
|
+
if datamodule is not None:
|
|
303
|
+
self._train_on_datamodule(datamodule=datamodule)
|
|
304
|
+
|
|
305
|
+
else:
|
|
306
|
+
# raise error if target is provided to N2V
|
|
307
|
+
if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
|
|
308
|
+
if train_target is not None:
|
|
309
|
+
raise ValueError(
|
|
310
|
+
"Training target not compatible with N2V training."
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# dispatch the training
|
|
314
|
+
if isinstance(train_source, np.ndarray):
|
|
315
|
+
# mypy checks
|
|
316
|
+
assert isinstance(val_source, np.ndarray) or val_source is None
|
|
317
|
+
assert isinstance(train_target, np.ndarray) or train_target is None
|
|
318
|
+
assert isinstance(val_target, np.ndarray) or val_target is None
|
|
319
|
+
|
|
320
|
+
self._train_on_array(
|
|
321
|
+
train_source,
|
|
322
|
+
val_source,
|
|
323
|
+
train_target,
|
|
324
|
+
val_target,
|
|
325
|
+
val_percentage,
|
|
326
|
+
val_minimum_split,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
elif isinstance(train_source, Path) or isinstance(train_source, str):
|
|
330
|
+
# mypy checks
|
|
331
|
+
assert (
|
|
332
|
+
isinstance(val_source, Path)
|
|
333
|
+
or isinstance(val_source, str)
|
|
334
|
+
or val_source is None
|
|
335
|
+
)
|
|
336
|
+
assert (
|
|
337
|
+
isinstance(train_target, Path)
|
|
338
|
+
or isinstance(train_target, str)
|
|
339
|
+
or train_target is None
|
|
340
|
+
)
|
|
341
|
+
assert (
|
|
342
|
+
isinstance(val_target, Path)
|
|
343
|
+
or isinstance(val_target, str)
|
|
344
|
+
or val_target is None
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
self._train_on_path(
|
|
348
|
+
train_source,
|
|
349
|
+
val_source,
|
|
350
|
+
train_target,
|
|
351
|
+
val_target,
|
|
352
|
+
use_in_memory,
|
|
353
|
+
val_percentage,
|
|
354
|
+
val_minimum_split,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
else:
|
|
358
|
+
raise ValueError(
|
|
359
|
+
f"Invalid input, expected a str, Path, array or CAREamicsWood "
|
|
360
|
+
f"instance (got {type(train_source)})."
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
|
|
364
|
+
"""
|
|
365
|
+
Train the model on the provided datamodule.
|
|
366
|
+
|
|
367
|
+
Parameters
|
|
368
|
+
----------
|
|
369
|
+
datamodule : CAREamicsWood
|
|
370
|
+
Datamodule to train on.
|
|
371
|
+
"""
|
|
372
|
+
# record datamodule
|
|
373
|
+
self.train_datamodule = datamodule
|
|
374
|
+
|
|
375
|
+
self.trainer.fit(self.model, datamodule=datamodule)
|
|
376
|
+
|
|
377
|
+
def _train_on_array(
|
|
378
|
+
self,
|
|
379
|
+
train_data: np.ndarray,
|
|
380
|
+
val_data: Optional[np.ndarray] = None,
|
|
381
|
+
train_target: Optional[np.ndarray] = None,
|
|
382
|
+
val_target: Optional[np.ndarray] = None,
|
|
383
|
+
val_percentage: float = 0.1,
|
|
384
|
+
val_minimum_split: int = 5,
|
|
385
|
+
) -> None:
|
|
386
|
+
"""
|
|
387
|
+
Train the model on the provided data arrays.
|
|
388
|
+
|
|
389
|
+
Parameters
|
|
390
|
+
----------
|
|
391
|
+
train_data : np.ndarray
|
|
392
|
+
Training data.
|
|
393
|
+
val_data : Optional[np.ndarray], optional
|
|
394
|
+
Validation data, by default None.
|
|
395
|
+
train_target : Optional[np.ndarray], optional
|
|
396
|
+
Train target data, by default None.
|
|
397
|
+
val_target : Optional[np.ndarray], optional
|
|
398
|
+
Validation target data, by default None.
|
|
399
|
+
val_percentage : float, optional
|
|
400
|
+
Percentage of patches to use for validation, by default 0.1.
|
|
401
|
+
val_minimum_split : int, optional
|
|
402
|
+
Minimum number of patches to use for validation, by default 5.
|
|
403
|
+
"""
|
|
404
|
+
# create datamodule
|
|
405
|
+
datamodule = CAREamicsTrainData(
|
|
406
|
+
data_config=self.cfg.data_config,
|
|
407
|
+
train_data=train_data,
|
|
408
|
+
val_data=val_data,
|
|
409
|
+
train_data_target=train_target,
|
|
410
|
+
val_data_target=val_target,
|
|
411
|
+
val_percentage=val_percentage,
|
|
412
|
+
val_minimum_split=val_minimum_split,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# train
|
|
416
|
+
self.train(datamodule=datamodule)
|
|
417
|
+
|
|
418
|
+
def _train_on_path(
|
|
419
|
+
self,
|
|
420
|
+
path_to_train_data: Union[Path, str],
|
|
421
|
+
path_to_val_data: Optional[Union[Path, str]] = None,
|
|
422
|
+
path_to_train_target: Optional[Union[Path, str]] = None,
|
|
423
|
+
path_to_val_target: Optional[Union[Path, str]] = None,
|
|
424
|
+
use_in_memory: bool = True,
|
|
425
|
+
val_percentage: float = 0.1,
|
|
426
|
+
val_minimum_split: int = 1,
|
|
427
|
+
) -> None:
|
|
428
|
+
"""
|
|
429
|
+
Train the model on the provided data paths.
|
|
430
|
+
|
|
431
|
+
Parameters
|
|
432
|
+
----------
|
|
433
|
+
path_to_train_data : Union[Path, str]
|
|
434
|
+
Path to the training data.
|
|
435
|
+
path_to_val_data : Optional[Union[Path, str]], optional
|
|
436
|
+
Path to validation data, by default None.
|
|
437
|
+
path_to_train_target : Optional[Union[Path, str]], optional
|
|
438
|
+
Path to train target data, by default None.
|
|
439
|
+
path_to_val_target : Optional[Union[Path, str]], optional
|
|
440
|
+
Path to validation target data, by default None.
|
|
441
|
+
use_in_memory : bool, optional
|
|
442
|
+
Use in memory dataset if possible, by default True.
|
|
443
|
+
val_percentage : float, optional
|
|
444
|
+
Percentage of files to use for validation, by default 0.1.
|
|
445
|
+
val_minimum_split : int, optional
|
|
446
|
+
Minimum number of files to use for validation, by default 1.
|
|
447
|
+
"""
|
|
448
|
+
# sanity check on data (path exists)
|
|
449
|
+
path_to_train_data = check_path_exists(path_to_train_data)
|
|
450
|
+
|
|
451
|
+
if path_to_val_data is not None:
|
|
452
|
+
path_to_val_data = check_path_exists(path_to_val_data)
|
|
453
|
+
|
|
454
|
+
if path_to_train_target is not None:
|
|
455
|
+
path_to_train_target = check_path_exists(path_to_train_target)
|
|
456
|
+
|
|
457
|
+
if path_to_val_target is not None:
|
|
458
|
+
path_to_val_target = check_path_exists(path_to_val_target)
|
|
459
|
+
|
|
460
|
+
# create datamodule
|
|
461
|
+
datamodule = CAREamicsTrainData(
|
|
462
|
+
data_config=self.cfg.data_config,
|
|
463
|
+
train_data=path_to_train_data,
|
|
464
|
+
val_data=path_to_val_data,
|
|
465
|
+
train_data_target=path_to_train_target,
|
|
466
|
+
val_data_target=path_to_val_target,
|
|
467
|
+
use_in_memory=use_in_memory,
|
|
468
|
+
val_percentage=val_percentage,
|
|
469
|
+
val_minimum_split=val_minimum_split,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# train
|
|
473
|
+
self.train(datamodule=datamodule)
|
|
474
|
+
|
|
475
|
+
@overload
|
|
476
|
+
def predict( # numpydoc ignore=GL08
|
|
477
|
+
self,
|
|
478
|
+
source: CAREamicsPredictData,
|
|
479
|
+
*,
|
|
480
|
+
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
481
|
+
) -> Union[list, np.ndarray]:
|
|
482
|
+
...
|
|
483
|
+
|
|
484
|
+
@overload
|
|
485
|
+
def predict( # numpydoc ignore=GL08
|
|
486
|
+
self,
|
|
487
|
+
source: Union[Path, str],
|
|
488
|
+
*,
|
|
489
|
+
batch_size: int = 1,
|
|
490
|
+
tile_size: Optional[Tuple[int, ...]] = None,
|
|
491
|
+
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
492
|
+
axes: Optional[str] = None,
|
|
493
|
+
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
494
|
+
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
495
|
+
tta_transforms: bool = True,
|
|
496
|
+
dataloader_params: Optional[Dict] = None,
|
|
497
|
+
read_source_func: Optional[Callable] = None,
|
|
498
|
+
extension_filter: str = "",
|
|
499
|
+
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
500
|
+
) -> Union[list, np.ndarray]:
|
|
501
|
+
...
|
|
502
|
+
|
|
503
|
+
@overload
|
|
504
|
+
def predict( # numpydoc ignore=GL08
|
|
505
|
+
self,
|
|
506
|
+
source: np.ndarray,
|
|
507
|
+
*,
|
|
508
|
+
batch_size: int = 1,
|
|
509
|
+
tile_size: Optional[Tuple[int, ...]] = None,
|
|
510
|
+
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
511
|
+
axes: Optional[str] = None,
|
|
512
|
+
data_type: Optional[Literal["array"]] = None,
|
|
513
|
+
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
514
|
+
tta_transforms: bool = True,
|
|
515
|
+
dataloader_params: Optional[Dict] = None,
|
|
516
|
+
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
517
|
+
) -> Union[list, np.ndarray]:
|
|
518
|
+
...
|
|
519
|
+
|
|
520
|
+
def predict(
|
|
521
|
+
self,
|
|
522
|
+
source: Union[CAREamicsPredictData, Path, str, np.ndarray],
|
|
523
|
+
*,
|
|
524
|
+
batch_size: int = 1,
|
|
525
|
+
tile_size: Optional[Tuple[int, ...]] = None,
|
|
526
|
+
tile_overlap: Tuple[int, ...] = (48, 48),
|
|
527
|
+
axes: Optional[str] = None,
|
|
528
|
+
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
529
|
+
transforms: Optional[List[TRANSFORMS_UNION]] = None,
|
|
530
|
+
tta_transforms: bool = True,
|
|
531
|
+
dataloader_params: Optional[Dict] = None,
|
|
532
|
+
read_source_func: Optional[Callable] = None,
|
|
533
|
+
extension_filter: str = "",
|
|
534
|
+
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
535
|
+
**kwargs: Any,
|
|
536
|
+
) -> Union[List[np.ndarray], np.ndarray]:
|
|
537
|
+
"""
|
|
538
|
+
Make predictions on the provided data.
|
|
539
|
+
|
|
540
|
+
Input can be a CAREamicsClay instance, a path to a data file, or a numpy array.
|
|
541
|
+
|
|
542
|
+
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
543
|
+
configuration parameters will be used, with the `patch_size` instead of
|
|
544
|
+
`tile_size`.
|
|
545
|
+
|
|
546
|
+
The default transforms are defined in the `InferenceModel` Pydantic model.
|
|
547
|
+
|
|
548
|
+
Test-time augmentation (TTA) can be switched off using the `tta_transforms`
|
|
549
|
+
parameter.
|
|
550
|
+
|
|
551
|
+
Parameters
|
|
552
|
+
----------
|
|
553
|
+
source : Union[CAREamicsClay, Path, str, np.ndarray]
|
|
554
|
+
Data to predict on.
|
|
555
|
+
batch_size : int, optional
|
|
556
|
+
Batch size for prediction, by default 1.
|
|
557
|
+
tile_size : Optional[Tuple[int, ...]], optional
|
|
558
|
+
Size of the tiles to use for prediction, by default None.
|
|
559
|
+
tile_overlap : Tuple[int, ...], optional
|
|
560
|
+
Overlap between tiles, by default (48, 48).
|
|
561
|
+
axes : Optional[str], optional
|
|
562
|
+
Axes of the input data, by default None.
|
|
563
|
+
data_type : Optional[Literal["array", "tiff", "custom"]], optional
|
|
564
|
+
Type of the input data, by default None.
|
|
565
|
+
transforms : Optional[List[TRANSFORMS_UNION]], optional
|
|
566
|
+
List of transforms to apply to the data, by default None.
|
|
567
|
+
tta_transforms : bool, optional
|
|
568
|
+
Whether to apply test-time augmentation, by default True.
|
|
569
|
+
dataloader_params : Optional[Dict], optional
|
|
570
|
+
Parameters to pass to the dataloader, by default None.
|
|
571
|
+
read_source_func : Optional[Callable], optional
|
|
572
|
+
Function to read the source data, by default None.
|
|
573
|
+
extension_filter : str, optional
|
|
574
|
+
Filter for the file extension, by default "".
|
|
575
|
+
checkpoint : Optional[Literal["best", "last"]], optional
|
|
576
|
+
Checkpoint to use for prediction, by default None.
|
|
577
|
+
**kwargs : Any
|
|
578
|
+
Unused.
|
|
579
|
+
|
|
580
|
+
Returns
|
|
581
|
+
-------
|
|
582
|
+
Union[List[np.ndarray], np.ndarray]
|
|
583
|
+
Predictions made by the model.
|
|
584
|
+
|
|
585
|
+
Raises
|
|
586
|
+
------
|
|
587
|
+
ValueError
|
|
588
|
+
If the input is not a CAREamicsClay instance, a path or a numpy array.
|
|
589
|
+
"""
|
|
590
|
+
if isinstance(source, CAREamicsPredictData):
|
|
591
|
+
# record datamodule
|
|
592
|
+
self.pred_datamodule = source
|
|
593
|
+
|
|
594
|
+
return self.trainer.predict(
|
|
595
|
+
model=self.model, datamodule=source, ckpt_path=checkpoint
|
|
596
|
+
)
|
|
597
|
+
else:
|
|
598
|
+
if self.cfg is None:
|
|
599
|
+
raise ValueError(
|
|
600
|
+
"No configuration found. Train a model or load from a "
|
|
601
|
+
"checkpoint before predicting."
|
|
602
|
+
)
|
|
603
|
+
# create predict config, reuse training config if parameters missing
|
|
604
|
+
prediction_config = create_inference_configuration(
|
|
605
|
+
training_configuration=self.cfg,
|
|
606
|
+
tile_size=tile_size,
|
|
607
|
+
tile_overlap=tile_overlap,
|
|
608
|
+
data_type=data_type,
|
|
609
|
+
axes=axes,
|
|
610
|
+
transforms=transforms,
|
|
611
|
+
tta_transforms=tta_transforms,
|
|
612
|
+
batch_size=batch_size,
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
# remove batch from dataloader parameters (priority given to config)
|
|
616
|
+
if dataloader_params is None:
|
|
617
|
+
dataloader_params = {}
|
|
618
|
+
if "batch_size" in dataloader_params:
|
|
619
|
+
del dataloader_params["batch_size"]
|
|
620
|
+
|
|
621
|
+
if isinstance(source, Path) or isinstance(source, str):
|
|
622
|
+
# Check the source
|
|
623
|
+
source_path = check_path_exists(source)
|
|
624
|
+
|
|
625
|
+
# create datamodule
|
|
626
|
+
datamodule = CAREamicsPredictData(
|
|
627
|
+
pred_config=prediction_config,
|
|
628
|
+
pred_data=source_path,
|
|
629
|
+
read_source_func=read_source_func,
|
|
630
|
+
extension_filter=extension_filter,
|
|
631
|
+
dataloader_params=dataloader_params,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# record datamodule
|
|
635
|
+
self.pred_datamodule = datamodule
|
|
636
|
+
|
|
637
|
+
return self.trainer.predict(
|
|
638
|
+
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
elif isinstance(source, np.ndarray):
|
|
642
|
+
# create datamodule
|
|
643
|
+
datamodule = CAREamicsPredictData(
|
|
644
|
+
pred_config=prediction_config,
|
|
645
|
+
pred_data=source,
|
|
646
|
+
dataloader_params=dataloader_params,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# record datamodule
|
|
650
|
+
self.pred_datamodule = datamodule
|
|
651
|
+
|
|
652
|
+
return self.trainer.predict(
|
|
653
|
+
model=self.model, datamodule=datamodule, ckpt_path=checkpoint
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
else:
|
|
657
|
+
raise ValueError(
|
|
658
|
+
f"Invalid input. Expected a CAREamicsWood instance, paths or "
|
|
659
|
+
f"np.ndarray (got {type(source)})."
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
def export_to_bmz(
|
|
663
|
+
self,
|
|
664
|
+
path: Union[Path, str],
|
|
665
|
+
name: str,
|
|
666
|
+
authors: List[dict],
|
|
667
|
+
input_array: Optional[np.ndarray] = None,
|
|
668
|
+
general_description: str = "",
|
|
669
|
+
channel_names: Optional[List[str]] = None,
|
|
670
|
+
data_description: Optional[str] = None,
|
|
671
|
+
) -> None:
|
|
672
|
+
"""Export the model to the BioImage Model Zoo format.
|
|
673
|
+
|
|
674
|
+
Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
|
|
675
|
+
|
|
676
|
+
Parameters
|
|
677
|
+
----------
|
|
678
|
+
path : Union[Path, str]
|
|
679
|
+
Path to save the model.
|
|
680
|
+
name : str
|
|
681
|
+
Name of the model.
|
|
682
|
+
authors : List[dict]
|
|
683
|
+
List of authors of the model.
|
|
684
|
+
input_array : Optional[np.ndarray], optional
|
|
685
|
+
Input array for the model, must be of shape SC(Z)YX, by default None.
|
|
686
|
+
general_description : str
|
|
687
|
+
General description of the model, used in the metadata of the BMZ archive.
|
|
688
|
+
channel_names : Optional[List[str]], optional
|
|
689
|
+
Channel names, by default None.
|
|
690
|
+
data_description : Optional[str], optional
|
|
691
|
+
Description of the data, by default None.
|
|
692
|
+
"""
|
|
693
|
+
if input_array is None:
|
|
694
|
+
# generate images, priority is given to the prediction data module
|
|
695
|
+
if self.pred_datamodule is not None:
|
|
696
|
+
# unpack a batch, ignore masks or targets
|
|
697
|
+
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))
|
|
698
|
+
|
|
699
|
+
# convert torch.Tensor to numpy
|
|
700
|
+
input_patch = input_patch.numpy()
|
|
701
|
+
elif self.train_datamodule is not None:
|
|
702
|
+
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
|
|
703
|
+
input_patch = input_patch.numpy()
|
|
704
|
+
else:
|
|
705
|
+
if (
|
|
706
|
+
self.cfg.data_config.mean is None
|
|
707
|
+
or self.cfg.data_config.std is None
|
|
708
|
+
):
|
|
709
|
+
raise ValueError(
|
|
710
|
+
"Mean and std cannot be None in the configuration in order to"
|
|
711
|
+
"export to the BMZ format. Was the model trained?"
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
# create a random input array
|
|
715
|
+
input_patch = np.random.normal(
|
|
716
|
+
loc=self.cfg.data_config.mean,
|
|
717
|
+
scale=self.cfg.data_config.std,
|
|
718
|
+
size=self.cfg.data_config.patch_size,
|
|
719
|
+
).astype(np.float32)[
|
|
720
|
+
np.newaxis, np.newaxis, ...
|
|
721
|
+
] # add S & C dimensions
|
|
722
|
+
else:
|
|
723
|
+
input_patch = input_array
|
|
724
|
+
|
|
725
|
+
# if there is a batch dimension
|
|
726
|
+
if input_patch.shape[0] > 1:
|
|
727
|
+
input_patch = input_patch[0:1, ...] # keep singleton dim
|
|
728
|
+
|
|
729
|
+
# axes need to be reformated for the export because reshaping was done in the
|
|
730
|
+
# datamodule
|
|
731
|
+
if "Z" in self.cfg.data_config.axes:
|
|
732
|
+
axes = "SCZYX"
|
|
733
|
+
else:
|
|
734
|
+
axes = "SCYX"
|
|
735
|
+
|
|
736
|
+
# predict output, remove extra dimensions for the purpose of the prediction
|
|
737
|
+
output_patch = self.predict(
|
|
738
|
+
input_patch,
|
|
739
|
+
data_type=SupportedData.ARRAY.value,
|
|
740
|
+
axes=axes,
|
|
741
|
+
tta_transforms=False,
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
if not isinstance(output_patch, np.ndarray):
|
|
745
|
+
raise ValueError(
|
|
746
|
+
f"Numpy array required for export to BioImage Model Zoo, got "
|
|
747
|
+
f"{type(output_patch)}."
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
export_to_bmz(
|
|
751
|
+
model=self.model,
|
|
752
|
+
config=self.cfg,
|
|
753
|
+
path=path,
|
|
754
|
+
name=name,
|
|
755
|
+
general_description=general_description,
|
|
756
|
+
authors=authors,
|
|
757
|
+
input_array=input_patch,
|
|
758
|
+
output_array=output_patch,
|
|
759
|
+
channel_names=channel_names,
|
|
760
|
+
data_description=data_description,
|
|
761
|
+
)
|