careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/__init__.py CHANGED
@@ -7,7 +7,22 @@ try:
7
7
  except PackageNotFoundError:
8
8
  __version__ = "uninstalled"
9
9
 
10
- __all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]
10
+ __all__ = [
11
+ "CAREamist",
12
+ "Configuration",
13
+ "algorithm_factory",
14
+ "configuration_factory",
15
+ "data_factory",
16
+ "load_configuration",
17
+ "save_configuration",
18
+ ]
11
19
 
12
20
  from .careamist import CAREamist
13
- from .config import Configuration, load_configuration, save_configuration
21
+ from .config import (
22
+ Configuration,
23
+ algorithm_factory,
24
+ configuration_factory,
25
+ data_factory,
26
+ load_configuration,
27
+ save_configuration,
28
+ )
careamics/careamist.py CHANGED
@@ -11,16 +11,17 @@ from pytorch_lightning.callbacks import (
11
11
  EarlyStopping,
12
12
  ModelCheckpoint,
13
13
  )
14
- from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
14
+ from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
15
15
 
16
- from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
16
+ from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
17
17
  from careamics.config.support import (
18
18
  SupportedAlgorithm,
19
19
  SupportedArchitecture,
20
20
  SupportedData,
21
21
  SupportedLogger,
22
22
  )
