careamics 0.1.0rc7__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +83 -62
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -0
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +2 -0
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +1 -79
  11. careamics/config/configuration_model.py +12 -7
  12. careamics/config/data_model.py +29 -10
  13. careamics/config/inference_model.py +12 -2
  14. careamics/config/optimizer_models.py +6 -0
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/tile_information.py +10 -0
  17. careamics/config/training_model.py +5 -1
  18. careamics/dataset/dataset_utils/__init__.py +0 -6
  19. careamics/dataset/dataset_utils/file_utils.py +1 -1
  20. careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
  21. careamics/dataset/in_memory_dataset.py +37 -21
  22. careamics/dataset/iterable_dataset.py +38 -34
  23. careamics/dataset/iterable_pred_dataset.py +2 -1
  24. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  25. careamics/dataset/patching/patching.py +53 -37
  26. careamics/file_io/__init__.py +7 -0
  27. careamics/file_io/read/__init__.py +11 -0
  28. careamics/file_io/read/get_func.py +56 -0
  29. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -1
  30. careamics/file_io/write/__init__.py +9 -0
  31. careamics/file_io/write/get_func.py +59 -0
  32. careamics/file_io/write/tiff.py +39 -0
  33. careamics/lightning/__init__.py +17 -0
  34. careamics/{lightning_module.py → lightning/lightning_module.py} +58 -85
  35. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +78 -116
  36. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +134 -214
  37. careamics/model_io/bmz_io.py +1 -1
  38. careamics/model_io/model_io_utils.py +1 -1
  39. careamics/prediction_utils/__init__.py +0 -2
  40. careamics/prediction_utils/prediction_outputs.py +18 -46
  41. careamics/prediction_utils/stitch_prediction.py +17 -14
  42. careamics/utils/__init__.py +2 -0
  43. careamics/utils/autocorrelation.py +40 -0
  44. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +1 -1
  45. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/RECORD +51 -46
  46. careamics/config/configuration_example.py +0 -86
  47. careamics/dataset/dataset_utils/read_utils.py +0 -27
  48. careamics/prediction_utils/create_pred_datamodule.py +0 -185
  49. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  50. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  51. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  52. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  53. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +0 -0
  54. {careamics-0.1.0rc7.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
careamics/__init__.py CHANGED
@@ -7,20 +7,7 @@ try:
7
7
  except PackageNotFoundError:
8
8
  __version__ = "uninstalled"
9
9
 
10
- __all__ = [
11
- "CAREamist",
12
- "CAREamicsModuleWrapper",
13
- "CAREamicsPredictData",
14
- "CAREamicsTrainData",
15
- "Configuration",
16
- "load_configuration",
17
- "save_configuration",
18
- "TrainingDataWrapper",
19
- "PredictDataWrapper",
20
- ]
10
+ __all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]
21
11
 
22
12
  from .careamist import CAREamist
23
13
  from .config import Configuration, load_configuration, save_configuration
24
- from .lightning_datamodule import CAREamicsTrainData, TrainingDataWrapper
25
- from .lightning_module import CAREamicsModuleWrapper
26
- from .lightning_prediction_datamodule import CAREamicsPredictData, PredictDataWrapper
careamics/careamist.py CHANGED
@@ -13,22 +13,29 @@ from pytorch_lightning.callbacks import (
13
13
  )
14
14
  from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
15
 
16
- from careamics.callbacks import ProgressBarCallback
17
16
  from careamics.config import (
18
17
  Configuration,
19
18
  load_configuration,
20
19
  )
21
- from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
20
+ from careamics.config.support import (
21
+ SupportedAlgorithm,
22
+ SupportedArchitecture,
23
+ SupportedData,
24
+ SupportedLogger,
25
+ )
22
26
  from careamics.dataset.dataset_utils import reshape_array
23
- from careamics.lightning_datamodule import CAREamicsTrainData
24
- from careamics.lightning_module import CAREamicsModule
27
+ from careamics.lightning import (
28
+ CAREamicsModule,
29
+ HyperParametersCallback,
30
+ PredictDataModule,
31
+ ProgressBarCallback,
32
+ TrainDataModule,
33
+ create_predict_datamodule,
34
+ )
25
35
  from careamics.model_io import export_to_bmz, load_pretrained
