careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (43) hide show
  1. careamics/careamist.py +235 -25
  2. careamics/cli/conf.py +19 -30
  3. careamics/cli/main.py +111 -10
  4. careamics/cli/utils.py +29 -0
  5. careamics/config/__init__.py +2 -0
  6. careamics/config/architectures/lvae_model.py +104 -21
  7. careamics/config/configuration_factory.py +49 -45
  8. careamics/config/configuration_model.py +2 -2
  9. careamics/config/likelihood_model.py +7 -6
  10. careamics/config/loss_model.py +56 -0
  11. careamics/config/nm_model.py +24 -24
  12. careamics/config/vae_algorithm_model.py +14 -13
  13. careamics/dataset/dataset_utils/running_stats.py +22 -23
  14. careamics/lightning/lightning_module.py +58 -27
  15. careamics/lightning/train_data_module.py +15 -1
  16. careamics/losses/loss_factory.py +1 -85
  17. careamics/losses/lvae/losses.py +223 -164
  18. careamics/lvae_training/calibration.py +184 -0
  19. careamics/lvae_training/dataset/config.py +2 -2
  20. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  21. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  22. careamics/lvae_training/dataset/types.py +15 -26
  23. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  24. careamics/lvae_training/eval_utils.py +125 -213
  25. careamics/model_io/bioimage/_readme_factory.py +25 -33
  26. careamics/model_io/bioimage/cover_factory.py +171 -0
  27. careamics/model_io/bioimage/model_description.py +39 -17
  28. careamics/model_io/bmz_io.py +36 -25
  29. careamics/models/layers.py +6 -4
  30. careamics/models/lvae/layers.py +348 -975
  31. careamics/models/lvae/likelihoods.py +10 -8
  32. careamics/models/lvae/lvae.py +214 -272
  33. careamics/models/lvae/noise_models.py +179 -112
  34. careamics/models/lvae/stochastic.py +393 -0
  35. careamics/models/lvae/utils.py +82 -73
  36. careamics/utils/lightning_utils.py +57 -0
  37. careamics/utils/serializers.py +2 -0
  38. careamics/utils/torch_utils.py +1 -1
  39. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
  40. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
  41. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
  42. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
  43. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -11,7 +11,7 @@ from pytorch_lightning.callbacks import (
11
11
  EarlyStopping,
12
12
  ModelCheckpoint,
13
13
  )
14
- from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
14
+ from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
15
15
 
16
16
  from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