23
- from careamics.dataset.dataset_utils import reshape_array
23
+ from careamics.dataset.dataset_utils import list_files, reshape_array
24
+ from careamics.file_io import WriteFunc, get_write_func
24
25
  from careamics.lightning import (
25
26
  FCNModule,
26
27
  HyperParametersCallback,
@@ -32,10 +33,11 @@ from careamics.lightning import (
32
33
  from careamics.model_io import export_to_bmz, load_pretrained
33
34
  from careamics.prediction_utils import convert_outputs
34
35
  from careamics.utils import check_path_exists, get_logger
36
+ from careamics.utils.lightning_utils import read_csv_logger
35
37
 
36
38
  logger = get_logger(__name__)
37
39
 
38
- LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
40
+ LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
39
41
 
40
42
 
41
43
  class CAREamist:
@@ -135,7 +137,7 @@ class CAREamist:
135
137
  self.cfg = source
136
138
 
137
139
  # instantiate model
138
- if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
140
+ if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
139
141
  self.model = FCNModule(
140
142
  algorithm_config=self.cfg.algorithm_config,
141
143
  )
@@ -144,6 +146,7 @@ class CAREamist:
144
146
 
145
147
  # path to configuration file or model
146
148
  else:
149
+ # TODO: update this check so models can be downloaded directly from BMZ
147
150
  source = check_path_exists(source)
148
151
 
149
152
  # configuration file
@@ -154,7 +157,8 @@ class CAREamist:
154
157
  self.cfg = load_configuration(source)
155
158
 
156
159
  # instantiate model
157
- if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
160
+ # TODO call model factory here
161
+ if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
158
162
  self.model = FCNModule(
159
163
  algorithm_config=self.cfg.algorithm_config,
160
164
  ) # type: ignore
@@ -169,18 +173,29 @@ class CAREamist:
169
173
  self._define_callbacks(callbacks)
170
174
 
171
175
  # instantiate logger
176
+ csv_logger = CSVLogger(
177
+ name=self.cfg.experiment_name,
178
+ save_dir=self.work_dir / "csv_logs",
179
+ )
180
+
172
181
  if self.cfg.training_config.has_logger():
173
182
  if self.cfg.training_config.logger == SupportedLogger.WANDB:
174
- self.experiment_logger: LOGGER_TYPES = WandbLogger(
175
- name=self.cfg.experiment_name,
176
- save_dir=self.work_dir / Path("logs"),
177
- )
183
+ experiment_logger: LOGGER_TYPES = [
184
+ WandbLogger(
185
+ name=self.cfg.experiment_name,
186
+ save_dir=self.work_dir / Path("wandb_logs"),
187
+ ),
188
+ csv_logger,
189
+ ]
178
190
  elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
179
- self.experiment_logger = TensorBoardLogger(
180
- save_dir=self.work_dir / Path("logs"),
181
- )
191
+ experiment_logger = [
192
+ TensorBoardLogger(
193
+ save_dir=self.work_dir / Path("tb_logs"),
194
+ ),
195
+ csv_logger,
196
+ ]
182
197
  else:
183
- self.experiment_logger = None
198
+ experiment_logger = [csv_logger]
184
199
 
185
200
  # instantiate trainer
186
201
  self.trainer = Trainer(
@@ -194,7 +209,7 @@ class CAREamist:
194
209
  gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
195
210
  callbacks=self.callbacks,
196
211
  default_root_dir=self.work_dir,
197
- logger=self.experiment_logger,
212
+ logger=experiment_logger,
198
213
  )
199
214
 
200
215
  # place holder for the datamodules
@@ -519,7 +534,7 @@ class CAREamist:
519
534
  *,
520
535
  batch_size: int = 1,
521
536
  tile_size: Optional[tuple[int, ...]] = None,
522
- tile_overlap: tuple[int, ...] = (48, 48),
537
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
523
538
  axes: Optional[str] = None,
524
539
  data_type: Optional[Literal["tiff", "custom"]] = None,
525
540
  tta_transforms: bool = False,
@@ -535,7 +550,7 @@ class CAREamist:
535
550
  *,
536
551
  batch_size: int = 1,
537
552
  tile_size: Optional[tuple[int, ...]] = None,
538
- tile_overlap: tuple[int, ...] = (48, 48),
553
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
539
554
  axes: Optional[str] = None,
540
555
  data_type: Optional[Literal["array"]] = None,
541
556
  tta_transforms: bool = False,
@@ -546,7 +561,7 @@ class CAREamist:
546
561
  self,
547
562
  source: Union[PredictDataModule, Path, str, NDArray],
548
563
  *,
549
- batch_size: Optional[int] = None,
564
+ batch_size: int = 1,
550
565
  tile_size: Optional[tuple[int, ...]] = None,
551
566
  tile_overlap: Optional[tuple[int, ...]] = (48, 48),
552
567
  axes: Optional[str] = None,
@@ -567,7 +582,7 @@ class CAREamist:
567
582
  configuration parameters will be used, with the `patch_size` instead of
568
583
  `tile_size`.
569
584
 
570
- Test-time augmentation (TTA) can be switched off using the `tta_transforms`
585
+ Test-time augmentation (TTA) can be switched on using the `tta_transforms`
571
586
  parameter. The TTA augmentation applies all possible flip and 90 degrees
572
587
  rotations to the prediction input and averages the predictions. TTA augmentation
573
588
  should not be used if you did not train with these augmentations.
@@ -580,7 +595,7 @@ class CAREamist:
580
595
 
581
596
  Parameters
582
597
  ----------
583
- source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
598
+ source : PredictDataModule, pathlib.Path, str or numpy.ndarray
584
599
  Data to predict on.
585
600
  batch_size : int, default=1
586
601
  Batch size for prediction.
@@ -668,15 +683,195 @@ class CAREamist:
668
683
  )
669
684
  return convert_outputs(predictions, self.pred_datamodule.tiled)
670
685
 
686
+ def predict_to_disk(
687
+ self,
688
+ source: Union[PredictDataModule, Path, str],
689
+ *,
690
+ batch_size: int = 1,
691
+ tile_size: Optional[tuple[int, ...]] = None,
692
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
693
+ axes: Optional[str] = None,
694
+ data_type: Optional[Literal["tiff", "custom"]] = None,
695
+ tta_transforms: bool = False,
696
+ dataloader_params: Optional[dict] = None,
697
+ read_source_func: Optional[Callable] = None,
698
+ extension_filter: str = "",
699
+ write_type: Literal["tiff", "custom"] = "tiff",
700
+ write_extension: Optional[str] = None,
701
+ write_func: Optional[WriteFunc] = None,
702
+ write_func_kwargs: Optional[dict[str, Any]] = None,
703
+ prediction_dir: Union[Path, str] = "predictions",
704
+ **kwargs,
705
+ ) -> None:
706
+ """
707
+ Make predictions on the provided data and save outputs to files.
708
+
709
+ The predictions will be saved in a new directory 'predictions' within the set
710
+ working directory. The directory stucture within the 'predictions' directory
711
+ will match that of the source directory.
712
+
713
+ The `source` must be from files and not arrays. The file names of the
714
+ predictions will match those of the source. If there is more than one sample
715
+ within a file, the samples will be saved to seperate files. The file names of
716
+ samples will have the name of the corresponding source file but with the sample
717
+ index appended. E.g. If the the source file name is 'images.tiff' then the first
718
+ sample's prediction will be saved with the file name "image_0.tiff".
719
+ Input can be a PredictDataModule instance, a path to a data file, or a numpy
720
+ array.
721
+
722
+ If `data_type`, `axes` and `tile_size` are not provided, the training
723
+ configuration parameters will be used, with the `patch_size` instead of
724
+ `tile_size`.
725
+
726
+ Test-time augmentation (TTA) can be switched on using the `tta_transforms`
727
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
728
+ rotations to the prediction input and averages the predictions. TTA augmentation
729
+ should not be used if you did not train with these augmentations.
730
+
731
+ Note that if you are using a UNet model and tiling, the tile size must be
732
+ divisible in every dimension by 2**d, where d is the depth of the model. This
733
+ avoids artefacts arising from the broken shift invariance induced by the
734
+ pooling layers of the UNet. If your image has less dimensions, as it may
735
+ happen in the Z dimension, consider padding your image.
736
+
737
+ Parameters
738
+ ----------
739
+ source : PredictDataModule or pathlib.Path, str
740
+ Data to predict on.
741
+ batch_size : int, default=1
742
+ Batch size for prediction.
743
+ tile_size : tuple of int, optional
744
+ Size of the tiles to use for prediction.
745
+ tile_overlap : tuple of int, default=(48, 48)
746
+ Overlap between tiles.
747
+ axes : str, optional
748
+ Axes of the input data, by default None.
749
+ data_type : {"array", "tiff", "custom"}, optional
750
+ Type of the input data.
751
+ tta_transforms : bool, default=True
752
+ Whether to apply test-time augmentation.
753
+ dataloader_params : dict, optional
754
+ Parameters to pass to the dataloader.
755
+ read_source_func : Callable, optional
756
+ Function to read the source data.
757
+ extension_filter : str, default=""
758
+ Filter for the file extension.
759
+ write_type : {"tiff", "custom"}, default="tiff"
760
+ The data type to save as, includes custom.
761
+ write_extension : str, optional
762
+ If a known `write_type` is selected this argument is ignored. For a custom
763
+ `write_type` an extension to save the data with must be passed.
764
+ write_func : WriteFunc, optional
765
+ If a known `write_type` is selected this argument is ignored. For a custom
766
+ `write_type` a function to save the data must be passed. See notes below.
767
+ write_func_kwargs : dict of {str: any}, optional
768
+ Additional keyword arguments to be passed to the save function.
769
+ prediction_dir : Path | str, default="predictions"
770
+ The path to save the prediction results to. If `prediction_dir` is not
771
+ absolute, the directory will be assumed to be relative to the pre-set
772
+ `work_dir`. If the directory does not exist it will be created.
773
+ **kwargs : Any
774
+ Unused.
775
+
776
+ Raises
777
+ ------
778
+ ValueError
779
+ If `write_type` is custom and `write_extension` is None.
780
+ ValueError
781
+ If `write_type` is custom and `write_fun is None.
782
+ ValueError
783
+ If `source` is not `str`, `Path` or `PredictDataModule`
784
+ """
785
+ if write_func_kwargs is None:
786
+ write_func_kwargs = {}
787
+
788
+ if Path(prediction_dir).is_absolute():
789
+ write_dir = Path(prediction_dir)
790
+ else:
791
+ write_dir = self.work_dir / prediction_dir
792
+ write_dir.mkdir(exist_ok=True, parents=True)
793
+
794
+ # guards for custom types
795
+ if write_type == SupportedData.CUSTOM:
796
+ if write_extension is None:
797
+ raise ValueError(
798
+ "A `write_extension` must be provided for custom write types."
799
+ )
800
+ if write_func is None:
801
+ raise ValueError(
802
+ "A `write_func` must be provided for custom write types."
803
+ )
804
+ else:
805
+ write_func = get_write_func(write_type)
806
+ write_extension = SupportedData.get_extension(write_type)
807
+
808
+ # extract file names
809
+ source_path: Union[Path, str, NDArray]
810
+ source_data_type: Literal["array", "tiff", "custom"]
811
+ if isinstance(source, PredictDataModule):
812
+ source_path = source.pred_data
813
+ source_data_type = source.data_type
814
+ extension_filter = source.extension_filter
815
+ elif isinstance(source, (str, Path)):
816
+ source_path = source
817
+ source_data_type = data_type or self.cfg.data_config.data_type
818
+ extension_filter = SupportedData.get_extension_pattern(
819
+ SupportedData(source_data_type)
820
+ )
821
+ else:
822
+ raise ValueError(f"Unsupported source type: '{type(source)}'.")
823
+
824
+ if source_data_type == "array":
825
+ raise ValueError(
826
+ "Predicting to disk is not supported for input type 'array'."
827
+ )
828
+ assert isinstance(source_path, (Path, str)) # because data_type != "array"
829
+ source_path = Path(source_path)
830
+
831
+ file_paths = list_files(source_path, source_data_type, extension_filter)
832
+
833
+ # predict and write each file in turn
834
+ for file_path in file_paths:
835
+ # source_path is relative to original source path...
836
+ # should mirror original directory structure
837
+ prediction = self.predict(
838
+ source=file_path,
839
+ batch_size=batch_size,
840
+ tile_size=tile_size,
841
+ tile_overlap=tile_overlap,
842
+ axes=axes,
843
+ data_type=data_type,
844
+ tta_transforms=tta_transforms,
845
+ dataloader_params=dataloader_params,
846
+ read_source_func=read_source_func,
847
+ extension_filter=extension_filter,
848
+ **kwargs,
849
+ )
850
+ # TODO: cast to float16?
851
+ write_data = np.concatenate(prediction)
852
+
853
+ # create directory structure and write path
854
+ if not source_path.is_file():
855
+ file_write_dir = write_dir / file_path.parent.relative_to(source_path)
856
+ else:
857
+ file_write_dir = write_dir
858
+ file_write_dir.mkdir(parents=True, exist_ok=True)
859
+ write_path = (file_write_dir / file_path.name).with_suffix(write_extension)
860
+
861
+ # write data
862
+ write_func(file_path=write_path, img=write_data)
863
+
671
864
  def export_to_bmz(
672
865
  self,
673
866
  path_to_archive: Union[Path, str],
674
867
  friendly_model_name: str,
675
868
  input_array: NDArray,
676
869
  authors: list[dict],
677
- general_description: str = "",
870
+ general_description: str,
871
+ data_description: str,
872
+ covers: Optional[list[Union[Path, str]]] = None,
678
873
  channel_names: Optional[list[str]] = None,
679
- data_description: Optional[str] = None,
874
+ model_version: str = "0.1.0",
680
875
  ) -> None:
681
876
  """Export the model to the BioImage Model Zoo format.
682
877
 
@@ -706,11 +901,15 @@ class CAREamist:
706
901
  authors : list of dict
707
902
  List of authors of the model.
708
903
  general_description : str
709
- General description of the model, used in the metadata of the BMZ archive.
710
- channel_names : list of str, optional
711
- Channel names, by default None.
712
- data_description : str, optional
713
- Description of the data, by default None.
904
+ General description of the model used in the BMZ metadata.
905
+ data_description : str
906
+ Description of the data the model was trained on.
907
+ covers : list of pathlib.Path or str, default=None
908
+ Paths to the cover images.
909
+ channel_names : list of str, default=None
910
+ Channel names.
911
+ model_version : str, default="0.1.0"
912
+ Version of the model.
714
913
  """
715
914
  # TODO: add in docs that it is expected that input_array dimensions match
716
915
  # those in data_config
@@ -729,9 +928,21 @@ class CAREamist:
729
928
  path_to_archive=path_to_archive,
730
929
  model_name=friendly_model_name,
731
930
  general_description=general_description,
931
+ data_description=data_description,
732
932
  authors=authors,
733
933
  input_array=input_array,
734
934
  output_array=output,
935
+ covers=covers,
735
936
  channel_names=channel_names,
736
- data_description=data_description,
937
+ model_version=model_version,
737
938
  )