26
- from careamics.prediction_utils import convert_outputs, create_pred_datamodule
36
+ from careamics.prediction_utils import convert_outputs
27
37
  from careamics.utils import check_path_exists, get_logger
28
38
 
29
- from .callbacks import HyperParametersCallback
30
- from .lightning_prediction_datamodule import CAREamicsPredictData
31
-
32
39
  logger = get_logger(__name__)
33
40
 
34
41
  LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
@@ -61,9 +68,9 @@ class CAREamist:
61
68
  Experiment logger, "wandb" or "tensorboard".
62
69
  work_dir : pathlib.Path
63
70
  Working directory.
64
- train_datamodule : CAREamicsTrainData
71
+ train_datamodule : TrainDataModule
65
72
  Training datamodule.
66
- pred_datamodule : CAREamicsPredictData
73
+ pred_datamodule : PredictDataModule
67
74
  Prediction datamodule.
68
75
  """
69
76
 
@@ -193,8 +200,8 @@ class CAREamist:
193
200
  )
194
201
 
195
202
  # place holder for the datamodules
196
- self.train_datamodule: Optional[CAREamicsTrainData] = None
197
- self.pred_datamodule: Optional[CAREamicsPredictData] = None
203
+ self.train_datamodule: Optional[TrainDataModule] = None
204
+ self.pred_datamodule: Optional[PredictDataModule] = None
198
205
 
199
206
  def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
200
207
  """
