careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (48) hide show
  1. careamics/__init__.py +7 -1
  2. careamics/bioimage/__init__.py +15 -0
  3. careamics/bioimage/docs/Noise2Void.md +5 -0
  4. careamics/bioimage/docs/__init__.py +1 -0
  5. careamics/bioimage/io.py +182 -0
  6. careamics/bioimage/rdf.py +105 -0
  7. careamics/config/__init__.py +11 -0
  8. careamics/config/algorithm.py +231 -0
  9. careamics/config/config.py +297 -0
  10. careamics/config/config_filter.py +44 -0
  11. careamics/config/data.py +194 -0
  12. careamics/config/torch_optim.py +118 -0
  13. careamics/config/training.py +534 -0
  14. careamics/dataset/__init__.py +1 -0
  15. careamics/dataset/dataset_utils.py +111 -0
  16. careamics/dataset/extraction_strategy.py +21 -0
  17. careamics/dataset/in_memory_dataset.py +202 -0
  18. careamics/dataset/patching.py +492 -0
  19. careamics/dataset/prepare_dataset.py +175 -0
  20. careamics/dataset/tiff_dataset.py +212 -0
  21. careamics/engine.py +1014 -0
  22. careamics/losses/__init__.py +4 -0
  23. careamics/losses/loss_factory.py +38 -0
  24. careamics/losses/losses.py +34 -0
  25. careamics/manipulation/__init__.py +4 -0
  26. careamics/manipulation/pixel_manipulation.py +158 -0
  27. careamics/models/__init__.py +4 -0
  28. careamics/models/layers.py +152 -0
  29. careamics/models/model_factory.py +251 -0
  30. careamics/models/unet.py +322 -0
  31. careamics/prediction/__init__.py +9 -0
  32. careamics/prediction/prediction_utils.py +106 -0
  33. careamics/utils/__init__.py +20 -0
  34. careamics/utils/ascii_logo.txt +9 -0
  35. careamics/utils/augment.py +65 -0
  36. careamics/utils/context.py +45 -0
  37. careamics/utils/logging.py +321 -0
  38. careamics/utils/metrics.py +160 -0
  39. careamics/utils/normalization.py +55 -0
  40. careamics/utils/torch_utils.py +89 -0
  41. careamics/utils/validators.py +170 -0
  42. careamics/utils/wandb.py +121 -0
  43. careamics-0.1.0rc2.dist-info/METADATA +81 -0
  44. careamics-0.1.0rc2.dist-info/RECORD +47 -0
  45. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
  46. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
  47. careamics-0.0.1.dist-info/METADATA +0 -46
  48. careamics-0.0.1.dist-info/RECORD +0 -6