939
+
940
+ def get_losses(self) -> dict[str, list]:
941
+ """Return data that can be used to plot train and validation loss curves.
942
+
943
+ Returns
944
+ -------
945
+ dict of str: list
946
+ Dictionary containing the losses for each epoch.
947
+ """
948
+ return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")
careamics/cli/conf.py CHANGED
@@ -3,12 +3,11 @@
3
3
  import sys
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
- from typing import Tuple
6
+ from typing import Annotated, Optional
7
7
 
8
8
  import click
9
9
  import typer
10
10
  import yaml
11
- from typing_extensions import Annotated
12
11
 
13
12
  from ..config import (
14
13
  Configuration,
@@ -17,6 +16,7 @@ from ..config import (
17
16
  create_n2v_configuration,
18
17
  save_configuration,
19
18
  )
19
+ from .utils import handle_2D_3D_callback
20
20
 
21
21
  WORK_DIR = Path.cwd()
22
22
 
@@ -92,26 +92,6 @@ def conf_options( # numpydoc ignore=PR01
92
92
  ctx.obj = ConfOptions(dir, name, force, print)
93
93
 
94
94
 
95
- def patch_size_callback(value: Tuple[int, int, int]) -> Tuple[int, ...]:
96
- """
97
- Callback for --patch-size option.
98
-
99
- Parameters
100
- ----------
101
- value : (int, int, int)
102
- Patch size value.
103
-
104
- Returns
105
- -------
106
- (int, int, int) | (int, int)
107
- If the last element in `value` is -1 the tuple is reduced to the first two
108
- values.
109
- """
110
- if value[2] == -1:
111
- return value[:2]
112
- return value
113
-
114
-
115
95
  # TODO: Need to decide how to parse model kwargs
116
96
  # - Could be json style string to be loaded as dict e.g. {"depth": 3}
117
97
  # - Cons: Annoying to type, easily have syntax errors
@@ -132,7 +112,7 @@ def care( # numpydoc ignore=PR01
132
112
  "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
133
113
  ),
134
114
  click_type=click.Tuple([int, int, int]),
135
- callback=patch_size_callback,
115
+ callback=handle_2D_3D_callback,
136
116
  ),
137
117
  ],
