careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +7 -1
- careamics/bioimage/__init__.py +15 -0
- careamics/bioimage/docs/Noise2Void.md +5 -0
- careamics/bioimage/docs/__init__.py +1 -0
- careamics/bioimage/io.py +182 -0
- careamics/bioimage/rdf.py +105 -0
- careamics/config/__init__.py +11 -0
- careamics/config/algorithm.py +231 -0
- careamics/config/config.py +297 -0
- careamics/config/config_filter.py +44 -0
- careamics/config/data.py +194 -0
- careamics/config/torch_optim.py +118 -0
- careamics/config/training.py +534 -0
- careamics/dataset/__init__.py +1 -0
- careamics/dataset/dataset_utils.py +111 -0
- careamics/dataset/extraction_strategy.py +21 -0
- careamics/dataset/in_memory_dataset.py +202 -0
- careamics/dataset/patching.py +492 -0
- careamics/dataset/prepare_dataset.py +175 -0
- careamics/dataset/tiff_dataset.py +212 -0
- careamics/engine.py +1014 -0
- careamics/losses/__init__.py +4 -0
- careamics/losses/loss_factory.py +38 -0
- careamics/losses/losses.py +34 -0
- careamics/manipulation/__init__.py +4 -0
- careamics/manipulation/pixel_manipulation.py +158 -0
- careamics/models/__init__.py +4 -0
- careamics/models/layers.py +152 -0
- careamics/models/model_factory.py +251 -0
- careamics/models/unet.py +322 -0
- careamics/prediction/__init__.py +9 -0
- careamics/prediction/prediction_utils.py +106 -0
- careamics/utils/__init__.py +20 -0
- careamics/utils/ascii_logo.txt +9 -0
- careamics/utils/augment.py +65 -0
- careamics/utils/context.py +45 -0
- careamics/utils/logging.py +321 -0
- careamics/utils/metrics.py +160 -0
- careamics/utils/normalization.py +55 -0
- careamics/utils/torch_utils.py +89 -0
- careamics/utils/validators.py +170 -0
- careamics/utils/wandb.py +121 -0
- careamics-0.1.0rc2.dist-info/METADATA +81 -0
- careamics-0.1.0rc2.dist-info/RECORD +47 -0
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
careamics/engine.py
ADDED
|
@@ -0,0 +1,1014 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Engine module.
|
|
3
|
+
|
|
4
|
+
This module contains the main CAREamics class, the Engine. The Engine allows training
|
|
5
|
+
a model and using it for prediction.
|
|
6
|
+
"""
|
|
7
|
+
from logging import FileHandler
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
14
|
+
|
|
15
|
+
from .bioimage import (
|
|
16
|
+
get_default_model_specs,
|
|
17
|
+
save_bioimage_model,
|
|
18
|
+
)
|
|
19
|
+
from .config import Configuration, load_configuration
|
|
20
|
+
from .dataset.prepare_dataset import (
|
|
21
|
+
get_prediction_dataset,
|
|
22
|
+
get_train_dataset,
|
|
23
|
+
get_validation_dataset,
|
|
24
|
+
)
|
|
25
|
+
from .losses import create_loss_function
|
|
26
|
+
from .models import create_model
|
|
27
|
+
from .prediction import (
|
|
28
|
+
stitch_prediction,
|
|
29
|
+
tta_backward,
|
|
30
|
+
tta_forward,
|
|
31
|
+
)
|
|
32
|
+
from .utils import (
|
|
33
|
+
MetricTracker,
|
|
34
|
+
add_axes,
|
|
35
|
+
denormalize,
|
|
36
|
+
get_device,
|
|
37
|
+
normalize,
|
|
38
|
+
)
|
|
39
|
+
from .utils.logging import ProgressBar, get_logger
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Engine:
|
|
43
|
+
"""
|
|
44
|
+
Class allowing training of a model and subsequent prediction.
|
|
45
|
+
|
|
46
|
+
There are three ways to instantiate an Engine:
|
|
47
|
+
1. With a CAREamics model (.pth), by passing a path.
|
|
48
|
+
2. With a configuration object.
|
|
49
|
+
3. With a configuration file, by passing a path.
|
|
50
|
+
|
|
51
|
+
In each case, the parameter name must be provided explicitly. For example:
|
|
52
|
+
>>> engine = Engine(config_path="path/to/config.yaml")
|
|
53
|
+
|
|
54
|
+
Note that only one of these options can be used at a time, in the order listed
|
|
55
|
+
above.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
config : Optional[Configuration], optional
|
|
60
|
+
Configuration object, by default None.
|
|
61
|
+
config_path : Optional[Union[str, Path]], optional
|
|
62
|
+
Path to configuration file, by default None.
|
|
63
|
+
model_path : Optional[Union[str, Path]], optional
|
|
64
|
+
Path to model file, by default None.
|
|
65
|
+
seed : int, optional
|
|
66
|
+
Seed for reproducibility, by default 42.
|
|
67
|
+
|
|
68
|
+
Attributes
|
|
69
|
+
----------
|
|
70
|
+
cfg : Configuration
|
|
71
|
+
Configuration.
|
|
72
|
+
device : torch.device
|
|
73
|
+
Device (CPU or GPU).
|
|
74
|
+
model : torch.nn.Module
|
|
75
|
+
Model.
|
|
76
|
+
optimizer : torch.optim.Optimizer
|
|
77
|
+
Optimizer.
|
|
78
|
+
lr_scheduler : torch.optim.lr_scheduler._LRScheduler
|
|
79
|
+
Learning rate scheduler.
|
|
80
|
+
scaler : torch.cuda.amp.GradScaler
|
|
81
|
+
Gradient scaler.
|
|
82
|
+
loss_func : Callable
|
|
83
|
+
Loss function.
|
|
84
|
+
logger : logging.Logger
|
|
85
|
+
Logger.
|
|
86
|
+
use_wandb : bool
|
|
87
|
+
Whether to use wandb.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
*,
|
|
93
|
+
config: Optional[Configuration] = None,
|
|
94
|
+
config_path: Optional[Union[str, Path]] = None,
|
|
95
|
+
model_path: Optional[Union[str, Path]] = None,
|
|
96
|
+
seed: Optional[int] = 42,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Constructor.
|
|
100
|
+
|
|
101
|
+
To disable the seed, set it to None.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
config : Optional[Configuration], optional
|
|
106
|
+
Configuration object, by default None.
|
|
107
|
+
config_path : Optional[Union[str, Path]], optional
|
|
108
|
+
Path to configuration file, by default None.
|
|
109
|
+
model_path : Optional[Union[str, Path]], optional
|
|
110
|
+
Path to model file, by default None.
|
|
111
|
+
seed : int, optional
|
|
112
|
+
Seed for reproducibility, by default 42.
|
|
113
|
+
|
|
114
|
+
Raises
|
|
115
|
+
------
|
|
116
|
+
ValueError
|
|
117
|
+
If all three parameters are None.
|
|
118
|
+
FileNotFoundError
|
|
119
|
+
If the model or configuration path is provided but does not exist.
|
|
120
|
+
TypeError
|
|
121
|
+
If the configuration is not a Configuration object.
|
|
122
|
+
UsageError
|
|
123
|
+
If wandb is not correctly installed.
|
|
124
|
+
ModuleNotFoundError
|
|
125
|
+
If wandb is not installed.
|
|
126
|
+
ValueError
|
|
127
|
+
If the configuration failed to configure.
|
|
128
|
+
"""
|
|
129
|
+
if model_path is not None:
|
|
130
|
+
if not Path(model_path).exists():
|
|
131
|
+
raise FileNotFoundError(
|
|
132
|
+
f"Model path {model_path} is incorrect or"
|
|
133
|
+
f" does not exist. Current working directory is: {Path.cwd()!s}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Ensure that config is None
|
|
137
|
+
self.cfg = None
|
|
138
|
+
|
|
139
|
+
elif config is not None:
|
|
140
|
+
# Check that config is a Configuration object
|
|
141
|
+
if not isinstance(config, Configuration):
|
|
142
|
+
raise TypeError(
|
|
143
|
+
f"config must be a Configuration object, got {type(config)}"
|
|
144
|
+
)
|
|
145
|
+
self.cfg = config
|
|
146
|
+
elif config_path is not None:
|
|
147
|
+
self.cfg = load_configuration(config_path)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
"No configuration or path provided. One of configuration "
|
|
151
|
+
"object, configuration path or model path must be provided."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# get device, CPU or GPU
|
|
155
|
+
self.device = get_device()
|
|
156
|
+
|
|
157
|
+
# Create model, optimizer, lr scheduler and gradient scaler and load everything
|
|
158
|
+
# to the specified device
|
|
159
|
+
(
|
|
160
|
+
self.model,
|
|
161
|
+
self.optimizer,
|
|
162
|
+
self.lr_scheduler,
|
|
163
|
+
self.scaler,
|
|
164
|
+
self.cfg,
|
|
165
|
+
) = create_model(config=self.cfg, model_path=model_path, device=self.device)
|
|
166
|
+
assert self.cfg is not None
|
|
167
|
+
|
|
168
|
+
# create loss function
|
|
169
|
+
self.loss_func = create_loss_function(self.cfg)
|
|
170
|
+
|
|
171
|
+
# Set logging
|
|
172
|
+
log_path = self.cfg.working_directory / "log.txt"
|
|
173
|
+
self.logger = get_logger(__name__, log_path=log_path)
|
|
174
|
+
|
|
175
|
+
# wandb
|
|
176
|
+
self.use_wandb = self.cfg.training.use_wandb
|
|
177
|
+
|
|
178
|
+
if self.use_wandb:
|
|
179
|
+
try:
|
|
180
|
+
from wandb.errors import UsageError
|
|
181
|
+
|
|
182
|
+
from careamics.utils.wandb import WandBLogging
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
self.wandb = WandBLogging(
|
|
186
|
+
experiment_name=self.cfg.experiment_name,
|
|
187
|
+
log_path=self.cfg.working_directory,
|
|
188
|
+
config=self.cfg,
|
|
189
|
+
model_to_watch=self.model,
|
|
190
|
+
)
|
|
191
|
+
except UsageError as e:
|
|
192
|
+
self.logger.warning(
|
|
193
|
+
f"Wandb usage error, using default logger. Check whether "
|
|
194
|
+
f"wandb correctly configured:\n"
|
|
195
|
+
f"{e}"
|
|
196
|
+
)
|
|
197
|
+
self.use_wandb = False
|
|
198
|
+
|
|
199
|
+
except ModuleNotFoundError:
|
|
200
|
+
self.logger.warning(
|
|
201
|
+
"Wandb not installed, using default logger. Try pip install "
|
|
202
|
+
"wandb"
|
|
203
|
+
)
|
|
204
|
+
self.use_wandb = False
|
|
205
|
+
|
|
206
|
+
# BMZ inputs/outputs placeholders, filled during validation
|
|
207
|
+
self._input = None
|
|
208
|
+
self._outputs = None
|
|
209
|
+
|
|
210
|
+
# torch version
|
|
211
|
+
self.torch_version = torch.__version__
|
|
212
|
+
|
|
213
|
+
def train(
|
|
214
|
+
self,
|
|
215
|
+
train_path: str,
|
|
216
|
+
val_path: str,
|
|
217
|
+
) -> Tuple[List[Any], List[Any]]:
|
|
218
|
+
"""
|
|
219
|
+
Train the network.
|
|
220
|
+
|
|
221
|
+
The training and validation data given by the paths must be compatible with the
|
|
222
|
+
axes and data format provided in the configuration.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
train_path : Union[str, Path]
|
|
227
|
+
Path to the training data.
|
|
228
|
+
val_path : Union[str, Path]
|
|
229
|
+
Path to the validation data.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
Tuple[List[Any], List[Any]]
|
|
234
|
+
Tuple of training and validation statistics.
|
|
235
|
+
|
|
236
|
+
Raises
|
|
237
|
+
------
|
|
238
|
+
ValueError
|
|
239
|
+
Raise a ValueError if the configuration is missing.
|
|
240
|
+
"""
|
|
241
|
+
if self.cfg is None:
|
|
242
|
+
raise ValueError("Configuration is not defined, cannot train.")
|
|
243
|
+
|
|
244
|
+
# General func
|
|
245
|
+
train_loader = self._get_train_dataloader(train_path)
|
|
246
|
+
|
|
247
|
+
# Set mean and std from train dataset of none
|
|
248
|
+
if self.cfg.data.mean is None or self.cfg.data.std is None:
|
|
249
|
+
self.cfg.data.set_mean_and_std(
|
|
250
|
+
train_loader.dataset.mean, train_loader.dataset.std
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
eval_loader = self._get_val_dataloader(val_path)
|
|
254
|
+
self.logger.info(f"Starting training for {self.cfg.training.num_epochs} epochs")
|
|
255
|
+
|
|
256
|
+
val_losses = []
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
train_stats = []
|
|
260
|
+
eval_stats = []
|
|
261
|
+
|
|
262
|
+
# loop over the dataset multiple times
|
|
263
|
+
for epoch in range(self.cfg.training.num_epochs):
|
|
264
|
+
if hasattr(train_loader.dataset, "__len__"):
|
|
265
|
+
epoch_size = train_loader.__len__()
|
|
266
|
+
else:
|
|
267
|
+
epoch_size = None
|
|
268
|
+
|
|
269
|
+
progress_bar = ProgressBar(
|
|
270
|
+
max_value=epoch_size,
|
|
271
|
+
epoch=epoch,
|
|
272
|
+
num_epochs=self.cfg.training.num_epochs,
|
|
273
|
+
mode="train",
|
|
274
|
+
)
|
|
275
|
+
# train_epoch = train_op(self._train_single_epoch,)
|
|
276
|
+
# Perform training step
|
|
277
|
+
train_outputs, epoch_size = self._train_single_epoch(
|
|
278
|
+
train_loader,
|
|
279
|
+
progress_bar,
|
|
280
|
+
self.cfg.training.amp.use,
|
|
281
|
+
)
|
|
282
|
+
# Perform validation step
|
|
283
|
+
eval_outputs = self._evaluate(eval_loader)
|
|
284
|
+
val_losses.append(eval_outputs["loss"])
|
|
285
|
+
learning_rate = self.optimizer.param_groups[0]["lr"]
|
|
286
|
+
|
|
287
|
+
progress_bar.add(
|
|
288
|
+
1,
|
|
289
|
+
values=[
|
|
290
|
+
("train_loss", train_outputs["loss"]),
|
|
291
|
+
("val loss", eval_outputs["loss"]),
|
|
292
|
+
("lr", learning_rate),
|
|
293
|
+
],
|
|
294
|
+
)
|
|
295
|
+
# Add update scheduler rule based on type
|
|
296
|
+
self.lr_scheduler.step(eval_outputs["loss"])
|
|
297
|
+
|
|
298
|
+
if self.use_wandb:
|
|
299
|
+
metrics = {
|
|
300
|
+
"train": train_outputs,
|
|
301
|
+
"eval": eval_outputs,
|
|
302
|
+
"lr": learning_rate,
|
|
303
|
+
}
|
|
304
|
+
self.wandb.log_metrics(metrics)
|
|
305
|
+
|
|
306
|
+
train_stats.append(train_outputs)
|
|
307
|
+
eval_stats.append(eval_outputs)
|
|
308
|
+
|
|
309
|
+
checkpoint_path = self._save_checkpoint(epoch, val_losses, "state_dict")
|
|
310
|
+
self.logger.info(f"Saved checkpoint to {checkpoint_path}")
|
|
311
|
+
|
|
312
|
+
except KeyboardInterrupt:
|
|
313
|
+
self.logger.info("Training interrupted")
|
|
314
|
+
|
|
315
|
+
return train_stats, eval_stats
|
|
316
|
+
|
|
317
|
+
def _train_single_epoch(
|
|
318
|
+
self,
|
|
319
|
+
loader: torch.utils.data.DataLoader,
|
|
320
|
+
progress_bar: ProgressBar,
|
|
321
|
+
amp: bool,
|
|
322
|
+
) -> Tuple[Dict[str, float], int]:
|
|
323
|
+
"""
|
|
324
|
+
Train for a single epoch.
|
|
325
|
+
|
|
326
|
+
Parameters
|
|
327
|
+
----------
|
|
328
|
+
loader : torch.utils.data.DataLoader
|
|
329
|
+
Training dataloader.
|
|
330
|
+
progress_bar : ProgressBar
|
|
331
|
+
Progress bar.
|
|
332
|
+
amp : bool
|
|
333
|
+
Whether to use automatic mixed precision.
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
Tuple[Dict[str, float], int]
|
|
338
|
+
Tuple of training metrics and epoch size.
|
|
339
|
+
|
|
340
|
+
Raises
|
|
341
|
+
------
|
|
342
|
+
ValueError
|
|
343
|
+
If the configuration is missing.
|
|
344
|
+
"""
|
|
345
|
+
if self.cfg is not None:
|
|
346
|
+
avg_loss = MetricTracker()
|
|
347
|
+
self.model.train()
|
|
348
|
+
epoch_size = 0
|
|
349
|
+
|
|
350
|
+
for i, (batch, *auxillary) in enumerate(loader):
|
|
351
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
352
|
+
|
|
353
|
+
with torch.cuda.amp.autocast(enabled=amp):
|
|
354
|
+
outputs = self.model(batch.to(self.device))
|
|
355
|
+
|
|
356
|
+
loss = self.loss_func(
|
|
357
|
+
outputs, *[a.to(self.device) for a in auxillary], self.device
|
|
358
|
+
)
|
|
359
|
+
self.scaler.scale(loss).backward()
|
|
360
|
+
avg_loss.update(loss.detach(), batch.shape[0])
|
|
361
|
+
|
|
362
|
+
progress_bar.update(
|
|
363
|
+
current_step=i,
|
|
364
|
+
batch_size=self.cfg.training.batch_size,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
self.optimizer.step()
|
|
368
|
+
epoch_size += 1
|
|
369
|
+
|
|
370
|
+
return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()}, epoch_size
|
|
371
|
+
else:
|
|
372
|
+
raise ValueError("Configuration is not defined, cannot train.")
|
|
373
|
+
|
|
374
|
+
def _evaluate(self, val_loader: torch.utils.data.DataLoader) -> Dict[str, float]:
|
|
375
|
+
"""
|
|
376
|
+
Perform validation step.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
val_loader : torch.utils.data.DataLoader
|
|
381
|
+
Validation dataloader.
|
|
382
|
+
|
|
383
|
+
Returns
|
|
384
|
+
-------
|
|
385
|
+
Dict[str, float]
|
|
386
|
+
Loss value on the validation set.
|
|
387
|
+
"""
|
|
388
|
+
self.model.eval()
|
|
389
|
+
avg_loss = MetricTracker()
|
|
390
|
+
|
|
391
|
+
with torch.no_grad():
|
|
392
|
+
for patch, *auxillary in val_loader:
|
|
393
|
+
# if inputs is None, record a single patch
|
|
394
|
+
if self._input is None:
|
|
395
|
+
# patch has dimension SC(Z)YX
|
|
396
|
+
self._input = patch.clone().detach().cpu().numpy()
|
|
397
|
+
|
|
398
|
+
# evaluate
|
|
399
|
+
outputs = self.model(patch.to(self.device))
|
|
400
|
+
loss = self.loss_func(
|
|
401
|
+
outputs, *[a.to(self.device) for a in auxillary], self.device
|
|
402
|
+
)
|
|
403
|
+
avg_loss.update(loss.detach(), patch.shape[0])
|
|
404
|
+
return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()}
|
|
405
|
+
|
|
406
|
+
def predict(
|
|
407
|
+
self,
|
|
408
|
+
input: Union[np.ndarray, str, Path],
|
|
409
|
+
*,
|
|
410
|
+
tile_shape: Optional[List[int]] = None,
|
|
411
|
+
overlaps: Optional[List[int]] = None,
|
|
412
|
+
axes: Optional[str] = None,
|
|
413
|
+
tta: bool = True,
|
|
414
|
+
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
415
|
+
"""
|
|
416
|
+
Predict using the current model on an input array or a path to data.
|
|
417
|
+
|
|
418
|
+
The Engine must have previously been trained and mean/std be specified in
|
|
419
|
+
its configuration.
|
|
420
|
+
|
|
421
|
+
Data should be compatible with the axes, either from the configuration or
|
|
422
|
+
as passed using the `axes` parameter. If the batch and channel dimensions are
|
|
423
|
+
missing, then singleton dimensions are added.
|
|
424
|
+
|
|
425
|
+
To use tiling, both `tile_shape` and `overlaps` must be specified, have same
|
|
426
|
+
length, be divisible by 2 and greater than 0. Finally, the overlaps must be
|
|
427
|
+
smaller than the tiles.
|
|
428
|
+
|
|
429
|
+
By setting `tta` to `True`, the prediction is performed using test time
|
|
430
|
+
augmentation, meaning that the input is augmented and the prediction is averaged
|
|
431
|
+
over the augmentations.
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
input : Union[np.ndarra, str, Path]
|
|
436
|
+
Input data, either an array or a path to the data.
|
|
437
|
+
tile_shape : Optional[List[int]], optional
|
|
438
|
+
2D or 3D shape of the tiles to be predicted, by default None.
|
|
439
|
+
overlaps : Optional[List[int]], optional
|
|
440
|
+
2D or 3D overlaps between tiles, by default None.
|
|
441
|
+
axes : Optional[str], optional
|
|
442
|
+
Axes of the input array if different from the one in the configuration, by
|
|
443
|
+
default None.
|
|
444
|
+
tta : bool, optional
|
|
445
|
+
Whether to use test time augmentation, by default True.
|
|
446
|
+
|
|
447
|
+
Returns
|
|
448
|
+
-------
|
|
449
|
+
Union[np.ndarray, List[np.ndarray]]
|
|
450
|
+
Predicted image array of the same shape as the input, or list of arrays
|
|
451
|
+
if the arrays have inconsistent shapes.
|
|
452
|
+
|
|
453
|
+
Raises
|
|
454
|
+
------
|
|
455
|
+
ValueError
|
|
456
|
+
If the configuration is missing.
|
|
457
|
+
ValueError
|
|
458
|
+
If the mean or std are not specified in the configuration (untrained model).
|
|
459
|
+
"""
|
|
460
|
+
if self.cfg is None:
|
|
461
|
+
raise ValueError("Configuration is not defined, cannot predict.")
|
|
462
|
+
|
|
463
|
+
# Check that the mean and std are there (= has been trained)
|
|
464
|
+
if not self.cfg.data.mean or not self.cfg.data.std:
|
|
465
|
+
raise ValueError(
|
|
466
|
+
"Mean or std are not specified in the configuration, prediction cannot "
|
|
467
|
+
"be performed."
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
# set model to eval mode
|
|
471
|
+
self.model.to(self.device)
|
|
472
|
+
self.model.eval()
|
|
473
|
+
|
|
474
|
+
progress_bar = ProgressBar(num_epochs=1, mode="predict")
|
|
475
|
+
|
|
476
|
+
# Get dataloader
|
|
477
|
+
pred_loader, tiled = self._get_predict_dataloader(
|
|
478
|
+
input=input, tile_shape=tile_shape, overlaps=overlaps, axes=axes
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Start prediction
|
|
482
|
+
self.logger.info("Starting prediction")
|
|
483
|
+
if tiled:
|
|
484
|
+
self.logger.info("Starting tiled prediction")
|
|
485
|
+
prediction = self._predict_tiled(pred_loader, progress_bar, tta)
|
|
486
|
+
else:
|
|
487
|
+
self.logger.info("Starting prediction on whole sample")
|
|
488
|
+
prediction = self._predict_full(pred_loader, progress_bar, tta)
|
|
489
|
+
|
|
490
|
+
return prediction
|
|
491
|
+
|
|
492
|
+
def _predict_tiled(
|
|
493
|
+
self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True
|
|
494
|
+
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
495
|
+
"""
|
|
496
|
+
Predict using tiling.
|
|
497
|
+
|
|
498
|
+
Parameters
|
|
499
|
+
----------
|
|
500
|
+
pred_loader : DataLoader
|
|
501
|
+
Prediction dataloader.
|
|
502
|
+
progress_bar : ProgressBar
|
|
503
|
+
Progress bar.
|
|
504
|
+
tta : bool, optional
|
|
505
|
+
Whether to use test time augmentation, by default True.
|
|
506
|
+
|
|
507
|
+
Returns
|
|
508
|
+
-------
|
|
509
|
+
Union[np.ndarray, List[np.ndarray]]
|
|
510
|
+
Predicted image, or list of predictions if the images have different sizes.
|
|
511
|
+
|
|
512
|
+
Warns
|
|
513
|
+
-----
|
|
514
|
+
UserWarning
|
|
515
|
+
If the samples have different shapes, the prediction then returns a list.
|
|
516
|
+
"""
|
|
517
|
+
# checks are done here to satisfy mypy
|
|
518
|
+
# check that configuration exists
|
|
519
|
+
if self.cfg is None:
|
|
520
|
+
raise ValueError("Configuration is not defined, cannot predict.")
|
|
521
|
+
|
|
522
|
+
# Check that the mean and std are there (= has been trained)
|
|
523
|
+
if not self.cfg.data.mean or not self.cfg.data.std:
|
|
524
|
+
raise ValueError(
|
|
525
|
+
"Mean or std are not specified in the configuration, prediction cannot "
|
|
526
|
+
"be performed."
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
prediction = []
|
|
530
|
+
tiles = []
|
|
531
|
+
stitching_data = []
|
|
532
|
+
|
|
533
|
+
with torch.no_grad():
|
|
534
|
+
for i, (tile, *auxillary) in enumerate(pred_loader):
|
|
535
|
+
# Unpack auxillary data into last tile indicator and data, required to
|
|
536
|
+
# stitch tiles together
|
|
537
|
+
if auxillary:
|
|
538
|
+
last_tile, *data = auxillary
|
|
539
|
+
|
|
540
|
+
if tta:
|
|
541
|
+
augmented_tiles = tta_forward(tile)
|
|
542
|
+
predicted_augments = []
|
|
543
|
+
for augmented_tile in augmented_tiles:
|
|
544
|
+
augmented_pred = self.model(augmented_tile.to(self.device))
|
|
545
|
+
predicted_augments.append(augmented_pred.cpu())
|
|
546
|
+
tiles.append(tta_backward(predicted_augments).squeeze())
|
|
547
|
+
else:
|
|
548
|
+
tiles.append(
|
|
549
|
+
self.model(tile.to(self.device)).squeeze().cpu().numpy()
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
stitching_data.append(data)
|
|
553
|
+
|
|
554
|
+
if last_tile:
|
|
555
|
+
# Stitch tiles together if sample is finished
|
|
556
|
+
predicted_sample = stitch_prediction(tiles, stitching_data)
|
|
557
|
+
predicted_sample = denormalize(
|
|
558
|
+
predicted_sample,
|
|
559
|
+
float(self.cfg.data.mean),
|
|
560
|
+
float(self.cfg.data.std),
|
|
561
|
+
)
|
|
562
|
+
prediction.append(predicted_sample)
|
|
563
|
+
tiles.clear()
|
|
564
|
+
stitching_data.clear()
|
|
565
|
+
|
|
566
|
+
progress_bar.update(i, 1)
|
|
567
|
+
if tta:
|
|
568
|
+
i = int(i / 8)
|
|
569
|
+
self.logger.info(f"Predicted {len(prediction)} samples, {i} tiles in total")
|
|
570
|
+
try:
|
|
571
|
+
return np.stack(prediction)
|
|
572
|
+
except ValueError:
|
|
573
|
+
self.logger.warning("Samples have different shapes, returning list.")
|
|
574
|
+
return prediction
|
|
575
|
+
|
|
576
|
+
def _predict_full(
|
|
577
|
+
self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True
|
|
578
|
+
) -> np.ndarray:
|
|
579
|
+
"""
|
|
580
|
+
Predict whole image without tiling.
|
|
581
|
+
|
|
582
|
+
Parameters
|
|
583
|
+
----------
|
|
584
|
+
pred_loader : DataLoader
|
|
585
|
+
Prediction dataloader.
|
|
586
|
+
progress_bar : ProgressBar
|
|
587
|
+
Progress bar.
|
|
588
|
+
tta : bool, optional
|
|
589
|
+
Whether to use test time augmentation, by default True.
|
|
590
|
+
|
|
591
|
+
Returns
|
|
592
|
+
-------
|
|
593
|
+
np.ndarray
|
|
594
|
+
Predicted image.
|
|
595
|
+
"""
|
|
596
|
+
# checks are done here to satisfy mypy
|
|
597
|
+
# check that configuration exists
|
|
598
|
+
if self.cfg is None:
|
|
599
|
+
raise ValueError("Configuration is not defined, cannot predict.")
|
|
600
|
+
|
|
601
|
+
# Check that the mean and std are there (= has been trained)
|
|
602
|
+
if not self.cfg.data.mean or not self.cfg.data.std:
|
|
603
|
+
raise ValueError(
|
|
604
|
+
"Mean or std are not specified in the configuration, prediction cannot "
|
|
605
|
+
"be performed."
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
prediction = []
|
|
609
|
+
with torch.no_grad():
|
|
610
|
+
for i, sample in enumerate(pred_loader):
|
|
611
|
+
if tta:
|
|
612
|
+
augmented_preds = tta_forward(sample[0])
|
|
613
|
+
predicted_augments = []
|
|
614
|
+
for augmented_pred in augmented_preds:
|
|
615
|
+
augmented_pred = self.model(augmented_pred.to(self.device))
|
|
616
|
+
predicted_augments.append(augmented_pred.cpu())
|
|
617
|
+
prediction.append(tta_backward(predicted_augments).squeeze())
|
|
618
|
+
else:
|
|
619
|
+
prediction.append(
|
|
620
|
+
self.model(sample[0].to(self.device)).squeeze().cpu().numpy()
|
|
621
|
+
)
|
|
622
|
+
progress_bar.update(i, 1)
|
|
623
|
+
output = denormalize(
|
|
624
|
+
np.stack(prediction).squeeze(),
|
|
625
|
+
float(self.cfg.data.mean),
|
|
626
|
+
float(self.cfg.data.std),
|
|
627
|
+
)
|
|
628
|
+
return output
|
|
629
|
+
|
|
630
|
+
def _get_train_dataloader(self, train_path: str) -> DataLoader:
|
|
631
|
+
"""
|
|
632
|
+
Return a training dataloader.
|
|
633
|
+
|
|
634
|
+
Parameters
|
|
635
|
+
----------
|
|
636
|
+
train_path : str
|
|
637
|
+
Path to the training data.
|
|
638
|
+
|
|
639
|
+
Returns
|
|
640
|
+
-------
|
|
641
|
+
DataLoader
|
|
642
|
+
Training data loader.
|
|
643
|
+
|
|
644
|
+
Raises
|
|
645
|
+
------
|
|
646
|
+
ValueError
|
|
647
|
+
If the training configuration is None.
|
|
648
|
+
"""
|
|
649
|
+
if self.cfg is None:
|
|
650
|
+
raise ValueError("Configuration is not defined.")
|
|
651
|
+
|
|
652
|
+
dataset = get_train_dataset(self.cfg, train_path)
|
|
653
|
+
dataloader = DataLoader(
|
|
654
|
+
dataset,
|
|
655
|
+
batch_size=self.cfg.training.batch_size,
|
|
656
|
+
num_workers=self.cfg.training.num_workers,
|
|
657
|
+
pin_memory=True,
|
|
658
|
+
)
|
|
659
|
+
return dataloader
|
|
660
|
+
|
|
661
|
+
def _get_val_dataloader(self, val_path: str) -> DataLoader:
|
|
662
|
+
"""
|
|
663
|
+
Return a validation dataloader.
|
|
664
|
+
|
|
665
|
+
Parameters
|
|
666
|
+
----------
|
|
667
|
+
val_path : str
|
|
668
|
+
Path to the validation data.
|
|
669
|
+
|
|
670
|
+
Returns
|
|
671
|
+
-------
|
|
672
|
+
DataLoader
|
|
673
|
+
Validation data loader.
|
|
674
|
+
|
|
675
|
+
Raises
|
|
676
|
+
------
|
|
677
|
+
ValueError
|
|
678
|
+
If the configuration is None.
|
|
679
|
+
"""
|
|
680
|
+
if self.cfg is None:
|
|
681
|
+
raise ValueError("Configuration is not defined.")
|
|
682
|
+
|
|
683
|
+
dataset = get_validation_dataset(self.cfg, val_path)
|
|
684
|
+
dataloader = DataLoader(
|
|
685
|
+
dataset,
|
|
686
|
+
batch_size=self.cfg.training.batch_size,
|
|
687
|
+
num_workers=self.cfg.training.num_workers,
|
|
688
|
+
pin_memory=True,
|
|
689
|
+
)
|
|
690
|
+
return dataloader
|
|
691
|
+
|
|
692
|
+
def _get_predict_dataloader(
|
|
693
|
+
self,
|
|
694
|
+
input: Union[np.ndarray, str, Path],
|
|
695
|
+
*,
|
|
696
|
+
tile_shape: Optional[List[int]] = None,
|
|
697
|
+
overlaps: Optional[List[int]] = None,
|
|
698
|
+
axes: Optional[str] = None,
|
|
699
|
+
) -> Tuple[DataLoader, bool]:
|
|
700
|
+
"""
|
|
701
|
+
Return a prediction dataloader.
|
|
702
|
+
|
|
703
|
+
Parameters
|
|
704
|
+
----------
|
|
705
|
+
input : Union[np.ndarray, str, Path]
|
|
706
|
+
Input array or path to data.
|
|
707
|
+
tile_shape : Optional[List[int]], optional
|
|
708
|
+
2D or 3D shape of the tiles, by default None.
|
|
709
|
+
overlaps : Optional[List[int]], optional
|
|
710
|
+
2D or 3D overlaps between tiles, by default None.
|
|
711
|
+
axes : Optional[str], optional
|
|
712
|
+
Axes of the input array if different from the one in the configuration.
|
|
713
|
+
|
|
714
|
+
Returns
|
|
715
|
+
-------
|
|
716
|
+
Tuple[DataLoader, bool]
|
|
717
|
+
Tuple of prediction data loader, and whether the data is tiled.
|
|
718
|
+
|
|
719
|
+
Raises
|
|
720
|
+
------
|
|
721
|
+
ValueError
|
|
722
|
+
If the configuration is None.
|
|
723
|
+
ValueError
|
|
724
|
+
If the mean or std are not specified in the configuration.
|
|
725
|
+
ValueError
|
|
726
|
+
If the input is None.
|
|
727
|
+
"""
|
|
728
|
+
if self.cfg is None:
|
|
729
|
+
raise ValueError("Configuration is not defined.")
|
|
730
|
+
|
|
731
|
+
if self.cfg.data.mean is None or self.cfg.data.std is None:
|
|
732
|
+
raise ValueError(
|
|
733
|
+
"Mean or std are not specified in the configuration, prediction cannot "
|
|
734
|
+
"be performed. Was the model trained?"
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
if input is None:
|
|
738
|
+
raise ValueError("Input cannot be None.")
|
|
739
|
+
|
|
740
|
+
# Create dataset
|
|
741
|
+
if isinstance(input, np.ndarray): # np.ndarray
|
|
742
|
+
# Validate axes and add missing dimensions (S)C if necessary
|
|
743
|
+
img_axes = self.cfg.data.axes if axes is None else axes
|
|
744
|
+
input_expanded = add_axes(input, img_axes)
|
|
745
|
+
|
|
746
|
+
# Check if tiling requested
|
|
747
|
+
tiled = tile_shape is not None and overlaps is not None
|
|
748
|
+
|
|
749
|
+
# Validate tiles and overlaps
|
|
750
|
+
if tiled:
|
|
751
|
+
raise NotImplementedError(
|
|
752
|
+
"Tiling with in memory array is currently not implemented."
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Normalize input and cast to float32
|
|
756
|
+
normalized_input = normalize(
|
|
757
|
+
img=input_expanded, mean=self.cfg.data.mean, std=self.cfg.data.std
|
|
758
|
+
)
|
|
759
|
+
normalized_input = normalized_input.astype(np.float32)
|
|
760
|
+
|
|
761
|
+
# Create dataset
|
|
762
|
+
dataset = TensorDataset(torch.from_numpy(normalized_input))
|
|
763
|
+
|
|
764
|
+
elif isinstance(input, str) or isinstance(input, Path): # path
|
|
765
|
+
# Create dataset
|
|
766
|
+
dataset = get_prediction_dataset(
|
|
767
|
+
self.cfg,
|
|
768
|
+
pred_path=input,
|
|
769
|
+
tile_shape=tile_shape,
|
|
770
|
+
overlaps=overlaps,
|
|
771
|
+
axes=axes,
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
tiled = (
|
|
775
|
+
hasattr(dataset, "patch_extraction_method")
|
|
776
|
+
and dataset.patch_extraction_method is not None
|
|
777
|
+
)
|
|
778
|
+
return (
|
|
779
|
+
DataLoader(
|
|
780
|
+
dataset,
|
|
781
|
+
batch_size=1,
|
|
782
|
+
num_workers=0,
|
|
783
|
+
pin_memory=True,
|
|
784
|
+
),
|
|
785
|
+
tiled,
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
def _save_checkpoint(
|
|
789
|
+
self, epoch: int, losses: List[float], save_method: str
|
|
790
|
+
) -> Path:
|
|
791
|
+
"""
|
|
792
|
+
Save checkpoint.
|
|
793
|
+
|
|
794
|
+
Currently only supports saving using `save_method="state_dict"`.
|
|
795
|
+
|
|
796
|
+
Parameters
|
|
797
|
+
----------
|
|
798
|
+
epoch : int
|
|
799
|
+
Last epoch.
|
|
800
|
+
losses : List[float]
|
|
801
|
+
List of losses.
|
|
802
|
+
save_method : str
|
|
803
|
+
Method to save the model. Currently only supports `state_dict`.
|
|
804
|
+
|
|
805
|
+
Returns
|
|
806
|
+
-------
|
|
807
|
+
Path
|
|
808
|
+
Path to the saved checkpoint.
|
|
809
|
+
|
|
810
|
+
Raises
|
|
811
|
+
------
|
|
812
|
+
ValueError
|
|
813
|
+
If the configuration is None.
|
|
814
|
+
NotImplementedError
|
|
815
|
+
If the requested save method is not supported.
|
|
816
|
+
"""
|
|
817
|
+
if self.cfg is None:
|
|
818
|
+
raise ValueError("Configuration is not defined.")
|
|
819
|
+
|
|
820
|
+
if epoch == 0 or losses[-1] == min(losses):
|
|
821
|
+
name = f"{self.cfg.experiment_name}_best.pth"
|
|
822
|
+
else:
|
|
823
|
+
name = f"{self.cfg.experiment_name}_latest.pth"
|
|
824
|
+
workdir = self.cfg.working_directory
|
|
825
|
+
workdir.mkdir(parents=True, exist_ok=True)
|
|
826
|
+
|
|
827
|
+
if save_method == "state_dict":
|
|
828
|
+
checkpoint = {
|
|
829
|
+
"epoch": epoch,
|
|
830
|
+
"model_state_dict": self.model.state_dict(),
|
|
831
|
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
832
|
+
"scheduler_state_dict": self.lr_scheduler.state_dict(),
|
|
833
|
+
"grad_scaler_state_dict": self.scaler.state_dict(),
|
|
834
|
+
"loss": losses[-1],
|
|
835
|
+
"config": self.cfg.model_dump(),
|
|
836
|
+
}
|
|
837
|
+
torch.save(checkpoint, workdir / name)
|
|
838
|
+
else:
|
|
839
|
+
raise NotImplementedError("Invalid save method.")
|
|
840
|
+
|
|
841
|
+
return self.cfg.working_directory.absolute() / name
|
|
842
|
+
|
|
843
|
+
def __del__(self) -> None:
|
|
844
|
+
"""Exit logger."""
|
|
845
|
+
if hasattr(self, "logger"):
|
|
846
|
+
for handler in self.logger.handlers:
|
|
847
|
+
if isinstance(handler, FileHandler):
|
|
848
|
+
self.logger.removeHandler(handler)
|
|
849
|
+
handler.close()
|
|
850
|
+
|
|
851
|
+
def _get_sample_io_files(
|
|
852
|
+
self,
|
|
853
|
+
input_array: Optional[np.ndarray] = None,
|
|
854
|
+
axes: Optional[str] = None,
|
|
855
|
+
) -> Tuple[List[str], List[str]]:
|
|
856
|
+
"""
|
|
857
|
+
Create numpy format for use as inputs and outputs in the bioimage.io archive.
|
|
858
|
+
|
|
859
|
+
Parameters
|
|
860
|
+
----------
|
|
861
|
+
input_array : Optional[np.ndarray], optional
|
|
862
|
+
Input array to use for the bioimage.io model zoo, by default None.
|
|
863
|
+
axes : Optional[str], optional
|
|
864
|
+
Axes from the configuration.
|
|
865
|
+
|
|
866
|
+
Returns
|
|
867
|
+
-------
|
|
868
|
+
Tuple[List[str], List[str]]
|
|
869
|
+
Tuple of input and output file paths.
|
|
870
|
+
|
|
871
|
+
Raises
|
|
872
|
+
------
|
|
873
|
+
ValueError
|
|
874
|
+
If the configuration is not defined.
|
|
875
|
+
"""
|
|
876
|
+
if self.cfg is not None and self._input is not None:
|
|
877
|
+
# use the input array if provided, otherwise use the first validation sample
|
|
878
|
+
if input_array is not None:
|
|
879
|
+
array_in = input_array
|
|
880
|
+
|
|
881
|
+
# add axes to be compatible with the axes declared in the RDF specs
|
|
882
|
+
add_axes(array_in, axes)
|
|
883
|
+
else:
|
|
884
|
+
array_in = self._input
|
|
885
|
+
|
|
886
|
+
# predict (no tta since BMZ does not apply it)
|
|
887
|
+
array_out = self.predict(array_in, tta=False)
|
|
888
|
+
|
|
889
|
+
# add singleton dimensions (for compatibility with model axes)
|
|
890
|
+
# indeed, BMZ applies the model but CAREamics function are meant
|
|
891
|
+
# to work on user data (potentially with no S or C axe)
|
|
892
|
+
array_out = array_out[np.newaxis, np.newaxis, ...]
|
|
893
|
+
|
|
894
|
+
# save numpy files
|
|
895
|
+
workdir = self.cfg.working_directory
|
|
896
|
+
in_file = workdir.joinpath("test_inputs.npy")
|
|
897
|
+
np.save(in_file, array_in)
|
|
898
|
+
out_file = workdir.joinpath("test_outputs.npy")
|
|
899
|
+
np.save(out_file, array_out)
|
|
900
|
+
|
|
901
|
+
return [str(in_file.absolute())], [str(out_file.absolute())]
|
|
902
|
+
else:
|
|
903
|
+
raise ValueError("Configuration is not defined or model was not trained.")
|
|
904
|
+
|
|
905
|
+
def _generate_rdf(
|
|
906
|
+
self,
|
|
907
|
+
*,
|
|
908
|
+
model_specs: Optional[dict] = None,
|
|
909
|
+
input_array: Optional[np.ndarray] = None,
|
|
910
|
+
) -> dict:
|
|
911
|
+
"""
|
|
912
|
+
Generate rdf data for bioimage.io format export.
|
|
913
|
+
|
|
914
|
+
Parameters
|
|
915
|
+
----------
|
|
916
|
+
model_specs : Optional[dict], optional
|
|
917
|
+
Custom specs if different than the default ones, by default None.
|
|
918
|
+
input_array : Optional[np.ndarray], optional
|
|
919
|
+
Input array to use for the bioimage.io model zoo, by default None.
|
|
920
|
+
|
|
921
|
+
Returns
|
|
922
|
+
-------
|
|
923
|
+
dict
|
|
924
|
+
RDF specs.
|
|
925
|
+
|
|
926
|
+
Raises
|
|
927
|
+
------
|
|
928
|
+
ValueError
|
|
929
|
+
If the mean or std are not specified in the configuration.
|
|
930
|
+
ValueError
|
|
931
|
+
If the configuration is not defined.
|
|
932
|
+
"""
|
|
933
|
+
if self.cfg is not None:
|
|
934
|
+
if self.cfg.data.mean is None or self.cfg.data.std is None:
|
|
935
|
+
raise ValueError(
|
|
936
|
+
"Mean or std are not specified in the configuration, export to "
|
|
937
|
+
"bioimage.io format is not possible."
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# set in/out axes from config
|
|
941
|
+
axes = self.cfg.data.axes.lower().replace("s", "")
|
|
942
|
+
if "c" not in axes:
|
|
943
|
+
axes = "c" + axes
|
|
944
|
+
if "b" not in axes:
|
|
945
|
+
axes = "b" + axes
|
|
946
|
+
|
|
947
|
+
# get in/out samples' files
|
|
948
|
+
test_inputs, test_outputs = self._get_sample_io_files(
|
|
949
|
+
input_array, self.cfg.data.axes
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
specs = get_default_model_specs(
|
|
953
|
+
"Noise2Void",
|
|
954
|
+
self.cfg.data.mean,
|
|
955
|
+
self.cfg.data.std,
|
|
956
|
+
self.cfg.algorithm.is_3D,
|
|
957
|
+
)
|
|
958
|
+
if model_specs is not None:
|
|
959
|
+
specs.update(model_specs)
|
|
960
|
+
|
|
961
|
+
specs.update(
|
|
962
|
+
{
|
|
963
|
+
"test_inputs": test_inputs,
|
|
964
|
+
"test_outputs": test_outputs,
|
|
965
|
+
"input_axes": [axes],
|
|
966
|
+
"output_axes": [axes],
|
|
967
|
+
}
|
|
968
|
+
)
|
|
969
|
+
return specs
|
|
970
|
+
else:
|
|
971
|
+
raise ValueError("Configuration is not defined or model was not trained.")
|
|
972
|
+
|
|
973
|
+
def save_as_bioimage(
|
|
974
|
+
self,
|
|
975
|
+
output_zip: Union[Path, str],
|
|
976
|
+
model_specs: Optional[dict] = None,
|
|
977
|
+
input_array: Optional[np.ndarray] = None,
|
|
978
|
+
) -> None:
|
|
979
|
+
"""
|
|
980
|
+
Export the current model to BioImage.io model zoo format.
|
|
981
|
+
|
|
982
|
+
Custom specs can be passed in `model_specs (e.g. maintainers). For a description
|
|
983
|
+
of the model RDF, refer to
|
|
984
|
+
github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/model_spec_latest.md.
|
|
985
|
+
|
|
986
|
+
Parameters
|
|
987
|
+
----------
|
|
988
|
+
output_zip : Union[Path, str]
|
|
989
|
+
Where to save the model zip file.
|
|
990
|
+
model_specs : Optional[dict]
|
|
991
|
+
A dictionary with keys being the bioimage-core build_model parameters. If
|
|
992
|
+
None then it will be populated by the model default specs.
|
|
993
|
+
input_array : Optional[np.ndarray]
|
|
994
|
+
An array to use as input for the bioimage.io model zoo. If None then the
|
|
995
|
+
first validation sample will be used. Note that the array must have S and
|
|
996
|
+
C dimensions (e.g. SCYX), even if only singleton dimensions.
|
|
997
|
+
|
|
998
|
+
Raises
|
|
999
|
+
------
|
|
1000
|
+
ValueError
|
|
1001
|
+
If the configuration is not defined.
|
|
1002
|
+
"""
|
|
1003
|
+
if self.cfg is not None:
|
|
1004
|
+
# Generate specs
|
|
1005
|
+
specs = self._generate_rdf(model_specs=model_specs, input_array=input_array)
|
|
1006
|
+
|
|
1007
|
+
# Build model
|
|
1008
|
+
save_bioimage_model(
|
|
1009
|
+
path=output_zip,
|
|
1010
|
+
config=self.cfg,
|
|
1011
|
+
specs=specs,
|
|
1012
|
+
)
|
|
1013
|
+
else:
|
|
1014
|
+
raise ValueError("Configuration is not defined.")
|