careamics/engine.py ADDED
@@ -0,0 +1,1014 @@
1
+ """
2
+ Engine module.
3
+
4
+ This module contains the main CAREamics class, the Engine. The Engine allows training
5
+ a model and using it for prediction.
6
+ """
7
+ from logging import FileHandler
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data import DataLoader, TensorDataset
14
+
15
+ from .bioimage import (
16
+ get_default_model_specs,
17
+ save_bioimage_model,
18
+ )
19
+ from .config import Configuration, load_configuration
20
+ from .dataset.prepare_dataset import (
21
+ get_prediction_dataset,
22
+ get_train_dataset,
23
+ get_validation_dataset,
24
+ )
25
+ from .losses import create_loss_function
26
+ from .models import create_model
27
+ from .prediction import (
28
+ stitch_prediction,
29
+ tta_backward,
30
+ tta_forward,
31
+ )
32
+ from .utils import (
33
+ MetricTracker,
34
+ add_axes,
35
+ denormalize,
36
+ get_device,
37
+ normalize,
38
+ )
39
+ from .utils.logging import ProgressBar, get_logger
40
+
41
+
42
+ class Engine:
43
+ """
44
+ Class allowing training of a model and subsequent prediction.
45
+
46
+ There are three ways to instantiate an Engine:
47
+ 1. With a CAREamics model (.pth), by passing a path.
48
+ 2. With a configuration object.
49
+ 3. With a configuration file, by passing a path.
50
+
51
+ In each case, the parameter name must be provided explicitly. For example:
52
+ >>> engine = Engine(config_path="path/to/config.yaml")
53
+
54
+ Note that only one of these options can be used at a time, in the order listed
55
+ above.
56
+
57
+ Parameters
58
+ ----------
59
+ config : Optional[Configuration], optional
60
+ Configuration object, by default None.
61
+ config_path : Optional[Union[str, Path]], optional
62
+ Path to configuration file, by default None.
63
+ model_path : Optional[Union[str, Path]], optional
64
+ Path to model file, by default None.
65
+ seed : int, optional
66
+ Seed for reproducibility, by default 42.
67
+
68
+ Attributes
69
+ ----------
70
+ cfg : Configuration
71
+ Configuration.
72
+ device : torch.device
73
+ Device (CPU or GPU).
74
+ model : torch.nn.Module
75
+ Model.
76
+ optimizer : torch.optim.Optimizer
77
+ Optimizer.
78
+ lr_scheduler : torch.optim.lr_scheduler._LRScheduler
79
+ Learning rate scheduler.
80
+ scaler : torch.cuda.amp.GradScaler
81
+ Gradient scaler.
82
+ loss_func : Callable
83
+ Loss function.
84
+ logger : logging.Logger
85
+ Logger.
86
+ use_wandb : bool
87
+ Whether to use wandb.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ *,
93
+ config: Optional[Configuration] = None,
94
+ config_path: Optional[Union[str, Path]] = None,
95
+ model_path: Optional[Union[str, Path]] = None,
96
+ seed: Optional[int] = 42,
97
+ ) -> None:
98
+ """
99
+ Constructor.
100
+
101
+ To disable the seed, set it to None.
102
+
103
+ Parameters
104
+ ----------
105
+ config : Optional[Configuration], optional
106
+ Configuration object, by default None.
107
+ config_path : Optional[Union[str, Path]], optional
108
+ Path to configuration file, by default None.
109
+ model_path : Optional[Union[str, Path]], optional
110
+ Path to model file, by default None.
111
+ seed : int, optional
112
+ Seed for reproducibility, by default 42.
113
+
114
+ Raises
115
+ ------
116
+ ValueError
117
+ If all three parameters are None.
118
+ FileNotFoundError
119
+ If the model or configuration path is provided but does not exist.
120
+ TypeError
121
+ If the configuration is not a Configuration object.
122
+ UsageError
123
+ If wandb is not correctly installed.
124
+ ModuleNotFoundError
125
+ If wandb is not installed.
126
+ ValueError
127
+ If the configuration failed to configure.
128
+ """
129
+ if model_path is not None:
130
+ if not Path(model_path).exists():
131
+ raise FileNotFoundError(
132
+ f"Model path {model_path} is incorrect or"
133
+ f" does not exist. Current working directory is: {Path.cwd()!s}"
134
+ )
135
+
136
+ # Ensure that config is None
137
+ self.cfg = None
138
+
139
+ elif config is not None:
140
+ # Check that config is a Configuration object
141
+ if not isinstance(config, Configuration):
142
+ raise TypeError(
143
+ f"config must be a Configuration object, got {type(config)}"
144
+ )
145
+ self.cfg = config
146
+ elif config_path is not None:
147
+ self.cfg = load_configuration(config_path)
148
+ else:
149
+ raise ValueError(
150
+ "No configuration or path provided. One of configuration "
151
+ "object, configuration path or model path must be provided."
152
+ )
153
+
154
+ # get device, CPU or GPU
155
+ self.device = get_device()
156
+
157
+ # Create model, optimizer, lr scheduler and gradient scaler and load everything
158
+ # to the specified device
159
+ (
160
+ self.model,
161
+ self.optimizer,
162
+ self.lr_scheduler,
163
+ self.scaler,
164
+ self.cfg,
165
+ ) = create_model(config=self.cfg, model_path=model_path, device=self.device)
166
+ assert self.cfg is not None
167
+
168
+ # create loss function
169
+ self.loss_func = create_loss_function(self.cfg)
170
+
171
+ # Set logging
172
+ log_path = self.cfg.working_directory / "log.txt"
173
+ self.logger = get_logger(__name__, log_path=log_path)
174
+
175
+ # wandb
176
+ self.use_wandb = self.cfg.training.use_wandb
177
+
178
+ if self.use_wandb:
179
+ try:
180
+ from wandb.errors import UsageError
181
+
182
+ from careamics.utils.wandb import WandBLogging
183
+
184
+ try:
185
+ self.wandb = WandBLogging(
186
+ experiment_name=self.cfg.experiment_name,
187
+ log_path=self.cfg.working_directory,
188
+ config=self.cfg,
189
+ model_to_watch=self.model,
190
+ )
191
+ except UsageError as e:
192
+ self.logger.warning(
193
+ f"Wandb usage error, using default logger. Check whether "
194
+ f"wandb correctly configured:\n"
195
+ f"{e}"
196
+ )
197
+ self.use_wandb = False
198
+
199
+ except ModuleNotFoundError:
200
+ self.logger.warning(
201
+ "Wandb not installed, using default logger. Try pip install "
202
+ "wandb"
203
+ )
204
+ self.use_wandb = False
205
+
206
+ # BMZ inputs/outputs placeholders, filled during validation
207
+ self._input = None
208
+ self._outputs = None
209
+
210
+ # torch version
211
+ self.torch_version = torch.__version__
212
+
213
+ def train(
214
+ self,
215
+ train_path: str,
216
+ val_path: str,
217
+ ) -> Tuple[List[Any], List[Any]]:
218
+ """
219
+ Train the network.
220
+
221
+ The training and validation data given by the paths must be compatible with the
222
+ axes and data format provided in the configuration.
223
+
224
+ Parameters
225
+ ----------
226
+ train_path : Union[str, Path]
227
+ Path to the training data.
228
+ val_path : Union[str, Path]
229
+ Path to the validation data.
230
+
231
+ Returns
232
+ -------
233
+ Tuple[List[Any], List[Any]]
234
+ Tuple of training and validation statistics.
235
+
236
+ Raises
237
+ ------
238
+ ValueError
239
+ Raise a ValueError if the configuration is missing.
240
+ """
241
+ if self.cfg is None:
242
+ raise ValueError("Configuration is not defined, cannot train.")
243
+
244
+ # General func
245
+ train_loader = self._get_train_dataloader(train_path)
246
+
247
+ # Set mean and std from train dataset of none
248
+ if self.cfg.data.mean is None or self.cfg.data.std is None:
249
+ self.cfg.data.set_mean_and_std(
250
+ train_loader.dataset.mean, train_loader.dataset.std
251
+ )
252
+
253
+ eval_loader = self._get_val_dataloader(val_path)
254
+ self.logger.info(f"Starting training for {self.cfg.training.num_epochs} epochs")
255
+
256
+ val_losses = []
257
+
258
+ try:
259
+ train_stats = []
260
+ eval_stats = []
261
+
262
+ # loop over the dataset multiple times
263
+ for epoch in range(self.cfg.training.num_epochs):
264
+ if hasattr(train_loader.dataset, "__len__"):
265
+ epoch_size = train_loader.__len__()
266
+ else:
267
+ epoch_size = None
268
+
269
+ progress_bar = ProgressBar(
270
+ max_value=epoch_size,
271
+ epoch=epoch,
272
+ num_epochs=self.cfg.training.num_epochs,
273
+ mode="train",
274
+ )
275
+ # train_epoch = train_op(self._train_single_epoch,)
276
+ # Perform training step
277
+ train_outputs, epoch_size = self._train_single_epoch(
278
+ train_loader,
279
+ progress_bar,
280
+ self.cfg.training.amp.use,
281
+ )
282
+ # Perform validation step
283
+ eval_outputs = self._evaluate(eval_loader)
284
+ val_losses.append(eval_outputs["loss"])
285
+ learning_rate = self.optimizer.param_groups[0]["lr"]
286
+
287
+ progress_bar.add(
288
+ 1,
289
+ values=[
290
+ ("train_loss", train_outputs["loss"]),
291
+ ("val loss", eval_outputs["loss"]),
292
+ ("lr", learning_rate),
293
+ ],
294
+ )
295
+ # Add update scheduler rule based on type
296
+ self.lr_scheduler.step(eval_outputs["loss"])
297
+
298
+ if self.use_wandb:
299
+ metrics = {
300
+ "train": train_outputs,
301
+ "eval": eval_outputs,
302
+ "lr": learning_rate,
303
+ }
304
+ self.wandb.log_metrics(metrics)
305
+
306
+ train_stats.append(train_outputs)
307
+ eval_stats.append(eval_outputs)
308
+
309
+ checkpoint_path = self._save_checkpoint(epoch, val_losses, "state_dict")
310
+ self.logger.info(f"Saved checkpoint to {checkpoint_path}")
311
+
312
+ except KeyboardInterrupt:
313
+ self.logger.info("Training interrupted")
314
+
315
+ return train_stats, eval_stats
316
+
317
+ def _train_single_epoch(
318
+ self,
319
+ loader: torch.utils.data.DataLoader,
320
+ progress_bar: ProgressBar,
321
+ amp: bool,
322
+ ) -> Tuple[Dict[str, float], int]:
323
+ """
324
+ Train for a single epoch.
325
+
326
+ Parameters
327
+ ----------
328
+ loader : torch.utils.data.DataLoader
329
+ Training dataloader.
330
+ progress_bar : ProgressBar
331
+ Progress bar.
332
+ amp : bool
333
+ Whether to use automatic mixed precision.
334
+
335
+ Returns
336
+ -------
337
+ Tuple[Dict[str, float], int]
338
+ Tuple of training metrics and epoch size.
339
+
340
+ Raises
341
+ ------
342
+ ValueError
343
+ If the configuration is missing.
344
+ """
345
+ if self.cfg is not None:
346
+ avg_loss = MetricTracker()
347
+ self.model.train()
348
+ epoch_size = 0
349
+
350
+ for i, (batch, *auxillary) in enumerate(loader):
351
+ self.optimizer.zero_grad(set_to_none=True)
352
+
353
+ with torch.cuda.amp.autocast(enabled=amp):
354
+ outputs = self.model(batch.to(self.device))
355
+
356
+ loss = self.loss_func(
357
+ outputs, *[a.to(self.device) for a in auxillary], self.device
358
+ )
359
+ self.scaler.scale(loss).backward()
360
+ avg_loss.update(loss.detach(), batch.shape[0])
361
+
362
+ progress_bar.update(
363
+ current_step=i,
364
+ batch_size=self.cfg.training.batch_size,
365
+ )
366
+
367
+ self.optimizer.step()
368
+ epoch_size += 1
369
+
370
+ return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()}, epoch_size
371
+ else:
372
+ raise ValueError("Configuration is not defined, cannot train.")
373
+
374
+ def _evaluate(self, val_loader: torch.utils.data.DataLoader) -> Dict[str, float]:
375
+ """
376
+ Perform validation step.
377
+
378
+ Parameters
379
+ ----------
380
+ val_loader : torch.utils.data.DataLoader
381
+ Validation dataloader.
382
+
383
+ Returns
384
+ -------
385
+ Dict[str, float]
386
+ Loss value on the validation set.
387
+ """
388
+ self.model.eval()
389
+ avg_loss = MetricTracker()
390
+
391
+ with torch.no_grad():
392
+ for patch, *auxillary in val_loader:
393
+ # if inputs is None, record a single patch
394
+ if self._input is None:
395
+ # patch has dimension SC(Z)YX
396
+ self._input = patch.clone().detach().cpu().numpy()
397
+
398
+ # evaluate
399
+ outputs = self.model(patch.to(self.device))
400
+ loss = self.loss_func(
401
+ outputs, *[a.to(self.device) for a in auxillary], self.device
402
+ )
403
+ avg_loss.update(loss.detach(), patch.shape[0])
404
+ return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()}
405
+
406
+ def predict(
407
+ self,
408
+ input: Union[np.ndarray, str, Path],
409
+ *,
410
+ tile_shape: Optional[List[int]] = None,
411
+ overlaps: Optional[List[int]] = None,
412
+ axes: Optional[str] = None,
413
+ tta: bool = True,
414
+ ) -> Union[np.ndarray, List[np.ndarray]]:
415
+ """
416
+ Predict using the current model on an input array or a path to data.
417
+
418
+ The Engine must have previously been trained and mean/std be specified in
419
+ its configuration.
420
+
421
+ Data should be compatible with the axes, either from the configuration or
422
+ as passed using the `axes` parameter. If the batch and channel dimensions are
423
+ missing, then singleton dimensions are added.
424
+
425
+ To use tiling, both `tile_shape` and `overlaps` must be specified, have same
426
+ length, be divisible by 2 and greater than 0. Finally, the overlaps must be
427
+ smaller than the tiles.
428
+
429
+ By setting `tta` to `True`, the prediction is performed using test time
430
+ augmentation, meaning that the input is augmented and the prediction is averaged
431
+ over the augmentations.
432
+
433
+ Parameters
434
+ ----------
435
+ input : Union[np.ndarra, str, Path]
436
+ Input data, either an array or a path to the data.
437
+ tile_shape : Optional[List[int]], optional
438
+ 2D or 3D shape of the tiles to be predicted, by default None.
439
+ overlaps : Optional[List[int]], optional
440
+ 2D or 3D overlaps between tiles, by default None.
441
+ axes : Optional[str], optional
442
+ Axes of the input array if different from the one in the configuration, by
443
+ default None.
444
+ tta : bool, optional
445
+ Whether to use test time augmentation, by default True.
446
+
447
+ Returns
448
+ -------
449
+ Union[np.ndarray, List[np.ndarray]]
450
+ Predicted image array of the same shape as the input, or list of arrays
451
+ if the arrays have inconsistent shapes.
452
+
453
+ Raises
454
+ ------
455
+ ValueError
456
+ If the configuration is missing.
457
+ ValueError
458
+ If the mean or std are not specified in the configuration (untrained model).
459
+ """
460
+ if self.cfg is None:
461
+ raise ValueError("Configuration is not defined, cannot predict.")
462
+
463
+ # Check that the mean and std are there (= has been trained)
464
+ if not self.cfg.data.mean or not self.cfg.data.std:
465
+ raise ValueError(
466
+ "Mean or std are not specified in the configuration, prediction cannot "
467
+ "be performed."
468
+ )
469
+
470
+ # set model to eval mode
471
+ self.model.to(self.device)
472
+ self.model.eval()
473
+
474
+ progress_bar = ProgressBar(num_epochs=1, mode="predict")
475
+
476
+ # Get dataloader
477
+ pred_loader, tiled = self._get_predict_dataloader(
478
+ input=input, tile_shape=tile_shape, overlaps=overlaps, axes=axes
479
+ )
480
+
481
+ # Start prediction
482
+ self.logger.info("Starting prediction")
483
+ if tiled:
484
+ self.logger.info("Starting tiled prediction")
485
+ prediction = self._predict_tiled(pred_loader, progress_bar, tta)
486
+ else:
487
+ self.logger.info("Starting prediction on whole sample")
488
+ prediction = self._predict_full(pred_loader, progress_bar, tta)
489
+
490
+ return prediction
491
+
492
+ def _predict_tiled(
493
+ self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True
494
+ ) -> Union[np.ndarray, List[np.ndarray]]:
495
+ """
496
+ Predict using tiling.
497
+
498
+ Parameters
499
+ ----------
500
+ pred_loader : DataLoader
501
+ Prediction dataloader.
502
+ progress_bar : ProgressBar
503
+ Progress bar.
504
+ tta : bool, optional
505
+ Whether to use test time augmentation, by default True.
506
+
507
+ Returns
508
+ -------
509
+ Union[np.ndarray, List[np.ndarray]]
510
+ Predicted image, or list of predictions if the images have different sizes.
511
+
512
+ Warns
513
+ -----
514
+ UserWarning
515
+ If the samples have different shapes, the prediction then returns a list.
516
+ """
517
+ # checks are done here to satisfy mypy
518
+ # check that configuration exists
519
+ if self.cfg is None:
520
+ raise ValueError("Configuration is not defined, cannot predict.")
521
+
522
+ # Check that the mean and std are there (= has been trained)
523
+ if not self.cfg.data.mean or not self.cfg.data.std:
524
+ raise ValueError(
525
+ "Mean or std are not specified in the configuration, prediction cannot "
526
+ "be performed."
527
+ )
528
+
529
+ prediction = []
530
+ tiles = []
531
+ stitching_data = []
532
+
533
+ with torch.no_grad():
534
+ for i, (tile, *auxillary) in enumerate(pred_loader):
535
+ # Unpack auxillary data into last tile indicator and data, required to
536
+ # stitch tiles together
537
+ if auxillary:
538
+ last_tile, *data = auxillary
539
+
540
+ if tta:
541
+ augmented_tiles = tta_forward(tile)
542
+ predicted_augments = []
543
+ for augmented_tile in augmented_tiles:
544
+ augmented_pred = self.model(augmented_tile.to(self.device))
545
+ predicted_augments.append(augmented_pred.cpu())
546
+ tiles.append(tta_backward(predicted_augments).squeeze())
547
+ else:
548
+ tiles.append(
549
+ self.model(tile.to(self.device)).squeeze().cpu().numpy()
550
+ )
551
+
552
+ stitching_data.append(data)
553
+
554
+ if last_tile:
555
+ # Stitch tiles together if sample is finished
556
+ predicted_sample = stitch_prediction(tiles, stitching_data)
557
+ predicted_sample = denormalize(
558
+ predicted_sample,
559
+ float(self.cfg.data.mean),
560
+ float(self.cfg.data.std),
561
+ )
562
+ prediction.append(predicted_sample)
563
+ tiles.clear()
564
+ stitching_data.clear()
565
+
566
+ progress_bar.update(i, 1)
567
+ if tta:
568
+ i = int(i / 8)
569
+ self.logger.info(f"Predicted {len(prediction)} samples, {i} tiles in total")
570
+ try:
571
+ return np.stack(prediction)
572
+ except ValueError:
573
+ self.logger.warning("Samples have different shapes, returning list.")
574
+ return prediction
575
+
576
+ def _predict_full(
577
+ self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True
578
+ ) -> np.ndarray:
579
+ """
580
+ Predict whole image without tiling.
581
+
582
+ Parameters
583
+ ----------
584
+ pred_loader : DataLoader
585
+ Prediction dataloader.
586
+ progress_bar : ProgressBar
587
+ Progress bar.
588
+ tta : bool, optional
589
+ Whether to use test time augmentation, by default True.
590
+
591
+ Returns
592
+ -------
593
+ np.ndarray
594
+ Predicted image.
595
+ """
596
+ # checks are done here to satisfy mypy
597
+ # check that configuration exists
598
+ if self.cfg is None:
599
+ raise ValueError("Configuration is not defined, cannot predict.")
600
+
601
+ # Check that the mean and std are there (= has been trained)
602
+ if not self.cfg.data.mean or not self.cfg.data.std:
603
+ raise ValueError(
604
+ "Mean or std are not specified in the configuration, prediction cannot "
605
+ "be performed."
606
+ )
607
+
608
+ prediction = []
609
+ with torch.no_grad():
610
+ for i, sample in enumerate(pred_loader):
611
+ if tta:
612
+ augmented_preds = tta_forward(sample[0])
613
+ predicted_augments = []
614
+ for augmented_pred in augmented_preds:
615
+ augmented_pred = self.model(augmented_pred.to(self.device))
616
+ predicted_augments.append(augmented_pred.cpu())
617
+ prediction.append(tta_backward(predicted_augments).squeeze())
618
+ else:
619
+ prediction.append(
620
+ self.model(sample[0].to(self.device)).squeeze().cpu().numpy()
621
+ )
622
+ progress_bar.update(i, 1)
623
+ output = denormalize(
624
+ np.stack(prediction).squeeze(),
625
+ float(self.cfg.data.mean),
626
+ float(self.cfg.data.std),
627
+ )
628
+ return output
629
+
630
+ def _get_train_dataloader(self, train_path: str) -> DataLoader:
631
+ """
632
+ Return a training dataloader.
633
+
634
+ Parameters
635
+ ----------
636
+ train_path : str
637
+ Path to the training data.
638
+
639
+ Returns
640
+ -------
641
+ DataLoader
642
+ Training data loader.
643
+
644
+ Raises
645
+ ------
646
+ ValueError
647
+ If the training configuration is None.
648
+ """
649
+ if self.cfg is None:
650
+ raise ValueError("Configuration is not defined.")
651
+
652
+ dataset = get_train_dataset(self.cfg, train_path)
653
+ dataloader = DataLoader(
654
+ dataset,
655
+ batch_size=self.cfg.training.batch_size,
656
+ num_workers=self.cfg.training.num_workers,
657
+ pin_memory=True,
658
+ )
659
+ return dataloader
660
+
661
+ def _get_val_dataloader(self, val_path: str) -> DataLoader:
662
+ """
663
+ Return a validation dataloader.
664
+
665
+ Parameters
666
+ ----------
667
+ val_path : str
668
+ Path to the validation data.
669
+
670
+ Returns
671
+ -------
672
+ DataLoader
673
+ Validation data loader.
674
+
675
+ Raises
676
+ ------
677
+ ValueError
678
+ If the configuration is None.
679
+ """
680
+ if self.cfg is None:
681
+ raise ValueError("Configuration is not defined.")
682
+
683
+ dataset = get_validation_dataset(self.cfg, val_path)
684
+ dataloader = DataLoader(
685
+ dataset,
686
+ batch_size=self.cfg.training.batch_size,
687
+ num_workers=self.cfg.training.num_workers,
688
+ pin_memory=True,
689
+ )
690
+ return dataloader
691
+
692
+ def _get_predict_dataloader(
693
+ self,
694
+ input: Union[np.ndarray, str, Path],
695
+ *,
696
+ tile_shape: Optional[List[int]] = None,
697
+ overlaps: Optional[List[int]] = None,
698
+ axes: Optional[str] = None,
699
+ ) -> Tuple[DataLoader, bool]:
700
+ """
701
+ Return a prediction dataloader.
702
+
703
+ Parameters
704
+ ----------
705
+ input : Union[np.ndarray, str, Path]
706
+ Input array or path to data.
707
+ tile_shape : Optional[List[int]], optional
708
+ 2D or 3D shape of the tiles, by default None.
709
+ overlaps : Optional[List[int]], optional
710
+ 2D or 3D overlaps between tiles, by default None.
711
+ axes : Optional[str], optional
712
+ Axes of the input array if different from the one in the configuration.
713
+
714
+ Returns
715
+ -------
716
+ Tuple[DataLoader, bool]
717
+ Tuple of prediction data loader, and whether the data is tiled.
718
+
719
+ Raises
720
+ ------
721
+ ValueError
722
+ If the configuration is None.
723
+ ValueError
724
+ If the mean or std are not specified in the configuration.
725
+ ValueError
726
+ If the input is None.
727
+ """
728
+ if self.cfg is None:
729
+ raise ValueError("Configuration is not defined.")
730
+
731
+ if self.cfg.data.mean is None or self.cfg.data.std is None:
732
+ raise ValueError(
733
+ "Mean or std are not specified in the configuration, prediction cannot "
734
+ "be performed. Was the model trained?"
735
+ )
736
+
737
+ if input is None:
738
+ raise ValueError("Input cannot be None.")
739
+
740
+ # Create dataset
741
+ if isinstance(input, np.ndarray): # np.ndarray
742
+ # Validate axes and add missing dimensions (S)C if necessary
743
+ img_axes = self.cfg.data.axes if axes is None else axes
744
+ input_expanded = add_axes(input, img_axes)
745
+
746
+ # Check if tiling requested
747
+ tiled = tile_shape is not None and overlaps is not None
748
+
749
+ # Validate tiles and overlaps
750
+ if tiled:
751
+ raise NotImplementedError(
752
+ "Tiling with in memory array is currently not implemented."
753
+ )
754
+
755
+ # Normalize input and cast to float32
756
+ normalized_input = normalize(
757
+ img=input_expanded, mean=self.cfg.data.mean, std=self.cfg.data.std
758
+ )
759
+ normalized_input = normalized_input.astype(np.float32)
760
+
761
+ # Create dataset
762
+ dataset = TensorDataset(torch.from_numpy(normalized_input))
763
+
764
+ elif isinstance(input, str) or isinstance(input, Path): # path
765
+ # Create dataset
766
+ dataset = get_prediction_dataset(
767
+ self.cfg,
768
+ pred_path=input,
769
+ tile_shape=tile_shape,
770
+ overlaps=overlaps,
771
+ axes=axes,
772
+ )
773
+
774
+ tiled = (
775
+ hasattr(dataset, "patch_extraction_method")
776
+ and dataset.patch_extraction_method is not None
777
+ )
778
+ return (
779
+ DataLoader(
780
+ dataset,
781
+ batch_size=1,
782
+ num_workers=0,
783
+ pin_memory=True,
784
+ ),
785
+ tiled,
786
+ )
787
+
788
+ def _save_checkpoint(
789
+ self, epoch: int, losses: List[float], save_method: str
790
+ ) -> Path:
791
+ """
792
+ Save checkpoint.
793
+
794
+ Currently only supports saving using `save_method="state_dict"`.
795
+
796
+ Parameters
797
+ ----------
798
+ epoch : int
799
+ Last epoch.
800
+ losses : List[float]
801
+ List of losses.
802
+ save_method : str
803
+ Method to save the model. Currently only supports `state_dict`.
804
+
805
+ Returns
806
+ -------
807
+ Path
808
+ Path to the saved checkpoint.
809
+
810
+ Raises
811
+ ------
812
+ ValueError
813
+ If the configuration is None.
814
+ NotImplementedError
815
+ If the requested save method is not supported.
816
+ """
817
+ if self.cfg is None:
818
+ raise ValueError("Configuration is not defined.")
819
+
820
+ if epoch == 0 or losses[-1] == min(losses):
821
+ name = f"{self.cfg.experiment_name}_best.pth"
822
+ else:
823
+ name = f"{self.cfg.experiment_name}_latest.pth"
824
+ workdir = self.cfg.working_directory
825
+ workdir.mkdir(parents=True, exist_ok=True)
826
+
827
+ if save_method == "state_dict":
828
+ checkpoint = {
829
+ "epoch": epoch,
830
+ "model_state_dict": self.model.state_dict(),
831
+ "optimizer_state_dict": self.optimizer.state_dict(),
832
+ "scheduler_state_dict": self.lr_scheduler.state_dict(),
833
+ "grad_scaler_state_dict": self.scaler.state_dict(),
834
+ "loss": losses[-1],
835
+ "config": self.cfg.model_dump(),
836
+ }
837
+ torch.save(checkpoint, workdir / name)
838
+ else:
839
+ raise NotImplementedError("Invalid save method.")
840
+
841
+ return self.cfg.working_directory.absolute() / name
842
+
843
+ def __del__(self) -> None:
844
+ """Exit logger."""
845
+ if hasattr(self, "logger"):
846
+ for handler in self.logger.handlers:
847
+ if isinstance(handler, FileHandler):
848
+ self.logger.removeHandler(handler)
849
+ handler.close()
850
+
851
+ def _get_sample_io_files(
852
+ self,
853
+ input_array: Optional[np.ndarray] = None,
854
+ axes: Optional[str] = None,
855
+ ) -> Tuple[List[str], List[str]]:
856
+ """
857
+ Create numpy format for use as inputs and outputs in the bioimage.io archive.
858
+
859
+ Parameters
860
+ ----------
861
+ input_array : Optional[np.ndarray], optional
862
+ Input array to use for the bioimage.io model zoo, by default None.
863
+ axes : Optional[str], optional
864
+ Axes from the configuration.
865
+
866
+ Returns
867
+ -------
868
+ Tuple[List[str], List[str]]
869
+ Tuple of input and output file paths.
870
+
871
+ Raises
872
+ ------
873
+ ValueError
874
+ If the configuration is not defined.
875
+ """
876
+ if self.cfg is not None and self._input is not None:
877
+ # use the input array if provided, otherwise use the first validation sample
878
+ if input_array is not None:
879
+ array_in = input_array
880
+
881
+ # add axes to be compatible with the axes declared in the RDF specs
882
+ add_axes(array_in, axes)
883
+ else:
884
+ array_in = self._input
885
+
886
+ # predict (no tta since BMZ does not apply it)
887
+ array_out = self.predict(array_in, tta=False)
888
+
889
+ # add singleton dimensions (for compatibility with model axes)
890
+ # indeed, BMZ applies the model but CAREamics function are meant
891
+ # to work on user data (potentially with no S or C axe)
892
+ array_out = array_out[np.newaxis, np.newaxis, ...]
893
+
894
+ # save numpy files
895
+ workdir = self.cfg.working_directory
896
+ in_file = workdir.joinpath("test_inputs.npy")
897
+ np.save(in_file, array_in)
898
+ out_file = workdir.joinpath("test_outputs.npy")
899
+ np.save(out_file, array_out)
900
+
901
+ return [str(in_file.absolute())], [str(out_file.absolute())]
902
+ else:
903
+ raise ValueError("Configuration is not defined or model was not trained.")
904
+
905
+ def _generate_rdf(
906
+ self,
907
+ *,
908
+ model_specs: Optional[dict] = None,
909
+ input_array: Optional[np.ndarray] = None,
910
+ ) -> dict:
911
+ """
912
+ Generate rdf data for bioimage.io format export.
913
+
914
+ Parameters
915
+ ----------
916
+ model_specs : Optional[dict], optional
917
+ Custom specs if different than the default ones, by default None.
918
+ input_array : Optional[np.ndarray], optional
919
+ Input array to use for the bioimage.io model zoo, by default None.
920
+
921
+ Returns
922
+ -------
923
+ dict
924
+ RDF specs.
925
+
926
+ Raises
927
+ ------
928
+ ValueError
929
+ If the mean or std are not specified in the configuration.
930
+ ValueError
931
+ If the configuration is not defined.
932
+ """
933
+ if self.cfg is not None:
934
+ if self.cfg.data.mean is None or self.cfg.data.std is None:
935
+ raise ValueError(
936
+ "Mean or std are not specified in the configuration, export to "
937
+ "bioimage.io format is not possible."
938
+ )
939
+
940
+ # set in/out axes from config
941
+ axes = self.cfg.data.axes.lower().replace("s", "")
942
+ if "c" not in axes:
943
+ axes = "c" + axes
944
+ if "b" not in axes:
945
+ axes = "b" + axes
946
+
947
+ # get in/out samples' files
948
+ test_inputs, test_outputs = self._get_sample_io_files(
949
+ input_array, self.cfg.data.axes
950
+ )
951
+
952
+ specs = get_default_model_specs(
953
+ "Noise2Void",
954
+ self.cfg.data.mean,
955
+ self.cfg.data.std,
956
+ self.cfg.algorithm.is_3D,
957
+ )
958
+ if model_specs is not None:
959
+ specs.update(model_specs)
960
+
961
+ specs.update(
962
+ {
963
+ "test_inputs": test_inputs,
964
+ "test_outputs": test_outputs,
965
+ "input_axes": [axes],
966
+ "output_axes": [axes],
967
+ }
968
+ )
969
+ return specs
970
+ else:
971
+ raise ValueError("Configuration is not defined or model was not trained.")
972
+
973
+ def save_as_bioimage(
974
+ self,
975
+ output_zip: Union[Path, str],
976
+ model_specs: Optional[dict] = None,
977
+ input_array: Optional[np.ndarray] = None,
978
+ ) -> None:
979
+ """
980
+ Export the current model to BioImage.io model zoo format.
981
+
982
+ Custom specs can be passed in `model_specs (e.g. maintainers). For a description
983
+ of the model RDF, refer to
984
+ github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/model_spec_latest.md.
985
+
986
+ Parameters
987
+ ----------
988
+ output_zip : Union[Path, str]
989
+ Where to save the model zip file.
990
+ model_specs : Optional[dict]
991
+ A dictionary with keys being the bioimage-core build_model parameters. If
992
+ None then it will be populated by the model default specs.
993
+ input_array : Optional[np.ndarray]
994
+ An array to use as input for the bioimage.io model zoo. If None then the
995
+ first validation sample will be used. Note that the array must have S and
996
+ C dimensions (e.g. SCYX), even if only singleton dimensions.
997
+
998
+ Raises
999
+ ------
1000
+ ValueError
1001
+ If the configuration is not defined.
1002
+ """
1003
+ if self.cfg is not None:
1004
+ # Generate specs
1005
+ specs = self._generate_rdf(model_specs=model_specs, input_array=input_array)
1006
+
1007
+ # Build model
1008
+ save_bioimage_model(
1009
+ path=output_zip,
1010
+ config=self.cfg,
1011
+ specs=specs,
1012
+ )
1013
+ else:
1014
+ raise ValueError("Configuration is not defined.")