138
118
  batch_size: Annotated[int, typer.Option(help="Batch size.")],
@@ -154,8 +134,12 @@ def care( # numpydoc ignore=PR01
154
134
  help="Loss function to use.",
155
135
  ),
156
136
  ] = "mae",
157
- n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
158
- n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
137
+ n_channels_in: Annotated[
138
+ Optional[int], typer.Option(help="Number of channels in")
139
+ ] = None,
140
+ n_channels_out: Annotated[
141
+ Optional[int], typer.Option(help="Number of channels out")
142
+ ] = None,
159
143
  logger: Annotated[
160
144
  click.Choice,
161
145
  typer.Option(
@@ -215,7 +199,7 @@ def n2n( # numpydoc ignore=PR01
215
199
  "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
216
200
  ),
217
201
  click_type=click.Tuple([int, int, int]),
218
- callback=patch_size_callback,
202
+ callback=handle_2D_3D_callback,
219
203
  ),
220
204
  ],
221
205
  batch_size: Annotated[int, typer.Option(help="Batch size.")],
@@ -237,8 +221,12 @@ def n2n( # numpydoc ignore=PR01
237
221
  help="Loss function to use.",
238
222
  ),
239
223
  ] = "mae",
240
- n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
241
- n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
224
+ n_channels_in: Annotated[
225
+ Optional[int], typer.Option(help="Number of channels in")
226
+ ] = None,
227
+ n_channels_out: Annotated[
228
+ Optional[int], typer.Option(help="Number of channels out")
229
+ ] = None,
242
230
  logger: Annotated[
243
231
  click.Choice,
244
232
  typer.Option(
@@ -295,7 +283,7 @@ def n2v( # numpydoc ignore=PR01
295
283
  "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
296
284
  ),
297
285
  click_type=click.Tuple([int, int, int]),
298
- callback=patch_size_callback,
286
+ callback=handle_2D_3D_callback,
299
287
  ),
300
288
  ],