17
17
  from careamics.config.support import (
@@ -20,7 +20,8 @@ from careamics.config.support import (
20
20
  SupportedData,
21
21
  SupportedLogger,
22
22
  )
23
- from careamics.dataset.dataset_utils import reshape_array
23
+ from careamics.dataset.dataset_utils import list_files, reshape_array
24
+ from careamics.file_io import WriteFunc, get_write_func
24
25
  from careamics.lightning import (
25
26
  FCNModule,
26
27
  HyperParametersCallback,
@@ -32,10 +33,11 @@ from careamics.lightning import (
32
33
  from careamics.model_io import export_to_bmz, load_pretrained
33
34
  from careamics.prediction_utils import convert_outputs
34
35
  from careamics.utils import check_path_exists, get_logger
36
+ from careamics.utils.lightning_utils import read_csv_logger
35
37
 
36
38
  logger = get_logger(__name__)
37
39
 
38
- LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
40
+ LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
39
41
 
40
42
 
41
43
  class CAREamist:
@@ -144,6 +146,7 @@ class CAREamist:
144
146
 
145
147
  # path to configuration file or model
146
148
  else:
149
+ # TODO: update this check so models can be downloaded directly from BMZ
147
150
  source = check_path_exists(source)
148
151
 
149
152
  # configuration file
@@ -169,18 +172,29 @@ class CAREamist:
169
172
  self._define_callbacks(callbacks)
170
173
 
171
174
  # instantiate logger
175
+ csv_logger = CSVLogger(
176
+ name=self.cfg.experiment_name,
177
+ save_dir=self.work_dir / "csv_logs",
178
+ )
179
+
172
180
  if self.cfg.training_config.has_logger():
173
181
  if self.cfg.training_config.logger == SupportedLogger.WANDB:
174
- self.experiment_logger: LOGGER_TYPES = WandbLogger(
175
- name=self.cfg.experiment_name,
176
- save_dir=self.work_dir / Path("logs"),
177
- )
182
+ experiment_logger: LOGGER_TYPES = [
183
+ WandbLogger(
184
+ name=self.cfg.experiment_name,
185
+ save_dir=self.work_dir / Path("wandb_logs"),
186
+ ),
187
+ csv_logger,
188
+ ]
178
189
  elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
179
- self.experiment_logger = TensorBoardLogger(
180
- save_dir=self.work_dir / Path("logs"),
181
- )
190
+ experiment_logger = [
191
+ TensorBoardLogger(
192
+ save_dir=self.work_dir / Path("tb_logs"),
193
+ ),
194
+ csv_logger,
195
+ ]
182
196
  else:
183
- self.experiment_logger = None
197
+ experiment_logger = [csv_logger]
184
198
 
185
199
  # instantiate trainer
186
200
  self.trainer = Trainer(
@@ -194,7 +208,7 @@ class CAREamist:
194
208
  gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
195
209
  callbacks=self.callbacks,
196
210
  default_root_dir=self.work_dir,
197
- logger=self.experiment_logger,
211
+ logger=experiment_logger,
198
212
  )
199
213
 
200
214
  # place holder for the datamodules
@@ -519,7 +533,7 @@ class CAREamist:
519
533
  *,
520
534
  batch_size: int = 1,
521
535
  tile_size: Optional[tuple[int, ...]] = None,
522
- tile_overlap: tuple[int, ...] = (48, 48),
536
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
523
537
  axes: Optional[str] = None,
524
538
  data_type: Optional[Literal["tiff", "custom"]] = None,
525
539
  tta_transforms: bool = False,
@@ -535,7 +549,7 @@ class CAREamist:
535
549
  *,
536
550
  batch_size: int = 1,
537
551
  tile_size: Optional[tuple[int, ...]] = None,
538
- tile_overlap: tuple[int, ...] = (48, 48),
552
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
539
553
  axes: Optional[str] = None,
540
554
  data_type: Optional[Literal["array"]] = None,
541
555
  tta_transforms: bool = False,
@@ -546,7 +560,7 @@ class CAREamist:
546
560
  self,
547
561
  source: Union[PredictDataModule, Path, str, NDArray],
548
562
  *,
549
- batch_size: Optional[int] = None,
563
+ batch_size: int = 1,
550
564
  tile_size: Optional[tuple[int, ...]] = None,
551
565
  tile_overlap: Optional[tuple[int, ...]] = (48, 48),
552
566
  axes: Optional[str] = None,
@@ -567,7 +581,7 @@ class CAREamist:
567
581
  configuration parameters will be used, with the `patch_size` instead of
568
582
  `tile_size`.
569
583
 
570
- Test-time augmentation (TTA) can be switched off using the `tta_transforms`
584
+ Test-time augmentation (TTA) can be switched on using the `tta_transforms`
571
585
  parameter. The TTA augmentation applies all possible flip and 90 degrees
572
586
  rotations to the prediction input and averages the predictions. TTA augmentation
573
587
  should not be used if you did not train with these augmentations.
@@ -580,7 +594,7 @@ class CAREamist:
580
594
 
581
595
  Parameters
582
596
  ----------
583
- source : CAREamicsPredData, pathlib.Path, str or numpy.ndarray
597
+ source : PredictDataModule, pathlib.Path, str or numpy.ndarray
584
598
  Data to predict on.
585
599
  batch_size : int, default=1
586
600
  Batch size for prediction.
@@ -668,15 +682,195 @@ class CAREamist:
668
682
  )
669
683
  return convert_outputs(predictions, self.pred_datamodule.tiled)
670
684
 
685
+ def predict_to_disk(
686
+ self,
687
+ source: Union[PredictDataModule, Path, str],
688
+ *,
689
+ batch_size: int = 1,
690
+ tile_size: Optional[tuple[int, ...]] = None,
691
+ tile_overlap: Optional[tuple[int, ...]] = (48, 48),
692
+ axes: Optional[str] = None,
693
+ data_type: Optional[Literal["tiff", "custom"]] = None,
694
+ tta_transforms: bool = False,
695
+ dataloader_params: Optional[dict] = None,
696
+ read_source_func: Optional[Callable] = None,
697
+ extension_filter: str = "",
698
+ write_type: Literal["tiff", "custom"] = "tiff",
699
+ write_extension: Optional[str] = None,
700
+ write_func: Optional[WriteFunc] = None,
701
+ write_func_kwargs: Optional[dict[str, Any]] = None,
702
+ prediction_dir: Union[Path, str] = "predictions",
703
+ **kwargs,
704
+ ) -> None:
705
+ """
706
+ Make predictions on the provided data and save outputs to files.
707
+
708
+ The predictions will be saved in a new directory 'predictions' within the set
709
+ working directory. The directory stucture within the 'predictions' directory
710
+ will match that of the source directory.
711
+
712
+ The `source` must be from files and not arrays. The file names of the
713
+ predictions will match those of the source. If there is more than one sample
714
+ within a file, the samples will be saved to seperate files. The file names of
715
+ samples will have the name of the corresponding source file but with the sample
716
+ index appended. E.g. If the the source file name is 'images.tiff' then the first
717
+ sample's prediction will be saved with the file name "image_0.tiff".
718
+ Input can be a PredictDataModule instance, a path to a data file, or a numpy
719
+ array.
720
+
721
+ If `data_type`, `axes` and `tile_size` are not provided, the training
722
+ configuration parameters will be used, with the `patch_size` instead of
723
+ `tile_size`.
724
+
725
+ Test-time augmentation (TTA) can be switched on using the `tta_transforms`
726
+ parameter. The TTA augmentation applies all possible flip and 90 degrees
727
+ rotations to the prediction input and averages the predictions. TTA augmentation
728
+ should not be used if you did not train with these augmentations.
729
+
730
+ Note that if you are using a UNet model and tiling, the tile size must be
731
+ divisible in every dimension by 2**d, where d is the depth of the model. This
732
+ avoids artefacts arising from the broken shift invariance induced by the
733
+ pooling layers of the UNet. If your image has less dimensions, as it may
734
+ happen in the Z dimension, consider padding your image.
735
+
736
+ Parameters
737
+ ----------
738
+ source : PredictDataModule or pathlib.Path, str
739
+ Data to predict on.
740
+ batch_size : int, default=1
741
+ Batch size for prediction.
742
+ tile_size : tuple of int, optional
743
+ Size of the tiles to use for prediction.
744
+ tile_overlap : tuple of int, default=(48, 48)
745
+ Overlap between tiles.
746
+ axes : str, optional
747
+ Axes of the input data, by default None.
748
+ data_type : {"array", "tiff", "custom"}, optional
749
+ Type of the input data.
750
+ tta_transforms : bool, default=True
751
+ Whether to apply test-time augmentation.
752
+ dataloader_params : dict, optional
753
+ Parameters to pass to the dataloader.
754
+ read_source_func : Callable, optional
755
+ Function to read the source data.
756
+ extension_filter : str, default=""
757
+ Filter for the file extension.
758
+ write_type : {"tiff", "custom"}, default="tiff"
759
+ The data type to save as, includes custom.
760
+ write_extension : str, optional
761
+ If a known `write_type` is selected this argument is ignored. For a custom
762
+ `write_type` an extension to save the data with must be passed.
763
+ write_func : WriteFunc, optional
764
+ If a known `write_type` is selected this argument is ignored. For a custom
765
+ `write_type` a function to save the data must be passed. See notes below.
766
+ write_func_kwargs : dict of {str: any}, optional
767
+ Additional keyword arguments to be passed to the save function.
768
+ prediction_dir : Path | str, default="predictions"
769
+ The path to save the prediction results to. If `prediction_dir` is not
770
+ absolute, the directory will be assumed to be relative to the pre-set
771
+ `work_dir`. If the directory does not exist it will be created.
772
+ **kwargs : Any
773
+ Unused.
774
+
775
+ Raises
776
+ ------
777
+ ValueError
778
+ If `write_type` is custom and `write_extension` is None.
779
+ ValueError
780
+ If `write_type` is custom and `write_fun is None.
781
+ ValueError
782
+ If `source` is not `str`, `Path` or `PredictDataModule`
783
+ """
784
+ if write_func_kwargs is None:
785
+ write_func_kwargs = {}
786
+
787
+ if Path(prediction_dir).is_absolute():
788
+ write_dir = Path(prediction_dir)
789
+ else:
790
+ write_dir = self.work_dir / prediction_dir
791
+ write_dir.mkdir(exist_ok=True, parents=True)
792
+
793
+ # guards for custom types
794
+ if write_type == SupportedData.CUSTOM:
795
+ if write_extension is None:
796
+ raise ValueError(
797
+ "A `write_extension` must be provided for custom write types."
798
+ )
799
+ if write_func is None:
800
+ raise ValueError(
801
+ "A `write_func` must be provided for custom write types."
802
+ )
803
+ else:
804
+ write_func = get_write_func(write_type)
805
+ write_extension = SupportedData.get_extension(write_type)
806
+
807
+ # extract file names
808
+ source_path: Union[Path, str, NDArray]
809
+ source_data_type: Literal["array", "tiff", "custom"]
810
+ if isinstance(source, PredictDataModule):
811
+ source_path = source.pred_data
812
+ source_data_type = source.data_type
813
+ extension_filter = source.extension_filter
814
+ elif isinstance(source, (str, Path)):
815
+ source_path = source
816
+ source_data_type = data_type or self.cfg.data_config.data_type
817
+ extension_filter = SupportedData.get_extension_pattern(
818
+ SupportedData(source_data_type)
819
+ )
820
+ else:
821
+ raise ValueError(f"Unsupported source type: '{type(source)}'.")
822
+
823
+ if source_data_type == "array":
824
+ raise ValueError(
825
+ "Predicting to disk is not supported for input type 'array'."
826
+ )
827
+ assert isinstance(source_path, (Path, str)) # because data_type != "array"
828
+ source_path = Path(source_path)
829
+
830
+ file_paths = list_files(source_path, source_data_type, extension_filter)
831
+
832
+ # predict and write each file in turn
833
+ for file_path in file_paths:
834
+ # source_path is relative to original source path...
835
+ # should mirror original directory structure
836
+ prediction = self.predict(
837
+ source=file_path,
838
+ batch_size=batch_size,
839
+ tile_size=tile_size,
840
+ tile_overlap=tile_overlap,
841
+ axes=axes,
842
+ data_type=data_type,
843
+ tta_transforms=tta_transforms,
844
+ dataloader_params=dataloader_params,
845
+ read_source_func=read_source_func,
846
+ extension_filter=extension_filter,
847
+ **kwargs,
848
+ )
849
+ # TODO: cast to float16?
850
+ write_data = np.concatenate(prediction)
851
+
852
+ # create directory structure and write path
853
+ if not source_path.is_file():
854
+ file_write_dir = write_dir / file_path.parent.relative_to(source_path)
855
+ else:
856
+ file_write_dir = write_dir
857
+ file_write_dir.mkdir(parents=True, exist_ok=True)
858
+ write_path = (file_write_dir / file_path.name).with_suffix(write_extension)
859
+
860
+ # write data
861
+ write_func(file_path=write_path, img=write_data)
862
+
671
863
  def export_to_bmz(
672
864
  self,
673
865
  path_to_archive: Union[Path, str],
674
866
  friendly_model_name: str,
675
867
  input_array: NDArray,
676
868
  authors: list[dict],
677
- general_description: str = "",
869
+ general_description: str,
870
+ data_description: str,
871
+ covers: Optional[list[Union[Path, str]]] = None,
678
872
  channel_names: Optional[list[str]] = None,
679
- data_description: Optional[str] = None,
873
+ model_version: str = "0.1.0",
680
874
  ) -> None:
681
875
  """Export the model to the BioImage Model Zoo format.
682
876
 
@@ -706,11 +900,15 @@ class CAREamist:
706
900
  authors : list of dict
707
901
  List of authors of the model.
708
902
  general_description : str
709
- General description of the model, used in the metadata of the BMZ archive.
710
- channel_names : list of str, optional
711
- Channel names, by default None.
712
- data_description : str, optional
713
- Description of the data, by default None.
903
+ General description of the model used in the BMZ metadata.
904
+ data_description : str
905
+ Description of the data the model was trained on.
906
+ covers : list of pathlib.Path or str, default=None
907
+ Paths to the cover images.
908
+ channel_names : list of str, default=None
909
+ Channel names.
910
+ model_version : str, default="0.1.0"
911
+ Version of the model.
714
912
  """
715
913
  # TODO: add in docs that it is expected that input_array dimensions match
716
914
  # those in data_config
@@ -729,9 +927,21 @@ class CAREamist:
729
927
  path_to_archive=path_to_archive,
730
928
  model_name=friendly_model_name,
731
929
  general_description=general_description,
930
+ data_description=data_description,
732
931
  authors=authors,
733
932
  input_array=input_array,
734
933
  output_array=output,
934
+ covers=covers,
735
935
  channel_names=channel_names,
736
- data_description=data_description,
936
+ model_version=model_version,
737
937
  )
938
+
939
+ def get_losses(self) -> dict[str, list]:
940
+ """Return data that can be used to plot train and validation loss curves.
941
+
942
+ Returns
943
+ -------
944
+ dict of str: list
945
+ Dictionary containing the losses for each epoch.
946
+ """
947
+ return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")
careamics/cli/conf.py CHANGED
@@ -3,7 +3,7 @@
3
3
  import sys
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
- from typing import Tuple
6
+ from typing import Optional
7
7
 
8
8
  import click
9
9
  import typer
@@ -17,6 +17,7 @@ from ..config import (
17
17
  create_n2v_configuration,
18
18
  save_configuration,
19
19
  )
20
+ from .utils import handle_2D_3D_callback
20
21
 
21
22
  WORK_DIR = Path.cwd()
22
23
 
@@ -92,26 +93,6 @@ def conf_options( # numpydoc ignore=PR01
92
93
  ctx.obj = ConfOptions(dir, name, force, print)
93
94
 
94
95
 
95
- def patch_size_callback(value: Tuple[int, int, int]) -> Tuple[int, ...]:
96
- """
97
- Callback for --patch-size option.
98
-
99
- Parameters
100
- ----------
101
- value : (int, int, int)
102
- Patch size value.
103
-
104
- Returns
105
- -------
106
- (int, int, int) | (int, int)
107
- If the last element in `value` is -1 the tuple is reduced to the first two
108
- values.
109
- """
110
- if value[2] == -1:
111
- return value[:2]
112
- return value
113
-
114
-
115
96
  # TODO: Need to decide how to parse model kwargs
116
97
  # - Could be json style string to be loaded as dict e.g. {"depth": 3}
117
98
  # - Cons: Annoying to type, easily have syntax errors
@@ -132,7 +113,7 @@ def care( # numpydoc ignore=PR01
132
113
  "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
133
114
  ),
134
115
  click_type=click.Tuple([int, int, int]),
135
- callback=patch_size_callback,
116
+ callback=handle_2D_3D_callback,
136
117
  ),
137
118
  ],
