careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
careamics/engine.py DELETED
@@ -1,954 +0,0 @@
1
- """
2
- Engine module.
3
-
4
- This module contains the main CAREamics class, the Engine. The Engine allows training
5
- a model and using it for prediction.
6
- """
7
- from logging import FileHandler
8
- from pathlib import Path
9
- from typing import Any, Dict, List, Optional, Tuple, Union
10
-
11
- import numpy as np
12
- import torch
13
- from bioimageio.spec.model.raw_nodes import Model as BioimageModel
14
- from torch.utils.data import DataLoader, TensorDataset
15
-
16
- from .bioimage import (
17
- build_zip_model,
18
- get_default_model_specs,
19
- )
20
- from .config import Configuration, load_configuration
21
- from .dataset.prepare_dataset import (
22
- get_prediction_dataset,
23
- get_train_dataset,
24
- get_validation_dataset,
25
- )
26
- from .losses import create_loss_function
27
- from .models import create_model
28
- from .prediction import (
29
- stitch_prediction,
30
- tta_backward,
31
- tta_forward,
32
- )
33
- from .utils import (
34
- MetricTracker,
35
- check_array_validity,
36
- denormalize,
37
- get_device,
38
- normalize,
39
- )
40
- from .utils.logging import ProgressBar, get_logger
41
-
42
-
43
- # TODO: refactor private methods and bioimage.io to other modules
44
- class Engine:
45
- """
46
- Class allowing training of a model and subsequent prediction.
47
-
48
- There are three ways to instantiate an Engine:
49
- 1. With a CAREamics model (.pth), by passing a path.
50
- 2. With a configuration object.
51
- 3. With a configuration file, by passing a path.
52
-
53
- In each case, the parameter name must be provided explicitly. For example:
54
- >>> engine = Engine(config_path="path/to/config.yaml")
55
-
56
- Note that only one of these options can be used at a time, in the order listed
57
- above.
58
-
59
- Parameters
60
- ----------
61
- config : Optional[Configuration], optional
62
- Configuration object, by default None.
63
- config_path : Optional[Union[str, Path]], optional
64
- Path to configuration file, by default None.
65
- model_path : Optional[Union[str, Path]], optional
66
- Path to model file, by default None.
67
- seed : int, optional
68
- Seed for reproducibility, by default 42.
69
-
70
- Attributes
71
- ----------
72
- cfg : Configuration
73
- Configuration.
74
- device : torch.device
75
- Device (CPU or GPU).
76
- model : torch.nn.Module
77
- Model.
78
- optimizer : torch.optim.Optimizer
79
- Optimizer.
80
- lr_scheduler : torch.optim.lr_scheduler._LRScheduler
81
- Learning rate scheduler.
82
- scaler : torch.cuda.amp.GradScaler
83
- Gradient scaler.
84
- loss_func : Callable
85
- Loss function.
86
- logger : logging.Logger
87
- Logger.
88
- use_wandb : bool
89
- Whether to use wandb.
90
- """
91
-
92
- def __init__(
93
- self,
94
- *,
95
- config: Optional[Configuration] = None,
96
- config_path: Optional[Union[str, Path]] = None,
97
- model_path: Optional[Union[str, Path]] = None,
98
- seed: Optional[int] = 42,
99
- ) -> None:
100
- """
101
- Constructor.
102
-
103
- To disable the seed, set it to None.
104
-
105
- Parameters
106
- ----------
107
- config : Optional[Configuration], optional
108
- Configuration object, by default None.
109
- config_path : Optional[Union[str, Path]], optional
110
- Path to configuration file, by default None.
111
- model_path : Optional[Union[str, Path]], optional
112
- Path to model file, by default None.
113
- seed : int, optional
114
- Seed for reproducibility, by default 42.
115
-
116
- Raises
117
- ------
118
- ValueError
119
- If all three parameters are None.
120
- FileNotFoundError
121
- If the model or configuration path is provided but does not exist.
122
- TypeError
123
- If the configuration is not a Configuration object.
124
- UsageError
125
- If wandb is not correctly installed.
126
- ModuleNotFoundError
127
- If wandb is not installed.
128
- ValueError
129
- If the configuration failed to configure.
130
- """
131
- if model_path is not None:
132
- if not Path(model_path).exists():
133
- raise FileNotFoundError(
134
- f"Model path {model_path} is incorrect or"
135
- f" does not exist. Current working directory is: {Path.cwd()!s}"
136
- )
137
-
138
- # Ensure that config is None
139
- self.cfg = None
140
-
141
- elif config is not None:
142
- # Check that config is a Configuration object
143
- if not isinstance(config, Configuration):
144
- raise TypeError(
145
- f"config must be a Configuration object, got {type(config)}"
146
- )
147
- self.cfg = config
148
- elif config_path is not None:
149
- self.cfg = load_configuration(config_path)
150
- else:
151
- raise ValueError(
152
- "No configuration or path provided. One of configuration "
153
- "object, configuration path or model path must be provided."
154
- )
155
-
156
- # get device, CPU or GPU
157
- self.device = get_device()
158
-
159
- # Create model, optimizer, lr scheduler and gradient scaler and load everything
160
- # to the specified device
161
- (
162
- self.model,
163
- self.optimizer,
164
- self.lr_scheduler,
165
- self.scaler,
166
- self.cfg,
167
- ) = create_model(config=self.cfg, model_path=model_path, device=self.device)
168
-
169
- # create loss function
170
- if self.cfg is not None:
171
- self.loss_func = create_loss_function(self.cfg)
172
-
173
- # Set logging
174
- log_path = self.cfg.working_directory / "log.txt"
175
- self.logger = get_logger(__name__, log_path=log_path)
176
-
177
- # wandb
178
- self.use_wandb = self.cfg.training.use_wandb
179
-
180
- if self.use_wandb:
181
- try:
182
- from wandb.errors import UsageError
183
-
184
- from careamics.utils.wandb import WandBLogging
185
-
186
- try:
187
- self.wandb = WandBLogging(
188
- experiment_name=self.cfg.experiment_name,
189
- log_path=self.cfg.working_directory,
190
- config=self.cfg,
191
- model_to_watch=self.model,
192
- )
193
- except UsageError as e:
194
- self.logger.warning(
195
- f"Wandb usage error, using default logger. Check whether "
196
- f"wandb correctly configured:\n"
197
- f"{e}"
198
- )
199
- self.use_wandb = False
200
-
201
- except ModuleNotFoundError:
202
- self.logger.warning(
203
- "Wandb not installed, using default logger. Try pip install "
204
- "wandb"
205
- )
206
- self.use_wandb = False
207
- else:
208
- raise ValueError("Configuration is not defined.")
209
-
210
- def train(
211
- self,
212
- train_path: str,
213
- val_path: str,
214
- ) -> Tuple[List[Any], List[Any]]:
215
- """
216
- Train the network.
217
-
218
- The training and validation data given by the paths must be compatible with the
219
- axes and data format provided in the configuration.
220
-
221
- Parameters
222
- ----------
223
- train_path : Union[str, Path]
224
- Path to the training data.
225
- val_path : Union[str, Path]
226
- Path to the validation data.
227
-
228
- Returns
229
- -------
230
- Tuple[List[Any], List[Any]]
231
- Tuple of training and validation statistics.
232
-
233
- Raises
234
- ------
235
- ValueError
236
- Raise a ValueError if the configuration is missing.
237
- """
238
- if self.cfg is None:
239
- raise ValueError("Configuration is not defined, cannot train.")
240
-
241
- # General func
242
- train_loader = self._get_train_dataloader(train_path)
243
-
244
- # Set mean and std from train dataset of none
245
- if self.cfg.data.mean is None or self.cfg.data.std is None:
246
- self.cfg.data.set_mean_and_std(
247
- train_loader.dataset.mean, train_loader.dataset.std
248
- )
249
-
250
- eval_loader = self._get_val_dataloader(val_path)
251
- self.logger.info(f"Starting training for {self.cfg.training.num_epochs} epochs")
252
-
253
- val_losses = []
254
-
255
- try:
256
- train_stats = []
257
- eval_stats = []
258
-
259
- # loop over the dataset multiple times
260
- for epoch in range(self.cfg.training.num_epochs):
261
- if hasattr(train_loader.dataset, "__len__"):
262
- epoch_size = train_loader.__len__()
263
- else:
264
- epoch_size = None
265
-
266
- progress_bar = ProgressBar(
267
- max_value=epoch_size,
268
- epoch=epoch,
269
- num_epochs=self.cfg.training.num_epochs,
270
- mode="train",
271
- )
272
- # train_epoch = train_op(self._train_single_epoch,)
273
- # Perform training step
274
- train_outputs, epoch_size = self._train_single_epoch(
275
- train_loader,
276
- progress_bar,
277
- self.cfg.training.amp.use,
278
- )
279
- # Perform validation step
280
- eval_outputs = self._evaluate(eval_loader)
281
- val_losses.append(eval_outputs["loss"])
282
- learning_rate = self.optimizer.param_groups[0]["lr"]
283
-
284
- progress_bar.add(
285
- 1,
286
- values=[
287
- ("train_loss", train_outputs["loss"]),
288
- ("val loss", eval_outputs["loss"]),
289
- ("lr", learning_rate),
290
- ],
291
- )
292
- # Add update scheduler rule based on type
293
- self.lr_scheduler.step(eval_outputs["loss"])
294
-
295
- if self.use_wandb:
296
- metrics = {
297
- "train": train_outputs,
298
- "eval": eval_outputs,
299
- "lr": learning_rate,
300
- }
301
- self.wandb.log_metrics(metrics)
302
-
303
- train_stats.append(train_outputs)
304
- eval_stats.append(eval_outputs)
305
-
306
- checkpoint_path = self._save_checkpoint(epoch, val_losses, "state_dict")
307
- self.logger.info(f"Saved checkpoint to {checkpoint_path}")
308
-
309
- except KeyboardInterrupt:
310
- self.logger.info("Training interrupted")
311
-
312
- return train_stats, eval_stats
313
-
314
- def _train_single_epoch(
315
- self,
316
- loader: torch.utils.data.DataLoader,
317
- progress_bar: ProgressBar,
318
- amp: bool,
319
- ) -> Tuple[Dict[str, float], int]:
320
- """
321
- Train for a single epoch.
322
-
323
- Parameters
324
- ----------
325
- loader : torch.utils.data.DataLoader
326
- Training dataloader.
327
- progress_bar : ProgressBar
328
- Progress bar.
329
- amp : bool
330
- Whether to use automatic mixed precision.
331
-
332
- Returns
333
- -------
334
- Tuple[Dict[str, float], int]
335
- Tuple of training metrics and epoch size.
336
-
337
- Raises
338
- ------
339
- ValueError
340
- If the configuration is missing.
341
- """
342
- if self.cfg is not None:
343
- avg_loss = MetricTracker()
344
- self.model.train()
345
- epoch_size = 0
346
-
347
- for i, (batch, *auxillary) in enumerate(loader):
348
- self.optimizer.zero_grad(set_to_none=True)
349
-
350
- with torch.cuda.amp.autocast(enabled=amp):
351
- outputs = self.model(batch.to(self.device))
352
-
353
- loss = self.loss_func(
354
- outputs, *[a.to(self.device) for a in auxillary], self.device
355
- )
356
- self.scaler.scale(loss).backward()
357
- avg_loss.update(loss.detach(), batch.shape[0])
358
-
359
- progress_bar.update(
360
- current_step=i,
361
- batch_size=self.cfg.training.batch_size,
362
- )
363
-
364
- self.optimizer.step()
365
- epoch_size += 1
366
-
367
- return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()}, epoch_size
368
- else:
369
- raise ValueError("Configuration is not defined, cannot train.")
370
-
371
- def _evaluate(self, val_loader: torch.utils.data.DataLoader) -> Dict[str, float]:
372
- """
373
- Perform validation step.
374
-
375
- Parameters
376
- ----------
377
- val_loader : torch.utils.data.DataLoader
378
- Validation dataloader.
379
-
380
- Returns
381
- -------
382
- Dict[str, float]
383
- Loss value on the validation set.
384
- """
385
- self.model.eval()
386
- avg_loss = MetricTracker()
387
-
388
- with torch.no_grad():
389
- for patch, *auxillary in val_loader:
390
- outputs = self.model(patch.to(self.device))
391
- loss = self.loss_func(
392
- outputs, *[a.to(self.device) for a in auxillary], self.device
393
- )
394
- avg_loss.update(loss.detach(), patch.shape[0])
395
- return {"loss": avg_loss.avg.to(torch.float16).cpu().numpy()}
396
-
397
- def predict(
398
- self,
399
- input: Union[np.ndarray, str, Path],
400
- *,
401
- tile_shape: Optional[List[int]] = None,
402
- overlaps: Optional[List[int]] = None,
403
- axes: Optional[str] = None,
404
- tta: bool = True,
405
- ) -> Union[np.ndarray, List[np.ndarray]]:
406
- """
407
- Predict using the current model on an input array or a path to data.
408
-
409
- The Engine must have previously been trained and mean/std be specified in
410
- its configuration.
411
-
412
- To use tiling, both `tile_shape` and `overlaps` must be specified, have same
413
- length, be divisible by 2 and greater than 0. Finally, the overlaps must be
414
- smaller than the tiles.
415
-
416
- Parameters
417
- ----------
418
- input : Union[np.ndarra, str, Path]
419
- Input data, either an array or a path to the data.
420
- tile_shape : Optional[List[int]], optional
421
- 2D or 3D shape of the tiles to be predicted, by default None.
422
- overlaps : Optional[List[int]], optional
423
- 2D or 3D overlaps between tiles, by default None.
424
- axes : Optional[str], optional
425
- Axes of the input array if different from the one in the configuration, by
426
- default None.
427
- tta : bool, optional
428
- Whether to use test time augmentation, by default True.
429
-
430
- Returns
431
- -------
432
- Union[np.ndarray, List[np.ndarray]]
433
- Predicted image array of the same shape as the input, or list of arrays
434
- if the arrays have inconsistent shapes.
435
-
436
- Raises
437
- ------
438
- ValueError
439
- If the configuration is missing.
440
- ValueError
441
- If the mean or std are not specified in the configuration (untrained model).
442
- """
443
- if self.cfg is None:
444
- raise ValueError("Configuration is not defined, cannot predict.")
445
-
446
- # Check that the mean and std are there (= has been trained)
447
- if not self.cfg.data.mean or not self.cfg.data.std:
448
- raise ValueError(
449
- "Mean or std are not specified in the configuration, prediction cannot "
450
- "be performed."
451
- )
452
-
453
- # set model to eval mode
454
- self.model.to(self.device)
455
- self.model.eval()
456
-
457
- progress_bar = ProgressBar(num_epochs=1, mode="predict")
458
-
459
- # Get dataloader
460
- pred_loader, tiled = self._get_predict_dataloader(
461
- input=input, tile_shape=tile_shape, overlaps=overlaps, axes=axes
462
- )
463
-
464
- # Start prediction
465
- self.logger.info("Starting prediction")
466
- if tiled:
467
- self.logger.info("Starting tiled prediction")
468
- prediction = self._predict_tiled(pred_loader, progress_bar, tta)
469
- else:
470
- self.logger.info("Starting prediction on whole sample")
471
- prediction = self._predict_full(pred_loader, progress_bar, tta)
472
-
473
- return prediction
474
-
475
- def _predict_tiled(
476
- self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True
477
- ) -> Union[np.ndarray, List[np.ndarray]]:
478
- """
479
- Predict using tiling.
480
-
481
- Parameters
482
- ----------
483
- pred_loader : DataLoader
484
- Prediction dataloader.
485
- progress_bar : ProgressBar
486
- Progress bar.
487
- tta : bool, optional
488
- Whether to use test time augmentation, by default True.
489
-
490
- Returns
491
- -------
492
- Union[np.ndarray, List[np.ndarray]]
493
- Predicted image, or list of predictions if the images have different sizes.
494
-
495
- Warns
496
- -----
497
- UserWarning
498
- If the samples have different shapes, the prediction then returns a list.
499
- """
500
- prediction = []
501
- tiles = []
502
- stitching_data = []
503
-
504
- with torch.no_grad():
505
- for i, (tile, *auxillary) in enumerate(pred_loader):
506
- # Unpack auxillary data into last tile indicator and data, required to
507
- # stitch tiles together
508
- if auxillary:
509
- last_tile, *data = auxillary
510
-
511
- if tta:
512
- augmented_tiles = tta_forward(tile)
513
- predicted_augments = []
514
- for augmented_tile in augmented_tiles:
515
- augmented_pred = self.model(augmented_tile.to(self.device))
516
- predicted_augments.append(
517
- augmented_pred.squeeze().cpu().numpy()
518
- )
519
- tiles.append(tta_backward(predicted_augments).squeeze())
520
- else:
521
- tiles.append(
522
- self.model(tile.to(self.device)).squeeze().cpu().numpy()
523
- )
524
-
525
- stitching_data.append(data)
526
-
527
- if last_tile:
528
- # Stitch tiles together if sample is finished
529
- predicted_sample = stitch_prediction(tiles, stitching_data)
530
- predicted_sample = denormalize(
531
- predicted_sample,
532
- float(self.cfg.data.mean), # type: ignore
533
- float(self.cfg.data.std), # type: ignore
534
- )
535
- prediction.append(predicted_sample)
536
- tiles.clear()
537
- stitching_data.clear()
538
-
539
- progress_bar.update(i, 1)
540
- if tta:
541
- i = int(i / 8)
542
- self.logger.info(f"Predicted {len(prediction)} samples, {i} tiles in total")
543
- try:
544
- return np.stack(prediction)
545
- except ValueError:
546
- self.logger.warning("Samples have different shapes, returning list.")
547
- return prediction
548
-
549
- def _predict_full(
550
- self, pred_loader: DataLoader, progress_bar: ProgressBar, tta: bool = True
551
- ) -> np.ndarray:
552
- """
553
- Predict whole image without tiling.
554
-
555
- Parameters
556
- ----------
557
- pred_loader : DataLoader
558
- Prediction dataloader.
559
- progress_bar : ProgressBar
560
- Progress bar.
561
- tta : bool, optional
562
- Whether to use test time augmentation, by default True.
563
-
564
- Returns
565
- -------
566
- np.ndarray
567
- Predicted image.
568
- """
569
- prediction = []
570
- with torch.no_grad():
571
- for i, sample in enumerate(pred_loader):
572
- if tta:
573
- augmented_preds = tta_forward(sample[0])
574
- predicted_augments = []
575
- for augmented_pred in augmented_preds:
576
- augmented_pred = self.model(augmented_pred.to(self.device))
577
- predicted_augments.append(
578
- augmented_pred.squeeze().cpu().numpy()
579
- )
580
- prediction.append(tta_backward(predicted_augments).squeeze())
581
- else:
582
- prediction.append(
583
- self.model(sample[0].to(self.device)).squeeze().cpu().numpy()
584
- )
585
- progress_bar.update(i, 1)
586
- output = denormalize(
587
- np.stack(prediction), float(self.cfg.data.mean), float(self.cfg.data.std) # type: ignore
588
- )
589
- return output
590
-
591
- def _get_train_dataloader(self, train_path: str) -> DataLoader:
592
- """
593
- Return a training dataloader.
594
-
595
- Parameters
596
- ----------
597
- train_path : str
598
- Path to the training data.
599
-
600
- Returns
601
- -------
602
- DataLoader
603
- Training data loader.
604
-
605
- Raises
606
- ------
607
- ValueError
608
- If the training configuration is None.
609
- """
610
- if self.cfg is None:
611
- raise ValueError("Configuration is not defined.")
612
-
613
- dataset = get_train_dataset(self.cfg, train_path)
614
- dataloader = DataLoader(
615
- dataset,
616
- batch_size=self.cfg.training.batch_size,
617
- num_workers=self.cfg.training.num_workers,
618
- pin_memory=True,
619
- )
620
- return dataloader
621
-
622
- def _get_val_dataloader(self, val_path: str) -> DataLoader:
623
- """
624
- Return a validation dataloader.
625
-
626
- Parameters
627
- ----------
628
- val_path : str
629
- Path to the validation data.
630
-
631
- Returns
632
- -------
633
- DataLoader
634
- Validation data loader.
635
-
636
- Raises
637
- ------
638
- ValueError
639
- If the configuration is None.
640
- """
641
- if self.cfg is None:
642
- raise ValueError("Configuration is not defined.")
643
-
644
- dataset = get_validation_dataset(self.cfg, val_path)
645
- dataloader = DataLoader(
646
- dataset,
647
- batch_size=self.cfg.training.batch_size,
648
- num_workers=self.cfg.training.num_workers,
649
- pin_memory=True,
650
- )
651
- return dataloader
652
-
653
- def _get_predict_dataloader(
654
- self,
655
- input: Union[np.ndarray, str, Path],
656
- *,
657
- tile_shape: Optional[List[int]] = None,
658
- overlaps: Optional[List[int]] = None,
659
- axes: Optional[str] = None,
660
- ) -> Tuple[DataLoader, bool]:
661
- """
662
- Return a prediction dataloader.
663
-
664
- Parameters
665
- ----------
666
- input : Union[np.ndarray, str, Path]
667
- Input array or path to data.
668
- tile_shape : Optional[List[int]], optional
669
- 2D or 3D shape of the tiles, by default None.
670
- overlaps : Optional[List[int]], optional
671
- 2D or 3D overlaps between tiles, by default None.
672
- axes : Optional[str], optional
673
- Axes of the input array if different from the one in the configuration.
674
-
675
- Returns
676
- -------
677
- Tuple[DataLoader, bool]
678
- Tuple of prediction data loader, and whether the data is tiled.
679
-
680
- Raises
681
- ------
682
- ValueError
683
- If the configuration is None.
684
- ValueError
685
- If the mean or std are not specified in the configuration.
686
- ValueError
687
- If the input is None.
688
- """
689
- if self.cfg is None:
690
- raise ValueError("Configuration is not defined.")
691
-
692
- if self.cfg.data.mean is None or self.cfg.data.std is None:
693
- raise ValueError(
694
- "Mean or std are not specified in the configuration, prediction cannot "
695
- "be performed. Was the model trained?"
696
- )
697
-
698
- if input is None:
699
- raise ValueError("Ccannot predict on None input.")
700
-
701
- # Create dataset
702
- if isinstance(input, np.ndarray): # np.ndarray
703
- # Check that the axes fit the input
704
- img_axes = self.cfg.data.axes if axes is None else axes
705
- # TODO are self.cfg.data.axes and axes compatible (same spatial dim)?
706
- check_array_validity(input, img_axes)
707
-
708
- # Check if tiling requested
709
- tiled = tile_shape is not None and overlaps is not None
710
-
711
- # Validate tiles and overlaps
712
- if tiled:
713
- raise NotImplementedError(
714
- "Tiling with in memory array is currently not implemented."
715
- )
716
-
717
- # check_tiling_validity(tile_shape, overlaps)
718
-
719
- # Normalize input and cast to float32
720
- normalized_input = normalize(
721
- img=input, mean=self.cfg.data.mean, std=self.cfg.data.std
722
- )
723
- normalized_input = normalized_input.astype(np.float32)
724
-
725
- # Create dataset
726
- dataset = TensorDataset(torch.from_numpy(normalized_input))
727
-
728
- elif isinstance(input, str) or isinstance(input, Path): # path
729
- # Create dataset
730
- dataset = get_prediction_dataset(
731
- self.cfg,
732
- pred_path=input,
733
- tile_shape=tile_shape,
734
- overlaps=overlaps,
735
- axes=axes,
736
- )
737
-
738
- tiled = (
739
- hasattr(dataset, "patch_extraction_method")
740
- and dataset.patch_extraction_method is not None
741
- )
742
- return (
743
- DataLoader(
744
- dataset,
745
- batch_size=1,
746
- num_workers=0,
747
- pin_memory=True,
748
- ),
749
- tiled,
750
- )
751
-
752
- def _save_checkpoint(
753
- self, epoch: int, losses: List[float], save_method: str
754
- ) -> Path:
755
- """
756
- Save checkpoint.
757
-
758
- Currently only supports saving using `save_method="state_dict"`.
759
-
760
- Parameters
761
- ----------
762
- epoch : int
763
- Last epoch.
764
- losses : List[float]
765
- List of losses.
766
- save_method : str
767
- Method to save the model. Currently only supports `state_dict`.
768
-
769
- Returns
770
- -------
771
- Path
772
- Path to the saved checkpoint.
773
-
774
- Raises
775
- ------
776
- ValueError
777
- If the configuration is None.
778
- NotImplementedError
779
- If the requested save method is not supported.
780
- """
781
- if self.cfg is None:
782
- raise ValueError("Configuration is not defined.")
783
-
784
- if epoch == 0 or losses[-1] == min(losses):
785
- name = f"{self.cfg.experiment_name}_best.pth"
786
- else:
787
- name = f"{self.cfg.experiment_name}_latest.pth"
788
- workdir = self.cfg.working_directory
789
- workdir.mkdir(parents=True, exist_ok=True)
790
-
791
- if save_method == "state_dict":
792
- checkpoint = {
793
- "epoch": epoch,
794
- "model_state_dict": self.model.state_dict(),
795
- "optimizer_state_dict": self.optimizer.state_dict(),
796
- "scheduler_state_dict": self.lr_scheduler.state_dict(),
797
- "grad_scaler_state_dict": self.scaler.state_dict(),
798
- "loss": losses[-1],
799
- "config": self.cfg.model_dump(),
800
- }
801
- torch.save(checkpoint, workdir / name)
802
- else:
803
- raise NotImplementedError("Invalid save method.")
804
-
805
- return self.cfg.working_directory.absolute() / name
806
-
807
- def __del__(self) -> None:
808
- """Exit logger."""
809
- if hasattr(self, "logger"):
810
- for handler in self.logger.handlers:
811
- if isinstance(handler, FileHandler):
812
- self.logger.removeHandler(handler)
813
- handler.close()
814
-
815
- def _generate_rdf(self, model_specs: Optional[dict] = None) -> dict:
816
- """
817
- Generate rdf data for bioimage.io format export.
818
-
819
- Parameters
820
- ----------
821
- model_specs : Optional[dict], optional
822
- Custom specs if different than the default ones, by default None.
823
-
824
- Returns
825
- -------
826
- dict
827
- RDF specs.
828
-
829
- Raises
830
- ------
831
- ValueError
832
- If the mean or std are not specified in the configuration.
833
- ValueError
834
- If the configuration is not defined.
835
- """
836
- if self.cfg is not None:
837
- if self.cfg.data.mean is None or self.cfg.data.std is None:
838
- raise ValueError(
839
- "Mean or std are not specified in the configuration, export to "
840
- "bioimage.io format is not possible."
841
- )
842
-
843
- # set in/out axes from config
844
- axes = self.cfg.data.axes.lower().replace("s", "")
845
- if "c" not in axes:
846
- axes = "c" + axes
847
- if "b" not in axes:
848
- axes = "b" + axes
849
-
850
- # get in/out samples' files
851
- test_inputs, test_outputs = self._get_sample_io_files(axes)
852
-
853
- specs = get_default_model_specs(
854
- "Noise2Void",
855
- self.cfg.data.mean,
856
- self.cfg.data.std,
857
- self.cfg.algorithm.is_3D,
858
- )
859
- if model_specs is not None:
860
- specs.update(model_specs)
861
-
862
- specs.update(
863
- {
864
- "architecture": "careamics.models.unet",
865
- "test_inputs": test_inputs,
866
- "test_outputs": test_outputs,
867
- "input_axes": [axes],
868
- "output_axes": [axes],
869
- }
870
- )
871
- return specs
872
- else:
873
- raise ValueError("Configuration is not defined.")
874
-
875
- def save_as_bioimage(
876
- self, output_zip: Union[Path, str], model_specs: Optional[dict] = None
877
- ) -> BioimageModel:
878
- """
879
- Export the current model to BioImage.io model zoo format.
880
-
881
- Parameters
882
- ----------
883
- output_zip : Union[Path, str]
884
- Where to save the model zip file.
885
- model_specs : Optional[dict]
886
- A dictionary with keys being the bioimage-core build_model parameters. If
887
- None then it will be populated by the model default specs.
888
-
889
- Returns
890
- -------
891
- BioimageModel
892
- Bioimage.io model object.
893
-
894
- Raises
895
- ------
896
- ValueError
897
- If the configuration is not defined.
898
- """
899
- if self.cfg is not None:
900
- # Generate specs
901
- specs = self._generate_rdf(model_specs)
902
-
903
- # Build model
904
- raw_model = build_zip_model(
905
- path=output_zip,
906
- config=self.cfg,
907
- model_specs=specs,
908
- )
909
-
910
- return raw_model
911
- else:
912
- raise ValueError("Configuration is not defined.")
913
-
914
- def _get_sample_io_files(self, axes: str) -> Tuple[List[str], List[str]]:
915
- """
916
- Create numpy format for use as inputs and outputs in the bioimage.io archive.
917
-
918
- Parameters
919
- ----------
920
- axes : str
921
- Input and output axes.
922
-
923
- Returns
924
- -------
925
- Tuple[List[str], List[str]]
926
- Tuple of input and output file paths.
927
-
928
- Raises
929
- ------
930
- ValueError
931
- If the configuration is not defined.
932
- """
933
- # input:
934
- if self.cfg is not None:
935
- sample_input = np.random.randn(*self.cfg.training.patch_size)
936
- # if there are more input axes (like channel, ...),
937
- # then expand the sample dimensions.
938
- len_diff = len(axes) - len(self.cfg.training.patch_size)
939
- if len_diff > 0:
940
- sample_input = np.expand_dims(
941
- sample_input, axis=tuple(i for i in range(len_diff))
942
- )
943
- sample_output = np.random.randn(*sample_input.shape)
944
-
945
- # save numpy files
946
- workdir = self.cfg.working_directory
947
- in_file = workdir.joinpath("test_inputs.npy")
948
- np.save(in_file, sample_input)
949
- out_file = workdir.joinpath("test_outputs.npy")
950
- np.save(out_file, sample_output)
951
-
952
- return [str(in_file.absolute())], [str(out_file.absolute())]
953
- else:
954
- raise ValueError("Configuration is not defined.")