301
289
  batch_size: Annotated[int, typer.Option(help="Batch size.")],
@@ -312,8 +300,8 @@ def n2v( # numpydoc ignore=PR01
312
300
  ] = True,
313
301
  use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
314
302
  n_channels: Annotated[
315
- int, typer.Option(help="Number of channels (in and out)")
316
- ] = 1,
303
+ Optional[int], typer.Option(help="Number of channels (in and out)")
304
+ ] = None,
317
305
  roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
318
306
  masked_pixel_percentage: Annotated[
319
307
  float, typer.Option(help="Percentage of pixels masked in each patch.")
careamics/cli/main.py CHANGED
@@ -7,23 +7,21 @@ its implementation is contained in the conf.py file.
7
7
  """
8
8
 
9
9
  from pathlib import Path
10
- from typing import Optional
10
+ from typing import Annotated, Optional
11
11
 
12
+ import click
12
13
  import typer
13
- from typing_extensions import Annotated
14
14
 
15
15
  from ..careamist import CAREamist
16
16
  from . import conf
17
+ from .utils import handle_2D_3D_callback
17
18
 
18
19
  app = typer.Typer(
19
20
  help="Run CAREamics algorithms from the command line, including Noise2Void "
20
- "and its many variants and cousins"
21
- )
22
- app.add_typer(
23
- conf.app,
24
- name="conf",
25
- # callback=conf.conf_options
21
+ "and its many variants and cousins",
22
+ pretty_exceptions_show_locals=False,
26
23
  )
24
+ app.add_typer(conf.app, name="conf")
27
25
 
28
26
 
29
27
  @app.command()
@@ -102,7 +100,7 @@ def train( # numpydoc ignore=PR01
102
100
  typer.Option(
103
101
  "--work-dir",
104
102
  "-wd",
105
- help=("Path to working directory in which to save checkpoints and " "logs"),
103
+ help=("Path to working directory in which to save checkpoints and logs"),
106
104
  exists=True,
107
105
  file_okay=False,
108
106
  dir_okay=True,
@@ -123,10 +121,112 @@ def train( # numpydoc ignore=PR01
123
121
 
124
122
 
125
123
  @app.command()
126
- def predict(): # numpydoc ignore=PR01
124
+ def predict( # numpydoc ignore=PR01
125
+ model: Annotated[
126
+ Path,
127
+ typer.Argument(
128
+ help="Path to a configuration file or a trained model.",
129
+ exists=True,
130
+ file_okay=True,
131
+ dir_okay=False,
132
+ ),
133
+ ],
134
+ source: Annotated[
135
+ Path,
136
+ typer.Argument(
137
+ help="Path to the training data. Can be a directory or single file.",
138
+ exists=True,
139
+ file_okay=True,
140
+ dir_okay=True,
141
+ ),
142
+ ],
143
+ batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
144
+ tile_size: Annotated[
145
+ Optional[click.Tuple],
146
+ typer.Option(
147
+ help=(
148
+ "Size of the tiles to use for prediction, (if the data "
149
+ "is not 3D pass the last value as -1 e.g. --tile_size 64 64 -1)."
150
+ ),
151
+ click_type=click.Tuple([int, int, int]),
152
+ callback=handle_2D_3D_callback,
153
+ ),
154
+ ] = None,
155
+ tile_overlap: Annotated[
156
+ click.Tuple,
157
+ typer.Option(
158
+ help=(
159
+ "Overlap between tiles, (if the data is not 3D pass the last value as "
160
+ "-1 e.g. --tile_overlap 64 64 -1)."
161
+ ),
162
+ click_type=click.Tuple([int, int, int]),
163
+ callback=handle_2D_3D_callback,
164
+ ),
165
+ ] = (48, 48, -1),
166
+ axes: Annotated[
167
+ Optional[str],
168
+ typer.Option(
169
+ help="Axes of the input data. If unused the data is assumed to have the "
170
+ "same axes as the original training data."
171
+ ),
172
+ ] = None,
173
+ data_type: Annotated[
174
+ click.Choice,
175
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the input data."),
176
+ ] = "tiff",
177
+ tta_transforms: Annotated[
178
+ bool,
179
+ typer.Option(
180
+ "--tta-transforms/--no-tta-transforms",
181
+ "-t/-T",
182
+ help="Whether to apply test-time augmentation.",
183
+ ),
184
+ ] = False,
185
+ write_type: Annotated[
186
+ click.Choice,
187
+ typer.Option(
188
+ click_type=click.Choice(["tiff"]), help="Type of the output data."
189
+ ),
190
+ ] = "tiff",
191
+ # TODO: could make dataloader_params as json, necessary?
192
+ work_dir: Annotated[
193
+ Optional[Path],
194
+ typer.Option(
195
+ "--work-dir",
196
+ "-wd",
197
+ help=("Path to working directory."),
198
+ exists=True,
199
+ file_okay=False,
200
+ dir_okay=True,
201
+ ),
202
+ ] = None,
203
+ prediction_dir: Annotated[
204
+ Path,
205
+ typer.Option(
206
+ "--prediction-dir",
207
+ "-pd",
208
+ help=(
209
+ "Directory to save predictions to. If not an abosulte path it will be "
210
+ "relative to the set working directory."
211
+ ),
212
+ file_okay=False,
213
+ dir_okay=True,
214
+ ),
215
+ ] = Path("predictions"),
216
+ ):
127
217
  """Create and save predictions from CAREamics models."""
128
- # TODO: Need a save predict to workdir function
129
- raise NotImplementedError
218
+ engine = CAREamist(source=model, work_dir=work_dir)
219
+ engine.predict_to_disk(
220
+ source=source,
221
+ batch_size=batch_size,
222
+ tile_size=tile_size,
223
+ tile_overlap=tile_overlap,
224
+ axes=axes,
225
+ data_type=data_type,
226
+ tta_transforms=tta_transforms,
227
+ write_type=write_type,
228
+ prediction_dir=prediction_dir,
229
+ )
130
230
 
131
231
 
132
232
  def run():