138
119
  batch_size: Annotated[int, typer.Option(help="Batch size.")],
@@ -154,8 +135,12 @@ def care( # numpydoc ignore=PR01
154
135
  help="Loss function to use.",
155
136
  ),
156
137
  ] = "mae",
157
- n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
158
- n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
138
+ n_channels_in: Annotated[
139
+ Optional[int], typer.Option(help="Number of channels in")
140
+ ] = None,
141
+ n_channels_out: Annotated[
142
+ Optional[int], typer.Option(help="Number of channels out")
143
+ ] = None,
159
144
  logger: Annotated[
160
145
  click.Choice,
161
146
  typer.Option(
@@ -215,7 +200,7 @@ def n2n( # numpydoc ignore=PR01
215
200
  "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
216
201
  ),
217
202
  click_type=click.Tuple([int, int, int]),
218
- callback=patch_size_callback,
203
+ callback=handle_2D_3D_callback,
219
204
  ),
220
205
  ],
221
206
  batch_size: Annotated[int, typer.Option(help="Batch size.")],
@@ -237,8 +222,12 @@ def n2n( # numpydoc ignore=PR01
237
222
  help="Loss function to use.",
238
223
  ),
239
224
  ] = "mae",
240
- n_channels_in: Annotated[int, typer.Option(help="Number of channels in")] = 1,
241
- n_channels_out: Annotated[int, typer.Option(help="Number of channels out")] = -1,
225
+ n_channels_in: Annotated[
226
+ Optional[int], typer.Option(help="Number of channels in")
227
+ ] = None,
228
+ n_channels_out: Annotated[
229
+ Optional[int], typer.Option(help="Number of channels out")
230
+ ] = None,
242
231
  logger: Annotated[
243
232
  click.Choice,
244
233
  typer.Option(
@@ -295,7 +284,7 @@ def n2v( # numpydoc ignore=PR01
295
284
  "is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
296
285
  ),
297
286
  click_type=click.Tuple([int, int, int]),
298
- callback=patch_size_callback,
287
+ callback=handle_2D_3D_callback,
299
288
  ),
300
289
  ],
