careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""CAREamics Lightning module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, Union
|
|
4
|
+
|
|
5
|
+
import pytorch_lightning as L
|
|
6
|
+
from torch import Tensor, nn
|
|
7
|
+
|
|
8
|
+
from careamics.config import AlgorithmConfig
|
|
9
|
+
from careamics.config.support import (
|
|
10
|
+
SupportedAlgorithm,
|
|
11
|
+
SupportedArchitecture,
|
|
12
|
+
SupportedLoss,
|
|
13
|
+
SupportedOptimizer,
|
|
14
|
+
SupportedScheduler,
|
|
15
|
+
)
|
|
16
|
+
from careamics.losses import loss_factory
|
|
17
|
+
from careamics.models.model_factory import model_factory
|
|
18
|
+
from careamics.transforms import Denormalize, ImageRestorationTTA
|
|
19
|
+
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CAREamicsModule(L.LightningModule):
|
|
23
|
+
"""
|
|
24
|
+
CAREamics Lightning module.
|
|
25
|
+
|
|
26
|
+
This class encapsulates the PyTorch model along with the training, validation,
|
|
27
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
algorithm_config : AlgorithmModel or dict
|
|
32
|
+
Algorithm configuration.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
model : torch.nn.Module
|
|
37
|
+
PyTorch model.
|
|
38
|
+
loss_func : torch.nn.Module
|
|
39
|
+
Loss function.
|
|
40
|
+
optimizer_name : str
|
|
41
|
+
Optimizer name.
|
|
42
|
+
optimizer_params : dict
|
|
43
|
+
Optimizer parameters.
|
|
44
|
+
lr_scheduler_name : str
|
|
45
|
+
Learning rate scheduler name.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
|
|
49
|
+
"""Lightning module for CAREamics.
|
|
50
|
+
|
|
51
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
52
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
algorithm_config : AlgorithmModel or dict
|
|
57
|
+
Algorithm configuration.
|
|
58
|
+
"""
|
|
59
|
+
super().__init__()
|
|
60
|
+
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
61
|
+
if isinstance(algorithm_config, dict):
|
|
62
|
+
algorithm_config = AlgorithmConfig(**algorithm_config)
|
|
63
|
+
|
|
64
|
+
# create model and loss function
|
|
65
|
+
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
66
|
+
self.loss_func = loss_factory(algorithm_config.loss)
|
|
67
|
+
|
|
68
|
+
# save optimizer and lr_scheduler names and parameters
|
|
69
|
+
self.optimizer_name = algorithm_config.optimizer.name
|
|
70
|
+
self.optimizer_params = algorithm_config.optimizer.parameters
|
|
71
|
+
self.lr_scheduler_name = algorithm_config.lr_scheduler.name
|
|
72
|
+
self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters
|
|
73
|
+
|
|
74
|
+
def forward(self, x: Any) -> Any:
|
|
75
|
+
"""Forward pass.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
x : Any
|
|
80
|
+
Input tensor.
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
Any
|
|
85
|
+
Output tensor.
|
|
86
|
+
"""
|
|
87
|
+
return self.model(x)
|
|
88
|
+
|
|
89
|
+
def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
90
|
+
"""Training step.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
batch : torch.Tensor
|
|
95
|
+
Input batch.
|
|
96
|
+
batch_idx : Any
|
|
97
|
+
Batch index.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
Any
|
|
102
|
+
Loss value.
|
|
103
|
+
"""
|
|
104
|
+
# TODO can N2V be simplified by returning mask*original_patch
|
|
105
|
+
x, *aux = batch
|
|
106
|
+
out = self.model(x)
|
|
107
|
+
loss = self.loss_func(out, *aux)
|
|
108
|
+
self.log(
|
|
109
|
+
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
110
|
+
)
|
|
111
|
+
return loss
|
|
112
|
+
|
|
113
|
+
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
114
|
+
"""Validation step.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
batch : torch.Tensor
|
|
119
|
+
Input batch.
|
|
120
|
+
batch_idx : Any
|
|
121
|
+
Batch index.
|
|
122
|
+
"""
|
|
123
|
+
x, *aux = batch
|
|
124
|
+
out = self.model(x)
|
|
125
|
+
val_loss = self.loss_func(out, *aux)
|
|
126
|
+
|
|
127
|
+
# log validation loss
|
|
128
|
+
self.log(
|
|
129
|
+
"val_loss",
|
|
130
|
+
val_loss,
|
|
131
|
+
on_step=False,
|
|
132
|
+
on_epoch=True,
|
|
133
|
+
prog_bar=True,
|
|
134
|
+
logger=True,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
138
|
+
"""Prediction step.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
batch : torch.Tensor
|
|
143
|
+
Input batch.
|
|
144
|
+
batch_idx : Any
|
|
145
|
+
Batch index.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
Any
|
|
150
|
+
Model output.
|
|
151
|
+
"""
|
|
152
|
+
if self._trainer.datamodule.tiled:
|
|
153
|
+
x, *aux = batch
|
|
154
|
+
else:
|
|
155
|
+
x = batch
|
|
156
|
+
aux = []
|
|
157
|
+
|
|
158
|
+
# apply test-time augmentation if available
|
|
159
|
+
# TODO: probably wont work with batch size > 1
|
|
160
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
161
|
+
tta = ImageRestorationTTA()
|
|
162
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
163
|
+
augmented_output = []
|
|
164
|
+
for augmented in augmented_batch:
|
|
165
|
+
augmented_pred = self.model(augmented)
|
|
166
|
+
augmented_output.append(augmented_pred)
|
|
167
|
+
output = tta.backward(augmented_output)
|
|
168
|
+
else:
|
|
169
|
+
output = self.model(x)
|
|
170
|
+
|
|
171
|
+
# Denormalize the output
|
|
172
|
+
denorm = Denormalize(
|
|
173
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
174
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
175
|
+
)
|
|
176
|
+
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
177
|
+
|
|
178
|
+
if len(aux) > 0: # aux can be tiling information
|
|
179
|
+
return denormalized_output, *aux
|
|
180
|
+
else:
|
|
181
|
+
return denormalized_output
|
|
182
|
+
|
|
183
|
+
def configure_optimizers(self) -> Any:
|
|
184
|
+
"""Configure optimizers and learning rate schedulers.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
Any
|
|
189
|
+
Optimizer and learning rate scheduler.
|
|
190
|
+
"""
|
|
191
|
+
# instantiate optimizer
|
|
192
|
+
optimizer_func = get_optimizer(self.optimizer_name)
|
|
193
|
+
optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
|
|
194
|
+
|
|
195
|
+
# and scheduler
|
|
196
|
+
scheduler_func = get_scheduler(self.lr_scheduler_name)
|
|
197
|
+
scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
|
|
198
|
+
|
|
199
|
+
return {
|
|
200
|
+
"optimizer": optimizer,
|
|
201
|
+
"lr_scheduler": scheduler,
|
|
202
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def create_careamics_module(
|
|
207
|
+
algorithm: Union[SupportedAlgorithm, str],
|
|
208
|
+
loss: Union[SupportedLoss, str],
|
|
209
|
+
architecture: Union[SupportedArchitecture, str],
|
|
210
|
+
model_parameters: Optional[dict] = None,
|
|
211
|
+
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
212
|
+
optimizer_parameters: Optional[dict] = None,
|
|
213
|
+
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
214
|
+
lr_scheduler_parameters: Optional[dict] = None,
|
|
215
|
+
) -> CAREamicsModule:
|
|
216
|
+
"""Create a CAREamics Lithgning module.
|
|
217
|
+
|
|
218
|
+
This function exposes parameters used to create an AlgorithmModel instance,
|
|
219
|
+
triggering parameters validation.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
algorithm : SupportedAlgorithm or str
|
|
224
|
+
Algorithm to use for training (see SupportedAlgorithm).
|
|
225
|
+
loss : SupportedLoss or str
|
|
226
|
+
Loss function to use for training (see SupportedLoss).
|
|
227
|
+
architecture : SupportedArchitecture or str
|
|
228
|
+
Model architecture to use for training (see SupportedArchitecture).
|
|
229
|
+
model_parameters : dict, optional
|
|
230
|
+
Model parameters to use for training, by default {}. Model parameters are
|
|
231
|
+
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
232
|
+
`careamics.config.architectures`).
|
|
233
|
+
optimizer : SupportedOptimizer or str, optional
|
|
234
|
+
Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
|
|
235
|
+
optimizer_parameters : dict, optional
|
|
236
|
+
Optimizer parameters to use for training, as defined in `torch.optim`, by
|
|
237
|
+
default {}.
|
|
238
|
+
lr_scheduler : SupportedScheduler or str, optional
|
|
239
|
+
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
240
|
+
(see SupportedScheduler).
|
|
241
|
+
lr_scheduler_parameters : dict, optional
|
|
242
|
+
Learning rate scheduler parameters to use for training, as defined in
|
|
243
|
+
`torch.optim`, by default {}.
|
|
244
|
+
|
|
245
|
+
Returns
|
|
246
|
+
-------
|
|
247
|
+
CAREamicsModule
|
|
248
|
+
CAREamics Lightning module.
|
|
249
|
+
"""
|
|
250
|
+
# create a AlgorithmModel compatible dictionary
|
|
251
|
+
if lr_scheduler_parameters is None:
|
|
252
|
+
lr_scheduler_parameters = {}
|
|
253
|
+
if optimizer_parameters is None:
|
|
254
|
+
optimizer_parameters = {}
|
|
255
|
+
if model_parameters is None:
|
|
256
|
+
model_parameters = {}
|
|
257
|
+
algorithm_configuration = {
|
|
258
|
+
"algorithm": algorithm,
|
|
259
|
+
"loss": loss,
|
|
260
|
+
"optimizer": {
|
|
261
|
+
"name": optimizer,
|
|
262
|
+
"parameters": optimizer_parameters,
|
|
263
|
+
},
|
|
264
|
+
"lr_scheduler": {
|
|
265
|
+
"name": lr_scheduler,
|
|
266
|
+
"parameters": lr_scheduler_parameters,
|
|
267
|
+
},
|
|
268
|
+
}
|
|
269
|
+
model_configuration = {"architecture": architecture}
|
|
270
|
+
model_configuration.update(model_parameters)
|
|
271
|
+
|
|
272
|
+
# add model parameters to algorithm configuration
|
|
273
|
+
algorithm_configuration["model"] = model_configuration
|
|
274
|
+
|
|
275
|
+
# call the parent init using an AlgorithmModel instance
|
|
276
|
+
return CAREamicsModule(AlgorithmConfig(**algorithm_configuration))
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""Prediction Lightning data modules."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytorch_lightning as L
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
|
|
11
|
+
from careamics.config import InferenceConfig
|
|
12
|
+
from careamics.config.support import SupportedData
|
|
13
|
+
from careamics.dataset import (
|
|
14
|
+
InMemoryPredDataset,
|
|
15
|
+
InMemoryTiledPredDataset,
|
|
16
|
+
IterablePredDataset,
|
|
17
|
+
IterableTiledPredDataset,
|
|
18
|
+
)
|
|
19
|
+
from careamics.dataset.dataset_utils import list_files
|
|
20
|
+
from careamics.dataset.tiling.collate_tiles import collate_tiles
|
|
21
|
+
from careamics.file_io.read import get_read_func
|
|
22
|
+
from careamics.utils import get_logger
|
|
23
|
+
|
|
24
|
+
PredictDatasetType = Union[
|
|
25
|
+
InMemoryPredDataset,
|
|
26
|
+
InMemoryTiledPredDataset,
|
|
27
|
+
IterablePredDataset,
|
|
28
|
+
IterableTiledPredDataset,
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PredictDataModule(L.LightningDataModule):
|
|
35
|
+
"""
|
|
36
|
+
CAREamics Lightning prediction data module.
|
|
37
|
+
|
|
38
|
+
The data module can be used with Path, str or numpy arrays. The data can be either
|
|
39
|
+
a folder containing images or a single file.
|
|
40
|
+
|
|
41
|
+
To read custom data types, you can set `data_type` to `custom` in `data_config`
|
|
42
|
+
and provide a function that returns a numpy array from a path as
|
|
43
|
+
`read_source_func` parameter. The function will receive a Path object and
|
|
44
|
+
an axies string as arguments, the axes being derived from the `data_config`.
|
|
45
|
+
|
|
46
|
+
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
47
|
+
"*.czi") to filter the files extension using `extension_filter`.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
pred_config : InferenceModel
|
|
52
|
+
Pydantic model for CAREamics prediction configuration.
|
|
53
|
+
pred_data : pathlib.Path or str or numpy.ndarray
|
|
54
|
+
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
55
|
+
read_source_func : Callable, optional
|
|
56
|
+
Function to read custom types, by default None.
|
|
57
|
+
extension_filter : str, optional
|
|
58
|
+
Filter to filter file extensions for custom types, by default "".
|
|
59
|
+
dataloader_params : dict, optional
|
|
60
|
+
Dataloader parameters, by default {}.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
pred_config: InferenceConfig,
|
|
66
|
+
pred_data: Union[Path, str, NDArray],
|
|
67
|
+
read_source_func: Optional[Callable] = None,
|
|
68
|
+
extension_filter: str = "",
|
|
69
|
+
dataloader_params: Optional[dict] = None,
|
|
70
|
+
) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Constructor.
|
|
73
|
+
|
|
74
|
+
The data module can be used with Path, str or numpy arrays. The data can be
|
|
75
|
+
either a folder containing images or a single file.
|
|
76
|
+
|
|
77
|
+
To read custom data types, you can set `data_type` to `custom` in `data_config`
|
|
78
|
+
and provide a function that returns a numpy array from a path as
|
|
79
|
+
`read_source_func` parameter. The function will receive a Path object and
|
|
80
|
+
an axies string as arguments, the axes being derived from the `data_config`.
|
|
81
|
+
|
|
82
|
+
You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
|
|
83
|
+
"*.czi") to filter the files extension using `extension_filter`.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
pred_config : InferenceModel
|
|
88
|
+
Pydantic model for CAREamics prediction configuration.
|
|
89
|
+
pred_data : pathlib.Path or str or numpy.ndarray
|
|
90
|
+
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
91
|
+
read_source_func : Callable, optional
|
|
92
|
+
Function to read custom types, by default None.
|
|
93
|
+
extension_filter : str, optional
|
|
94
|
+
Filter to filter file extensions for custom types, by default "".
|
|
95
|
+
dataloader_params : dict, optional
|
|
96
|
+
Dataloader parameters, by default {}.
|
|
97
|
+
|
|
98
|
+
Raises
|
|
99
|
+
------
|
|
100
|
+
ValueError
|
|
101
|
+
If the data type is `custom` and no `read_source_func` is provided.
|
|
102
|
+
ValueError
|
|
103
|
+
If the data type is `array` and the input is not a numpy array.
|
|
104
|
+
ValueError
|
|
105
|
+
If the data type is `tiff` and the input is neither a Path nor a str.
|
|
106
|
+
"""
|
|
107
|
+
if dataloader_params is None:
|
|
108
|
+
dataloader_params = {}
|
|
109
|
+
if dataloader_params is None:
|
|
110
|
+
dataloader_params = {}
|
|
111
|
+
super().__init__()
|
|
112
|
+
|
|
113
|
+
# check that a read source function is provided for custom types
|
|
114
|
+
if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Data type {SupportedData.CUSTOM} is not allowed without "
|
|
117
|
+
f"specifying a `read_source_func` and an `extension_filer`."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# check correct input type
|
|
121
|
+
if (
|
|
122
|
+
isinstance(pred_data, np.ndarray)
|
|
123
|
+
and pred_config.data_type != SupportedData.ARRAY
|
|
124
|
+
):
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Received a numpy array as input, but the data type was set to "
|
|
127
|
+
f"{pred_config.data_type}. Set the data type "
|
|
128
|
+
f"to {SupportedData.ARRAY} to predict on numpy arrays."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# and that Path or str are passed, if tiff file type specified
|
|
132
|
+
elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
|
|
133
|
+
pred_config.data_type != SupportedData.TIFF
|
|
134
|
+
and pred_config.data_type != SupportedData.CUSTOM
|
|
135
|
+
):
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"Received a path as input, but the data type was neither set to "
|
|
138
|
+
f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
|
|
139
|
+
f" to {SupportedData.TIFF} or "
|
|
140
|
+
f"{SupportedData.CUSTOM} to predict on files."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# configuration data
|
|
144
|
+
self.prediction_config = pred_config
|
|
145
|
+
self.data_type = pred_config.data_type
|
|
146
|
+
self.batch_size = pred_config.batch_size
|
|
147
|
+
self.dataloader_params = dataloader_params
|
|
148
|
+
|
|
149
|
+
self.pred_data = pred_data
|
|
150
|
+
self.tile_size = pred_config.tile_size
|
|
151
|
+
self.tile_overlap = pred_config.tile_overlap
|
|
152
|
+
|
|
153
|
+
# check if it is tiled
|
|
154
|
+
self.tiled = self.tile_size is not None and self.tile_overlap is not None
|
|
155
|
+
|
|
156
|
+
# read source function
|
|
157
|
+
if pred_config.data_type == SupportedData.CUSTOM:
|
|
158
|
+
# mypy check
|
|
159
|
+
assert read_source_func is not None
|
|
160
|
+
|
|
161
|
+
self.read_source_func: Callable = read_source_func
|
|
162
|
+
elif pred_config.data_type != SupportedData.ARRAY:
|
|
163
|
+
self.read_source_func = get_read_func(pred_config.data_type)
|
|
164
|
+
|
|
165
|
+
self.extension_filter = extension_filter
|
|
166
|
+
|
|
167
|
+
def prepare_data(self) -> None:
|
|
168
|
+
"""Hook used to prepare the data before calling `setup`."""
|
|
169
|
+
# if the data is a Path or a str
|
|
170
|
+
if not isinstance(self.pred_data, np.ndarray):
|
|
171
|
+
self.pred_files = list_files(
|
|
172
|
+
self.pred_data, self.data_type, self.extension_filter
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
176
|
+
"""
|
|
177
|
+
Hook called at the beginning of predict.
|
|
178
|
+
|
|
179
|
+
Parameters
|
|
180
|
+
----------
|
|
181
|
+
stage : Optional[str], optional
|
|
182
|
+
Stage, by default None.
|
|
183
|
+
"""
|
|
184
|
+
# if numpy array
|
|
185
|
+
if self.data_type == SupportedData.ARRAY:
|
|
186
|
+
if self.tiled:
|
|
187
|
+
self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
|
|
188
|
+
prediction_config=self.prediction_config,
|
|
189
|
+
inputs=self.pred_data,
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
self.predict_dataset = InMemoryPredDataset(
|
|
193
|
+
prediction_config=self.prediction_config,
|
|
194
|
+
inputs=self.pred_data,
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
if self.tiled:
|
|
198
|
+
self.predict_dataset = IterableTiledPredDataset(
|
|
199
|
+
prediction_config=self.prediction_config,
|
|
200
|
+
src_files=self.pred_files,
|
|
201
|
+
read_source_func=self.read_source_func,
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
self.predict_dataset = IterablePredDataset(
|
|
205
|
+
prediction_config=self.prediction_config,
|
|
206
|
+
src_files=self.pred_files,
|
|
207
|
+
read_source_func=self.read_source_func,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def predict_dataloader(self) -> DataLoader:
|
|
211
|
+
"""
|
|
212
|
+
Create a dataloader for prediction.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
DataLoader
|
|
217
|
+
Prediction dataloader.
|
|
218
|
+
"""
|
|
219
|
+
return DataLoader(
|
|
220
|
+
self.predict_dataset,
|
|
221
|
+
batch_size=self.batch_size,
|
|
222
|
+
collate_fn=collate_tiles if self.tiled else None,
|
|
223
|
+
**self.dataloader_params,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def create_predict_datamodule(
|
|
228
|
+
pred_data: Union[str, Path, NDArray],
|
|
229
|
+
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
230
|
+
axes: str,
|
|
231
|
+
image_means: list[float],
|
|
232
|
+
image_stds: list[float],
|
|
233
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
234
|
+
tile_overlap: Optional[tuple[int, ...]] = None,
|
|
235
|
+
batch_size: int = 1,
|
|
236
|
+
tta_transforms: bool = True,
|
|
237
|
+
read_source_func: Optional[Callable] = None,
|
|
238
|
+
extension_filter: str = "",
|
|
239
|
+
dataloader_params: Optional[dict] = None,
|
|
240
|
+
) -> PredictDataModule:
|
|
241
|
+
"""Create a CAREamics prediction Lightning datamodule.
|
|
242
|
+
|
|
243
|
+
This function is used to explicitely pass the parameters usually contained in an
|
|
244
|
+
`inference_model` configuration.
|
|
245
|
+
|
|
246
|
+
Since the lightning datamodule has no access to the model, make sure that the
|
|
247
|
+
parameters passed to the datamodule are consistent with the model's requirements
|
|
248
|
+
and are coherent. This can be done by creating a `Configuration` object beforehand
|
|
249
|
+
and passing its parameters to the different Lightning modules.
|
|
250
|
+
|
|
251
|
+
The data module can be used with Path, str or numpy arrays. To use array data, set
|
|
252
|
+
`data_type` to `array` and pass a numpy array to `train_data`.
|
|
253
|
+
|
|
254
|
+
By default, CAREamics only supports types defined in
|
|
255
|
+
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
256
|
+
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
257
|
+
path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
|
|
258
|
+
(e.g. "*.jpeg") to filter the files extension using `extension_filter`.
|
|
259
|
+
|
|
260
|
+
In `dataloader_params`, you can pass any parameter accepted by PyTorch
|
|
261
|
+
dataloaders, except for `batch_size`, which is set by the `batch_size`
|
|
262
|
+
parameter.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
pred_data : str or pathlib.Path or numpy.ndarray
|
|
267
|
+
Prediction data.
|
|
268
|
+
data_type : {"array", "tiff", "custom"}
|
|
269
|
+
Data type, see `SupportedData` for available options.
|
|
270
|
+
axes : str
|
|
271
|
+
Axes of the data, choosen among SCZYX.
|
|
272
|
+
image_means : list of float
|
|
273
|
+
Mean values for normalization, only used if Normalization is defined.
|
|
274
|
+
image_stds : list of float
|
|
275
|
+
Std values for normalization, only used if Normalization is defined.
|
|
276
|
+
tile_size : tuple of int, optional
|
|
277
|
+
Tile size, 2D or 3D tile size.
|
|
278
|
+
tile_overlap : tuple of int, optional
|
|
279
|
+
Tile overlap, 2D or 3D tile overlap.
|
|
280
|
+
batch_size : int
|
|
281
|
+
Batch size.
|
|
282
|
+
tta_transforms : bool, optional
|
|
283
|
+
Use test time augmentation, by default True.
|
|
284
|
+
read_source_func : Callable, optional
|
|
285
|
+
Function to read the source data, used if `data_type` is `custom`, by
|
|
286
|
+
default None.
|
|
287
|
+
extension_filter : str, optional
|
|
288
|
+
Filter for file extensions, used if `data_type` is `custom`, by default "".
|
|
289
|
+
dataloader_params : dict, optional
|
|
290
|
+
Pytorch dataloader parameters, by default {}.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
PredictDataModule
|
|
295
|
+
CAREamics prediction datamodule.
|
|
296
|
+
|
|
297
|
+
Notes
|
|
298
|
+
-----
|
|
299
|
+
If you are using a UNet model and tiling, the tile size must be
|
|
300
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
301
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
302
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
303
|
+
happen in the Z dimension, consider padding your image.
|
|
304
|
+
"""
|
|
305
|
+
if dataloader_params is None:
|
|
306
|
+
dataloader_params = {}
|
|
307
|
+
|
|
308
|
+
prediction_dict: dict[str, Any] = {
|
|
309
|
+
"data_type": data_type,
|
|
310
|
+
"tile_size": tile_size,
|
|
311
|
+
"tile_overlap": tile_overlap,
|
|
312
|
+
"axes": axes,
|
|
313
|
+
"image_means": image_means,
|
|
314
|
+
"image_stds": image_stds,
|
|
315
|
+
"tta_transforms": tta_transforms,
|
|
316
|
+
"batch_size": batch_size,
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
# validate configuration
|
|
320
|
+
prediction_config = InferenceConfig(**prediction_dict)
|
|
321
|
+
|
|
322
|
+
# sanity check on the dataloader parameters
|
|
323
|
+
if "batch_size" in dataloader_params:
|
|
324
|
+
# remove it
|
|
325
|
+
del dataloader_params["batch_size"]
|
|
326
|
+
|
|
327
|
+
return PredictDataModule(
|
|
328
|
+
pred_config=prediction_config,
|
|
329
|
+
pred_data=pred_data,
|
|
330
|
+
read_source_func=read_source_func,
|
|
331
|
+
extension_filter=extension_filter,
|
|
332
|
+
dataloader_params=dataloader_params,
|
|
333
|
+
)
|