@@ -246,7 +253,7 @@ class CAREamist:
246
253
  def train(
247
254
  self,
248
255
  *,
249
- datamodule: Optional[CAREamicsTrainData] = None,
256
+ datamodule: Optional[TrainDataModule] = None,
250
257
  train_source: Optional[Union[Path, str, NDArray]] = None,
251
258
  val_source: Optional[Union[Path, str, NDArray]] = None,
252
259
  train_target: Optional[Union[Path, str, NDArray]] = None,
@@ -273,7 +280,7 @@ class CAREamist:
273
280
 
274
281
  Parameters
275
282
  ----------
276
- datamodule : CAREamicsTrainData, optional
283
+ datamodule : TrainDataModule, optional
277
284
  Datamodule to train on, by default None.
278
285
  train_source : pathlib.Path or str or NDArray, optional
279
286
  Train source, if no datamodule is provided, by default None.
@@ -375,17 +382,17 @@ class CAREamist:
375
382
 
376
383
  else:
377
384
  raise ValueError(
378
- f"Invalid input, expected a str, Path, array or CAREamicsTrainData "
385
+ f"Invalid input, expected a str, Path, array or TrainDataModule "
379
386
  f"instance (got {type(train_source)})."
380
387
  )
381
388
 
382
- def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
389
+ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
383
390
  """
384
391
  Train the model on the provided datamodule.
385
392
 
386
393
  Parameters
387
394
  ----------
388
- datamodule : CAREamicsTrainData
395
+ datamodule : TrainDataModule
389
396
  Datamodule to train on.
390
397
  """
391
398
  # record datamodule
@@ -421,7 +428,7 @@ class CAREamist:
421
428
  Minimum number of patches to use for validation, by default 5.
422
429
  """
423
430
  # create datamodule
424
- datamodule = CAREamicsTrainData(
431
+ datamodule = TrainDataModule(
425
432
  data_config=self.cfg.data_config,
426
433
  train_data=train_data,
427
434
  val_data=val_data,
@@ -477,7 +484,7 @@ class CAREamist:
477
484
  path_to_val_target = check_path_exists(path_to_val_target)
478
485
 
479
486
  # create datamodule
480
- datamodule = CAREamicsTrainData(
487
+ datamodule = TrainDataModule(
481
488
  data_config=self.cfg.data_config,
482
489
  train_data=path_to_train_data,
483
490
  val_data=path_to_val_data,
@@ -493,10 +500,7 @@ class CAREamist:
493
500
 
494
501
  @overload
495
502
  def predict( # numpydoc ignore=GL08
496
- self,
497
- source: CAREamicsPredictData,
498
- *,
499
- checkpoint: Optional[Literal["best", "last"]] = None,
503
+ self, source: PredictDataModule
500
504
  ) -> Union[list[NDArray], NDArray]: ...
501
505
 
502
506
  @overload
@@ -513,7 +517,6 @@ class CAREamist:
513
517
  dataloader_params: Optional[dict] = None,
514
518
  read_source_func: Optional[Callable] = None,
515
519
  extension_filter: str = "",
516
- checkpoint: Optional[Literal["best", "last"]] = None,
517
520
  ) -> Union[list[NDArray], NDArray]: ...
518
521
 
519
522
  @overload
@@ -528,12 +531,11 @@ class CAREamist:
528
531
  data_type: Optional[Literal["array"]] = None,
529
532
  tta_transforms: bool = True,
530
533
  dataloader_params: Optional[dict] = None,
531
- checkpoint: Optional[Literal["best", "last"]] = None,
532
534
  ) -> Union[list[NDArray], NDArray]: ...
533
535
 
534
536
  def predict(
535
537
  self,
536
- source: Union[CAREamicsPredictData, Path, str, NDArray],
538
+ source: Union[PredictDataModule, Path, str, NDArray],
537
539
  *,
538
540
  batch_size: Optional[int] = None,
539
541
  tile_size: Optional[tuple[int, ...]] = None,
@@ -544,7 +546,6 @@ class CAREamist:
544
546
  dataloader_params: Optional[dict] = None,
545
547
  read_source_func: Optional[Callable] = None,
546
548
  extension_filter: str = "",
547
- checkpoint: Optional[Literal["best", "last"]] = None,
548
549
  **kwargs: Any,
549
550
  ) -> Union[list[NDArray], NDArray]:
550
551
  """
@@ -590,8 +591,6 @@ class CAREamist:
590
591
  Function to read the source data.
591
592
  extension_filter : str, default=""
592
593
  Filter for the file extension.
593
- checkpoint : {"best", "last"}, optional
594
- Checkpoint to use for prediction.
595
594
  **kwargs : Any
596
595
  Unused.
597
596
 
@@ -599,31 +598,64 @@ class CAREamist:
599
598
  -------
600
599
  list of NDArray or NDArray
601
600
  Predictions made by the model.
602
- """
603
- # Reuse batch size if not provided explicitly
604
- if batch_size is None:
605
- batch_size = (
606
- self.train_datamodule.batch_size
607
- if self.train_datamodule
608
- else self.cfg.data_config.batch_size
609
- )
610
601
 
611
- self.pred_datamodule = create_pred_datamodule(
612
- source=source,
613
- config=self.cfg,
614
- batch_size=batch_size,
602
+ Raises
603
+ ------
604
+ ValueError
605
+ If mean and std are not provided in the configuration.
606
+ ValueError
607
+ If tile size is not divisible by 2**depth for UNet models.
608
+ ValueError
609
+ If tile overlap is not specified.
610
+ """
611
+ if (
612
+ self.cfg.data_config.image_means is None
613
+ or self.cfg.data_config.image_stds is None
614
+ ):
615
+ raise ValueError("Mean and std must be provided in the configuration.")
616
+
617
+ # tile size for UNets
618
+ if tile_size is not None:
619
+ model = self.cfg.algorithm_config.model
620
+
621
+ if model.architecture == SupportedArchitecture.UNET.value:
622
+ # tile size must be equal to k*2^n, where n is the number of pooling
623
+ # layers (equal to the depth) and k is an integer
624
+ depth = model.depth
625
+ tile_increment = 2**depth
626
+
627
+ for i, t in enumerate(tile_size):
628
+ if t % tile_increment != 0:
629
+ raise ValueError(
630
+ f"Tile size must be divisible by {tile_increment} along "
631
+ f"all axes (got {t} for axis {i}). If your image size is "
632
+ f"smaller along one axis (e.g. Z), consider padding the "
633
+ f"image."
634
+ )
635
+
636
+ # tile overlaps must be specified
637
+ if tile_overlap is None:
638
+ raise ValueError("Tile overlap must be specified.")
639
+
640
+ # create the prediction
641
+ self.pred_datamodule = create_predict_datamodule(
642
+ pred_data=source,
643
+ data_type=data_type or self.cfg.data_config.data_type,
644
+ axes=axes or self.cfg.data_config.axes,
645
+ image_means=self.cfg.data_config.image_means,
646
+ image_stds=self.cfg.data_config.image_stds,
615
647
  tile_size=tile_size,
616
648
  tile_overlap=tile_overlap,
617
- axes=axes,
618
- data_type=data_type,
649
+ batch_size=batch_size or self.cfg.data_config.batch_size,
619
650
  tta_transforms=tta_transforms,
620
- dataloader_params=dataloader_params,
621
651
  read_source_func=read_source_func,
622
652
  extension_filter=extension_filter,
653
+ dataloader_params=dataloader_params,
623
654
  )
624
655
 
656
+ # predict
625
657
  predictions = self.trainer.predict(
626
- model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint
658
+ model=self.model, datamodule=self.pred_datamodule
627
659
  )
628
660
  return convert_outputs(predictions, self.pred_datamodule.tiled)
629
661
 
@@ -659,27 +691,16 @@ class CAREamist:
659
691
  data_description : str, optional
660
692
  Description of the data, by default None.
661
693
  """
662
- input_patch = reshape_array(input_array, self.cfg.data_config.axes)
663
-
664
- # axes need to be reformated for the export because reshaping was done in the
665
- # datamodule
666
- if "Z" in self.cfg.data_config.axes:
667
- axes = "SCZYX"
668
- else:
669
- axes = "SCYX"
694
+ # TODO: add in docs that it is expected that input_array dimensions match
695
+ # those in data_config
670
696
 
671
- # predict output, remove extra dimensions for the purpose of the prediction
672
697
  output_patch = self.predict(
673
- input_patch,
698
+ input_array,
674
699
  data_type=SupportedData.ARRAY.value,
675
- axes=axes,
676
700
  tta_transforms=False,
677
701
  )
678
-
679
- if isinstance(output_patch, list):
680
- output = np.concatenate(output_patch, axis=0)
681
- else:
682
- output = output_patch
702
+ output = np.concatenate(output_patch, axis=0)
703
+ input_array = reshape_array(input_array, self.cfg.data_config.axes)
683
704
 
684
705
  export_to_bmz(
685
706
  model=self.model,
@@ -688,7 +709,7 @@ class CAREamist:
688
709
  name=name,
689
710
  general_description=general_description,
690
711
  authors=authors,
691
- input_array=input_patch,
712
+ input_array=input_array,
692
713
  output_array=output,
693
714
  channel_names=channel_names,
694
715
  data_description=data_description,
@@ -14,9 +14,7 @@ __all__ = [
14
14
  "create_care_configuration",
15
15
  "register_model",
16
16
  "CustomModel",
17
- "create_inference_configuration",
18
17
  "clear_custom_models",
19
- "ConfigurationInformation",
20
18
  ]
21
19
 
22
20
  from .algorithm_model import AlgorithmConfig
@@ -24,7 +22,6 @@ from .architectures import CustomModel, clear_custom_models, register_model
24
22
  from .callback_model import CheckpointModel
25
23
  from .configuration_factory import (
26
24
  create_care_configuration,
27
- create_inference_configuration,
28
25
  create_n2n_configuration,
29
26
  create_n2v_configuration,
30
27
  )
@@ -93,12 +93,20 @@ class AlgorithmConfig(BaseModel):
93
93
 
94
94
  # Mandatory fields
95
95
  algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
96
+ """Name of the algorithm, as defined in SupportedAlgorithm."""
97
+
96
98
  loss: Literal["n2v", "mae", "mse"]
99
+ """Loss function to use, as defined in SupportedLoss."""
100
+
97
101
  model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
102
+ """Model architecture to use, defined in SupportedArchitecture."""
98
103
 
99
104
  # Optional fields
100
105
  optimizer: OptimizerModel = OptimizerModel()
106
+ """Optimizer to use, defined in SupportedOptimizer."""
107
+
101
108
  lr_scheduler: LrSchedulerModel = LrSchedulerModel()
109
+ """Learning rate scheduler to use, defined in SupportedScheduler."""
102
110
 
103
111
  @model_validator(mode="after")
104
112
  def algorithm_cross_validation(self: Self) -> Self:
@@ -13,6 +13,7 @@ class ArchitectureModel(BaseModel):
13
13
  """
14
14
 
15
15
  architecture: str
16
+ """Name of the architecture."""
16
17
 
17
18
  def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
18
19
  """
@@ -72,9 +72,11 @@ class CustomModel(ArchitectureModel):
72
72
 
73
73
  # discriminator used for choosing the pydantic model in Model
74
74
  architecture: Literal["Custom"]
75
+ """Name of the architecture."""
75
76
 
76
77
  # name of the custom model
77
78
  name: str
79
+ """Name of the custom model."""
78
80
 
79
81
  @field_validator("name")
80
82
  @classmethod
@@ -29,19 +29,38 @@ class UNetModel(ArchitectureModel):
29
29
 
30
30
  # discriminator used for choosing the pydantic model in Model
31
31
  architecture: Literal["UNet"]
32
+ """Name of the architecture."""
32
33
 
33
34
  # parameters
34
35
  # validate_defaults allow ignoring default values in the dump if they were not set
35
36
  conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
37
+ """Dimensions (2D or 3D) of the convolutional layers."""
38
+
36
39
  num_classes: int = Field(default=1, ge=1, validate_default=True)
40
+ """Number of classes or channels in the model output."""
41
+
37
42
  in_channels: int = Field(default=1, ge=1, validate_default=True)
43
+ """Number of channels in the input to the model."""
44
+
38
45
  depth: int = Field(default=2, ge=1, le=10, validate_default=True)
46
+ """Number of levels in the UNet."""
47
+
39
48
  num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
49
+ """Number of convolutional filters in the first layer of the UNet."""
50
+
40
51
  final_activation: Literal[
41
52
  "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
42
53
  ] = Field(default="None", validate_default=True)
54
+ """Final activation function."""
55
+
43
56
  n2v2: bool = Field(default=False, validate_default=True)
57
+ """Whether to use N2V2 architecture modifications, with blur pool layers and fewer
58
+ skip connections.
59
+ """
60
+
44
61
  independent_channels: bool = Field(default=True, validate_default=True)
62
+ """Whether information is processed independently in each channel, used to train
63
+ channels independently."""
45
64
 
46
65
  @field_validator("num_channels_init")
47
66
  @classmethod
@@ -17,6 +17,7 @@ class VAEModel(ArchitectureModel):
17
17
  )
18
18
 
19
19
  architecture: Literal["VAE"]
20
+ """Name of the architecture."""
20
21
 
21
22
  def set_3D(self, is_3D: bool) -> None:
22
23
  """
@@ -13,69 +13,111 @@ from pydantic import (
13
13
 
14
14
 
15
15
  class CheckpointModel(BaseModel):
16
- """Checkpoint saving callback Pydantic model."""
16
+ """Checkpoint saving callback Pydantic model.
17
+
18
+ The parameters corresponds to those of
19
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
20
+
21
+ See:
22
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
23
+ """
17
24
 
18
25
  model_config = ConfigDict(
19
26
  validate_assignment=True,
20
27
  )
21
28
 
22
29
  monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
30
+ """Quantity to monitor."""
31
+
23
32
  verbose: bool = Field(default=False, validate_default=True)
33
+ """Verbosity mode."""
34
+
24
35
  save_weights_only: bool = Field(default=False, validate_default=True)
36
+ """When `True`, only the model's weights will be saved (model.save_weights)."""
37
+
38
+ save_last: Optional[Literal[True, False, "link"]] = Field(
39
+ default=True, validate_default=True
40
+ )
41
+ """When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
42
+
43
+ save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
44
+ """If `save_top_k == kz, the best k models according to the quantity monitored
45
+ will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
46
+ all models are saved."""
47
+
25
48
  mode: Literal["min", "max"] = Field(default="min", validate_default=True)
49
+ """One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
50
+ save file is made based on either the maximization or the minimization of the
51
+ monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
52
+ be 'min', etc.
53
+ """
54
+
26
55
  auto_insert_metric_name: bool = Field(default=False, validate_default=True)
56
+ """When `True`, the checkpoints filenames will contain the metric name."""
57
+
27
58
  every_n_train_steps: Optional[int] = Field(
28
59
  default=None, ge=1, le=10, validate_default=True
29
60
  )
61
+ """Number of training steps between checkpoints."""
62
+
30
63
  train_time_interval: Optional[timedelta] = Field(
31
64
  default=None, validate_default=True
32
65
  )
66
+ """Checkpoints are monitored at the specified time interval."""
67
+
33
68
  every_n_epochs: Optional[int] = Field(
34
69
  default=None, ge=1, le=10, validate_default=True
35
70
  )
36
- save_last: Optional[Literal[True, False, "link"]] = Field(
37
- default=True, validate_default=True
38
- )
39
- save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
71
+ """Number of epochs between checkpoints."""
40
72
 
41
73
 
42
74
  class EarlyStoppingModel(BaseModel):
43
- """Early stopping callback Pydantic model."""
75
+ """Early stopping callback Pydantic model.
76
+
77
+ The parameters corresponds to those of
78
+ `pytorch_lightning.callbacks.ModelCheckpoint`.
79
+
80
+ See:
81
+ https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
82
+ """
44
83
 
45
84
  model_config = ConfigDict(
46
85
  validate_assignment=True,
47
86
  )
48
87
 
49
88
  monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
89
+ """Quantity to monitor."""
90
+
91
+ min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
92
+ """Minimum change in the monitored quantity to qualify as an improvement, i.e. an
93
+ absolute change of less than or equal to min_delta, will count as no improvement."""
94
+
50
95
  patience: int = Field(default=3, ge=1, le=10, validate_default=True)
96
+ """Number of checks with no improvement after which training will be stopped."""
97
+
98
+ verbose: bool = Field(default=False, validate_default=True)
99
+ """Verbosity mode."""
100
+
51
101
  mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
52
- min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
102
+ """One of {min, max, auto}."""
103
+
53
104
  check_finite: bool = Field(default=True, validate_default=True)
54
- stop_on_nan: bool = Field(default=True, validate_default=True)
55
- verbose: bool = Field(default=False, validate_default=True)
56
- restore_best_weights: bool = Field(default=True, validate_default=True)
57
- auto_lr_find: bool = Field(default=False, validate_default=True)
58
- auto_lr_find_patience: int = Field(default=3, ge=1, le=10, validate_default=True)
59
- auto_lr_find_mode: Literal["min", "max", "auto"] = Field(
60
- default="min", validate_default=True
61
- )
62
- auto_lr_find_direction: Literal["forward", "backward"] = Field(
63
- default="backward", validate_default=True
64
- )
65
- auto_lr_find_max_lr: float = Field(
66
- default=10.0, ge=0.0, le=1e6, validate_default=True
67
- )
68
- auto_lr_find_min_lr: float = Field(
69
- default=1e-8, ge=0.0, le=1e6, validate_default=True
70
- )
71
- auto_lr_find_num_training: int = Field(
72
- default=100, ge=1, le=1e6, validate_default=True
73
- )
74
- auto_lr_find_divergence_threshold: float = Field(
75
- default=5.0, ge=0.0, le=1e6, validate_default=True
76
- )
77
- auto_lr_find_accumulate_grad_batches: int = Field(
78
- default=1, ge=1, le=1e6, validate_default=True
105
+ """When `True`, stops training when the monitored quantity becomes `NaN` or
106
+ `inf`."""
107
+
108
+ stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
109
+ """Stop training immediately once the monitored quantity reaches this threshold."""
110
+
111
+ divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
112
+ """Stop training as soon as the monitored quantity becomes worse than this
113
+ threshold."""
114
+
115
+ check_on_train_epoch_end: Optional[bool] = Field(
116
+ default=False, validate_default=True
79
117
  )
80
- auto_lr_find_stop_divergence: bool = Field(default=True, validate_default=True)
81
- auto_lr_find_step_scale: float = Field(default=0.1, ge=0.0, le=10)
118
+ """Whether to run early stopping at the end of the training epoch. If this is
119
+ `False`, then the check runs at the end of the validation."""
120
+
121
+ log_rank_zero_only: bool = Field(default=False, validate_default=True)
122
+ """When set `True`, logs the status of the early stopping callback only for rank 0
123
+ process."""
@@ -1,12 +1,11 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Dict, List, Literal, Optional, Tuple
3
+ from typing import Any, Dict, List, Literal, Optional
4
4
 
5
5
  from .algorithm_model import AlgorithmConfig
6
6
  from .architectures import UNetModel
7
7
  from .configuration_model import Configuration
8
8
  from .data_model import DataConfig
9
- from .inference_model import InferenceConfig
10
9
  from .support import (
11
10
  SupportedAlgorithm,
12
11
  SupportedArchitecture,
@@ -574,80 +573,3 @@ def create_n2v_configuration(
574
573
  )
575
574
 
576
575
  return configuration
577
-
578
-
579
- def create_inference_configuration(
580
- configuration: Configuration,
581
- tile_size: Optional[Tuple[int, ...]] = None,
582
- tile_overlap: Optional[Tuple[int, ...]] = None,
583
- data_type: Optional[Literal["array", "tiff", "custom"]] = None,
584
- axes: Optional[str] = None,
585
- tta_transforms: bool = True,
586
- batch_size: Optional[int] = 1,
587
- ) -> InferenceConfig:
588
- """
589
- Create a configuration for inference with N2V.
590
-
591
- If not provided, `data_type` and `axes` are taken from the training
592
- configuration.
593
-
594
- Parameters
595
- ----------
596
- configuration : Configuration
597
- Global configuration.
598
- tile_size : Tuple[int, ...], optional
599
- Size of the tiles.
600
- tile_overlap : Tuple[int, ...], optional
601
- Overlap of the tiles.
602
- data_type : str, optional
603
- Type of the data, by default "tiff".
604
- axes : str, optional
605
- Axes of the data, by default "YX".
606
- tta_transforms : bool, optional
607
- Whether to apply test-time augmentations, by default True.
608
- batch_size : int, optional
609
- Batch size, by default 1.
610
-
611
- Returns
612
- -------
613
- InferenceConfiguration
614
- Configuration used to configure CAREamicsPredictData.
615
- """
616
- if (
617
- configuration.data_config.image_means is None
618
- or configuration.data_config.image_stds is None
619
- ):
620
- raise ValueError("Mean and std must be provided in the configuration.")
621
-
622
- # tile size for UNets
623
- if tile_size is not None:
624
- model = configuration.algorithm_config.model
625
-
626
- if model.architecture == SupportedArchitecture.UNET.value:
627
- # tile size must be equal to k*2^n, where n is the number of pooling layers
628
- # (equal to the depth) and k is an integer
629
- depth = model.depth
630
- tile_increment = 2**depth
631
-
632
- for i, t in enumerate(tile_size):
633
- if t % tile_increment != 0:
634
- raise ValueError(
635
- f"Tile size must be divisible by {tile_increment} along all "
636
- f"axes (got {t} for axis {i}). If your image size is smaller "
637
- f"along one axis (e.g. Z), consider padding the image."
638
- )
639
-
640
- # tile overlaps must be specified
641
- if tile_overlap is None:
642
- raise ValueError("Tile overlap must be specified.")
643
-
644
- return InferenceConfig(
645
- data_type=data_type or configuration.data_config.data_type,
646
- tile_size=tile_size,
647
- tile_overlap=tile_overlap,
648
- axes=axes or configuration.data_config.axes,
649
- image_means=configuration.data_config.image_means,
650
- image_stds=configuration.data_config.image_stds,
651
- tta_transforms=tta_transforms,
652
- batch_size=batch_size,
653
- )