301
290
  batch_size: Annotated[int, typer.Option(help="Batch size.")],
@@ -312,8 +301,8 @@ def n2v( # numpydoc ignore=PR01
312
301
  ] = True,
313
302
  use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
314
303
  n_channels: Annotated[
315
- int, typer.Option(help="Number of channels (in and out)")
316
- ] = 1,
304
+ Optional[int], typer.Option(help="Number of channels (in and out)")
305
+ ] = None,
317
306
  roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
318
307
  masked_pixel_percentage: Annotated[
319
308
  float, typer.Option(help="Percentage of pixels masked in each patch.")
careamics/cli/main.py CHANGED
@@ -9,21 +9,20 @@ its implementation is contained in the conf.py file.
9
9
  from pathlib import Path
10
10
  from typing import Optional
11
11
 
12
+ import click
12
13
  import typer
13
14
  from typing_extensions import Annotated
14
15
 
15
16
  from ..careamist import CAREamist
16
17
  from . import conf
18
+ from .utils import handle_2D_3D_callback
17
19
 
18
20
  app = typer.Typer(
19
21
  help="Run CAREamics algorithms from the command line, including Noise2Void "
20
- "and its many variants and cousins"
21
- )
22
- app.add_typer(
23
- conf.app,
24
- name="conf",
25
- # callback=conf.conf_options
22
+ "and its many variants and cousins",
23
+ pretty_exceptions_show_locals=False,
26
24
  )
25
+ app.add_typer(conf.app, name="conf")
27
26
 
28
27
 
29
28
  @app.command()
@@ -102,7 +101,7 @@ def train( # numpydoc ignore=PR01
102
101
  typer.Option(
103
102
  "--work-dir",
104
103
  "-wd",
105
- help=("Path to working directory in which to save checkpoints and " "logs"),
104
+ help=("Path to working directory in which to save checkpoints and logs"),
106
105
  exists=True,
107
106
  file_okay=False,
108
107
  dir_okay=True,
@@ -123,10 +122,112 @@ def train( # numpydoc ignore=PR01
123
122
 
124
123
 
125
124
  @app.command()
126
- def predict(): # numpydoc ignore=PR01
125
+ def predict( # numpydoc ignore=PR01
126
+ model: Annotated[
127
+ Path,
128
+ typer.Argument(
129
+ help="Path to a configuration file or a trained model.",
130
+ exists=True,
131
+ file_okay=True,
132
+ dir_okay=False,
133
+ ),
134
+ ],
135
+ source: Annotated[
136
+ Path,
137
+ typer.Argument(
138
+ help="Path to the training data. Can be a directory or single file.",
139
+ exists=True,
140
+ file_okay=True,
141
+ dir_okay=True,
142
+ ),
143
+ ],
144
+ batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
145
+ tile_size: Annotated[
146
+ Optional[click.Tuple],
147
+ typer.Option(
148
+ help=(
149
+ "Size of the tiles to use for prediction, (if the data "
150
+ "is not 3D pass the last value as -1 e.g. --tile_size 64 64 -1)."
151
+ ),
152
+ click_type=click.Tuple([int, int, int]),
153
+ callback=handle_2D_3D_callback,
154
+ ),
155
+ ] = None,
156
+ tile_overlap: Annotated[
157
+ click.Tuple,
158
+ typer.Option(
159
+ help=(
160
+ "Overlap between tiles, (if the data is not 3D pass the last value as "
161
+ "-1 e.g. --tile_overlap 64 64 -1)."
162
+ ),
163
+ click_type=click.Tuple([int, int, int]),
164
+ callback=handle_2D_3D_callback,
165
+ ),
166
+ ] = (48, 48, -1),
167
+ axes: Annotated[
168
+ Optional[str],
169
+ typer.Option(
170
+ help="Axes of the input data. If unused the data is assumed to have the "
171
+ "same axes as the original training data."
172
+ ),
173
+ ] = None,
174
+ data_type: Annotated[
175
+ click.Choice,
176
+ typer.Option(click_type=click.Choice(["tiff"]), help="Type of the input data."),
177
+ ] = "tiff",
178
+ tta_transforms: Annotated[
179
+ bool,
180
+ typer.Option(
181
+ "--tta-transforms/--no-tta-transforms",
182
+ "-t/-T",
183
+ help="Whether to apply test-time augmentation.",
184
+ ),
185
+ ] = False,
186
+ write_type: Annotated[
187
+ click.Choice,
188
+ typer.Option(
189
+ click_type=click.Choice(["tiff"]), help="Type of the output data."
190
+ ),
191
+ ] = "tiff",
192
+ # TODO: could make dataloader_params as json, necessary?
193
+ work_dir: Annotated[
194
+ Optional[Path],
195
+ typer.Option(
196
+ "--work-dir",
197
+ "-wd",
198
+ help=("Path to working directory."),
199
+ exists=True,
200
+ file_okay=False,
201
+ dir_okay=True,
202
+ ),
203
+ ] = None,
204
+ prediction_dir: Annotated[
205
+ Path,
206
+ typer.Option(
207
+ "--prediction-dir",
208
+ "-pd",
209
+ help=(
210
+ "Directory to save predictions to. If not an abosulte path it will be "
211
+ "relative to the set working directory."
212
+ ),
213
+ file_okay=False,
214
+ dir_okay=True,
215
+ ),
216
+ ] = Path("predictions"),
217
+ ):
127
218
  """Create and save predictions from CAREamics models."""
128
- # TODO: Need a save predict to workdir function
129
- raise NotImplementedError
219
+ engine = CAREamist(source=model, work_dir=work_dir)
220
+ engine.predict_to_disk(
221
+ source=source,
222
+ batch_size=batch_size,
223
+ tile_size=tile_size,
224
+ tile_overlap=tile_overlap,
225
+ axes=axes,
226
+ data_type=data_type,
227
+ tta_transforms=tta_transforms,
228
+ write_type=write_type,
229
+ prediction_dir=prediction_dir,
230
+ )
130
231
 
131
232
 
132
233
  def run():
careamics/cli/utils.py ADDED
@@ -0,0 +1,29 @@
1
+ """Utility functions for the CAREamics CLI."""
2
+
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def handle_2D_3D_callback(
7
+ value: Optional[Tuple[int, int, int]]
8
+ ) -> Optional[Tuple[int, ...]]:
9
+ """
10
+ Callback for options that require 2D or 3D inputs.
11
+
12
+ In the case of 2D, the 3rd element should be set to -1.
13
+
14
+ Parameters
15
+ ----------
16
+ value : (int, int, int)
17
+ Tile size value.
18
+
19
+ Returns
20
+ -------
21
+ (int, int, int) | (int, int)
22
+ If the last element in `value` is -1 the tuple is reduced to the first two
23
+ values.
24
+ """
25
+ if value is None:
26
+ return value
27
+ if value[2] == -1:
28
+ return value[:2]
29
+ return value
@@ -18,6 +18,7 @@ __all__ = [
18
18
  "clear_custom_models",
19
19
  "GaussianMixtureNMConfig",
20
20
  "MultiChannelNMConfig",
21
+ "LVAELossConfig",
21
22
  ]
22
23
  from .architectures import CustomModel, clear_custom_models, register_model
23
24
  from .callback_model import CheckpointModel
@@ -34,6 +35,7 @@ from .configuration_model import (
34
35
  from .data_model import DataConfig
35
36
  from .fcn_algorithm_model import FCNAlgorithmConfig
36
37
  from .inference_model import InferenceConfig
38
+ from .loss_model import LVAELossConfig
37
39
  from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
38
40
  from .training_model import TrainingConfig
39
41
  from .vae_algorithm_model import VAEAlgorithmConfig