careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +235 -25
- careamics/cli/conf.py +19 -30
- careamics/cli/main.py +111 -10
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +2 -0
- careamics/config/architectures/lvae_model.py +104 -21
- careamics/config/configuration_factory.py +49 -45
- careamics/config/configuration_model.py +2 -2
- careamics/config/likelihood_model.py +7 -6
- careamics/config/loss_model.py +56 -0
- careamics/config/nm_model.py +24 -24
- careamics/config/vae_algorithm_model.py +14 -13
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/lightning/lightning_module.py +58 -27
- careamics/lightning/train_data_module.py +15 -1
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/bioimage/_readme_factory.py +25 -33
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +39 -17
- careamics/model_io/bmz_io.py +36 -25
- careamics/models/layers.py +6 -4
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -272
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -11,7 +11,7 @@ from pytorch_lightning.callbacks import (
|
|
|
11
11
|
EarlyStopping,
|
|
12
12
|
ModelCheckpoint,
|
|
13
13
|
)
|
|
14
|
-
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
14
|
+
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
|
|
15
15
|
|
|
16
16
|
from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
|
|
17
17
|
from careamics.config.support import (
|
|
@@ -20,7 +20,8 @@ from careamics.config.support import (
|
|
|
20
20
|
SupportedData,
|
|
21
21
|
SupportedLogger,
|
|
22
22
|
)
|
|
23
|
-
from careamics.dataset.dataset_utils import reshape_array
|
|
23
|
+
from careamics.dataset.dataset_utils import list_files, reshape_array
|
|
24
|
+
from careamics.file_io import WriteFunc, get_write_func
|
|
24
25
|
from careamics.lightning import (
|
|
25
26
|
FCNModule,
|
|
26
27
|
HyperParametersCallback,
|
|
@@ -32,10 +33,11 @@ from careamics.lightning import (
|
|
|
32
33
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
33
34
|
from careamics.prediction_utils import convert_outputs
|
|
34
35
|
from careamics.utils import check_path_exists, get_logger
|
|
36
|
+
from careamics.utils.lightning_utils import read_csv_logger
|
|
35
37
|
|
|
36
38
|
logger = get_logger(__name__)
|
|
37
39
|
|
|
38
|
-
LOGGER_TYPES =
|
|
40
|
+
LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class CAREamist:
|
|
@@ -144,6 +146,7 @@ class CAREamist:
|
|
|
144
146
|
|
|
145
147
|
# path to configuration file or model
|
|
146
148
|
else:
|
|
149
|
+
# TODO: update this check so models can be downloaded directly from BMZ
|
|
147
150
|
source = check_path_exists(source)
|
|
148
151
|
|
|
149
152
|
# configuration file
|
|
@@ -169,18 +172,29 @@ class CAREamist:
|
|
|
169
172
|
self._define_callbacks(callbacks)
|
|
170
173
|
|
|
171
174
|
# instantiate logger
|
|
175
|
+
csv_logger = CSVLogger(
|
|
176
|
+
name=self.cfg.experiment_name,
|
|
177
|
+
save_dir=self.work_dir / "csv_logs",
|
|
178
|
+
)
|
|
179
|
+
|
|
172
180
|
if self.cfg.training_config.has_logger():
|
|
173
181
|
if self.cfg.training_config.logger == SupportedLogger.WANDB:
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
182
|
+
experiment_logger: LOGGER_TYPES = [
|
|
183
|
+
WandbLogger(
|
|
184
|
+
name=self.cfg.experiment_name,
|
|
185
|
+
save_dir=self.work_dir / Path("wandb_logs"),
|
|
186
|
+
),
|
|
187
|
+
csv_logger,
|
|
188
|
+
]
|
|
178
189
|
elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
190
|
+
experiment_logger = [
|
|
191
|
+
TensorBoardLogger(
|
|
192
|
+
save_dir=self.work_dir / Path("tb_logs"),
|
|
193
|
+
),
|
|
194
|
+
csv_logger,
|
|
195
|
+
]
|
|
182
196
|
else:
|
|
183
|
-
|
|
197
|
+
experiment_logger = [csv_logger]
|
|
184
198
|
|
|
185
199
|
# instantiate trainer
|
|
186
200
|
self.trainer = Trainer(
|
|
@@ -194,7 +208,7 @@ class CAREamist:
|
|
|
194
208
|
gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
|
|
195
209
|
callbacks=self.callbacks,
|
|
196
210
|
default_root_dir=self.work_dir,
|
|
197
|
-
logger=
|
|
211
|
+
logger=experiment_logger,
|
|
198
212
|
)
|
|
199
213
|
|
|
200
214
|
# place holder for the datamodules
|
|
@@ -519,7 +533,7 @@ class CAREamist:
|
|
|
519
533
|
*,
|
|
520
534
|
batch_size: int = 1,
|
|
521
535
|
tile_size: Optional[tuple[int, ...]] = None,
|
|
522
|
-
tile_overlap: tuple[int, ...] = (48, 48),
|
|
536
|
+
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
523
537
|
axes: Optional[str] = None,
|
|
524
538
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
525
539
|
tta_transforms: bool = False,
|
|
@@ -535,7 +549,7 @@ class CAREamist:
|
|
|
535
549
|
*,
|
|
536
550
|
batch_size: int = 1,
|
|
537
551
|
tile_size: Optional[tuple[int, ...]] = None,
|
|
538
|
-
tile_overlap: tuple[int, ...] = (48, 48),
|
|
552
|
+
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
539
553
|
axes: Optional[str] = None,
|
|
540
554
|
data_type: Optional[Literal["array"]] = None,
|
|
541
555
|
tta_transforms: bool = False,
|
|
@@ -546,7 +560,7 @@ class CAREamist:
|
|
|
546
560
|
self,
|
|
547
561
|
source: Union[PredictDataModule, Path, str, NDArray],
|
|
548
562
|
*,
|
|
549
|
-
batch_size:
|
|
563
|
+
batch_size: int = 1,
|
|
550
564
|
tile_size: Optional[tuple[int, ...]] = None,
|
|
551
565
|
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
552
566
|
axes: Optional[str] = None,
|
|
@@ -567,7 +581,7 @@ class CAREamist:
|
|
|
567
581
|
configuration parameters will be used, with the `patch_size` instead of
|
|
568
582
|
`tile_size`.
|
|
569
583
|
|
|
570
|
-
Test-time augmentation (TTA) can be switched
|
|
584
|
+
Test-time augmentation (TTA) can be switched on using the `tta_transforms`
|
|
571
585
|
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
572
586
|
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
573
587
|
should not be used if you did not train with these augmentations.
|
|
@@ -580,7 +594,7 @@ class CAREamist:
|
|
|
580
594
|
|
|
581
595
|
Parameters
|
|
582
596
|
----------
|
|
583
|
-
source :
|
|
597
|
+
source : PredictDataModule, pathlib.Path, str or numpy.ndarray
|
|
584
598
|
Data to predict on.
|
|
585
599
|
batch_size : int, default=1
|
|
586
600
|
Batch size for prediction.
|
|
@@ -668,15 +682,195 @@ class CAREamist:
|
|
|
668
682
|
)
|
|
669
683
|
return convert_outputs(predictions, self.pred_datamodule.tiled)
|
|
670
684
|
|
|
685
|
+
def predict_to_disk(
|
|
686
|
+
self,
|
|
687
|
+
source: Union[PredictDataModule, Path, str],
|
|
688
|
+
*,
|
|
689
|
+
batch_size: int = 1,
|
|
690
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
691
|
+
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
692
|
+
axes: Optional[str] = None,
|
|
693
|
+
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
694
|
+
tta_transforms: bool = False,
|
|
695
|
+
dataloader_params: Optional[dict] = None,
|
|
696
|
+
read_source_func: Optional[Callable] = None,
|
|
697
|
+
extension_filter: str = "",
|
|
698
|
+
write_type: Literal["tiff", "custom"] = "tiff",
|
|
699
|
+
write_extension: Optional[str] = None,
|
|
700
|
+
write_func: Optional[WriteFunc] = None,
|
|
701
|
+
write_func_kwargs: Optional[dict[str, Any]] = None,
|
|
702
|
+
prediction_dir: Union[Path, str] = "predictions",
|
|
703
|
+
**kwargs,
|
|
704
|
+
) -> None:
|
|
705
|
+
"""
|
|
706
|
+
Make predictions on the provided data and save outputs to files.
|
|
707
|
+
|
|
708
|
+
The predictions will be saved in a new directory 'predictions' within the set
|
|
709
|
+
working directory. The directory stucture within the 'predictions' directory
|
|
710
|
+
will match that of the source directory.
|
|
711
|
+
|
|
712
|
+
The `source` must be from files and not arrays. The file names of the
|
|
713
|
+
predictions will match those of the source. If there is more than one sample
|
|
714
|
+
within a file, the samples will be saved to seperate files. The file names of
|
|
715
|
+
samples will have the name of the corresponding source file but with the sample
|
|
716
|
+
index appended. E.g. If the the source file name is 'images.tiff' then the first
|
|
717
|
+
sample's prediction will be saved with the file name "image_0.tiff".
|
|
718
|
+
Input can be a PredictDataModule instance, a path to a data file, or a numpy
|
|
719
|
+
array.
|
|
720
|
+
|
|
721
|
+
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
722
|
+
configuration parameters will be used, with the `patch_size` instead of
|
|
723
|
+
`tile_size`.
|
|
724
|
+
|
|
725
|
+
Test-time augmentation (TTA) can be switched on using the `tta_transforms`
|
|
726
|
+
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
727
|
+
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
728
|
+
should not be used if you did not train with these augmentations.
|
|
729
|
+
|
|
730
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
731
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
732
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
733
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
734
|
+
happen in the Z dimension, consider padding your image.
|
|
735
|
+
|
|
736
|
+
Parameters
|
|
737
|
+
----------
|
|
738
|
+
source : PredictDataModule or pathlib.Path, str
|
|
739
|
+
Data to predict on.
|
|
740
|
+
batch_size : int, default=1
|
|
741
|
+
Batch size for prediction.
|
|
742
|
+
tile_size : tuple of int, optional
|
|
743
|
+
Size of the tiles to use for prediction.
|
|
744
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
745
|
+
Overlap between tiles.
|
|
746
|
+
axes : str, optional
|
|
747
|
+
Axes of the input data, by default None.
|
|
748
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
749
|
+
Type of the input data.
|
|
750
|
+
tta_transforms : bool, default=True
|
|
751
|
+
Whether to apply test-time augmentation.
|
|
752
|
+
dataloader_params : dict, optional
|
|
753
|
+
Parameters to pass to the dataloader.
|
|
754
|
+
read_source_func : Callable, optional
|
|
755
|
+
Function to read the source data.
|
|
756
|
+
extension_filter : str, default=""
|
|
757
|
+
Filter for the file extension.
|
|
758
|
+
write_type : {"tiff", "custom"}, default="tiff"
|
|
759
|
+
The data type to save as, includes custom.
|
|
760
|
+
write_extension : str, optional
|
|
761
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
762
|
+
`write_type` an extension to save the data with must be passed.
|
|
763
|
+
write_func : WriteFunc, optional
|
|
764
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
765
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
766
|
+
write_func_kwargs : dict of {str: any}, optional
|
|
767
|
+
Additional keyword arguments to be passed to the save function.
|
|
768
|
+
prediction_dir : Path | str, default="predictions"
|
|
769
|
+
The path to save the prediction results to. If `prediction_dir` is not
|
|
770
|
+
absolute, the directory will be assumed to be relative to the pre-set
|
|
771
|
+
`work_dir`. If the directory does not exist it will be created.
|
|
772
|
+
**kwargs : Any
|
|
773
|
+
Unused.
|
|
774
|
+
|
|
775
|
+
Raises
|
|
776
|
+
------
|
|
777
|
+
ValueError
|
|
778
|
+
If `write_type` is custom and `write_extension` is None.
|
|
779
|
+
ValueError
|
|
780
|
+
If `write_type` is custom and `write_fun is None.
|
|
781
|
+
ValueError
|
|
782
|
+
If `source` is not `str`, `Path` or `PredictDataModule`
|
|
783
|
+
"""
|
|
784
|
+
if write_func_kwargs is None:
|
|
785
|
+
write_func_kwargs = {}
|
|
786
|
+
|
|
787
|
+
if Path(prediction_dir).is_absolute():
|
|
788
|
+
write_dir = Path(prediction_dir)
|
|
789
|
+
else:
|
|
790
|
+
write_dir = self.work_dir / prediction_dir
|
|
791
|
+
write_dir.mkdir(exist_ok=True, parents=True)
|
|
792
|
+
|
|
793
|
+
# guards for custom types
|
|
794
|
+
if write_type == SupportedData.CUSTOM:
|
|
795
|
+
if write_extension is None:
|
|
796
|
+
raise ValueError(
|
|
797
|
+
"A `write_extension` must be provided for custom write types."
|
|
798
|
+
)
|
|
799
|
+
if write_func is None:
|
|
800
|
+
raise ValueError(
|
|
801
|
+
"A `write_func` must be provided for custom write types."
|
|
802
|
+
)
|
|
803
|
+
else:
|
|
804
|
+
write_func = get_write_func(write_type)
|
|
805
|
+
write_extension = SupportedData.get_extension(write_type)
|
|
806
|
+
|
|
807
|
+
# extract file names
|
|
808
|
+
source_path: Union[Path, str, NDArray]
|
|
809
|
+
source_data_type: Literal["array", "tiff", "custom"]
|
|
810
|
+
if isinstance(source, PredictDataModule):
|
|
811
|
+
source_path = source.pred_data
|
|
812
|
+
source_data_type = source.data_type
|
|
813
|
+
extension_filter = source.extension_filter
|
|
814
|
+
elif isinstance(source, (str, Path)):
|
|
815
|
+
source_path = source
|
|
816
|
+
source_data_type = data_type or self.cfg.data_config.data_type
|
|
817
|
+
extension_filter = SupportedData.get_extension_pattern(
|
|
818
|
+
SupportedData(source_data_type)
|
|
819
|
+
)
|
|
820
|
+
else:
|
|
821
|
+
raise ValueError(f"Unsupported source type: '{type(source)}'.")
|
|
822
|
+
|
|
823
|
+
if source_data_type == "array":
|
|
824
|
+
raise ValueError(
|
|
825
|
+
"Predicting to disk is not supported for input type 'array'."
|
|
826
|
+
)
|
|
827
|
+
assert isinstance(source_path, (Path, str)) # because data_type != "array"
|
|
828
|
+
source_path = Path(source_path)
|
|
829
|
+
|
|
830
|
+
file_paths = list_files(source_path, source_data_type, extension_filter)
|
|
831
|
+
|
|
832
|
+
# predict and write each file in turn
|
|
833
|
+
for file_path in file_paths:
|
|
834
|
+
# source_path is relative to original source path...
|
|
835
|
+
# should mirror original directory structure
|
|
836
|
+
prediction = self.predict(
|
|
837
|
+
source=file_path,
|
|
838
|
+
batch_size=batch_size,
|
|
839
|
+
tile_size=tile_size,
|
|
840
|
+
tile_overlap=tile_overlap,
|
|
841
|
+
axes=axes,
|
|
842
|
+
data_type=data_type,
|
|
843
|
+
tta_transforms=tta_transforms,
|
|
844
|
+
dataloader_params=dataloader_params,
|
|
845
|
+
read_source_func=read_source_func,
|
|
846
|
+
extension_filter=extension_filter,
|
|
847
|
+
**kwargs,
|
|
848
|
+
)
|
|
849
|
+
# TODO: cast to float16?
|
|
850
|
+
write_data = np.concatenate(prediction)
|
|
851
|
+
|
|
852
|
+
# create directory structure and write path
|
|
853
|
+
if not source_path.is_file():
|
|
854
|
+
file_write_dir = write_dir / file_path.parent.relative_to(source_path)
|
|
855
|
+
else:
|
|
856
|
+
file_write_dir = write_dir
|
|
857
|
+
file_write_dir.mkdir(parents=True, exist_ok=True)
|
|
858
|
+
write_path = (file_write_dir / file_path.name).with_suffix(write_extension)
|
|
859
|
+
|
|
860
|
+
# write data
|
|
861
|
+
write_func(file_path=write_path, img=write_data)
|
|
862
|
+
|
|
671
863
|
def export_to_bmz(
|
|
672
864
|
self,
|
|
673
865
|
path_to_archive: Union[Path, str],
|
|
674
866
|
friendly_model_name: str,
|
|
675
867
|
input_array: NDArray,
|
|
676
868
|
authors: list[dict],
|
|
677
|
-
general_description: str
|
|
869
|
+
general_description: str,
|
|
870
|
+
data_description: str,
|
|
871
|
+
covers: Optional[list[Union[Path, str]]] = None,
|
|
678
872
|
channel_names: Optional[list[str]] = None,
|
|
679
|
-
|
|
873
|
+
model_version: str = "0.1.0",
|
|
680
874
|
) -> None:
|
|
681
875
|
"""Export the model to the BioImage Model Zoo format.
|
|
682
876
|
|
|
@@ -706,11 +900,15 @@ class CAREamist:
|
|
|
706
900
|
authors : list of dict
|
|
707
901
|
List of authors of the model.
|
|
708
902
|
general_description : str
|
|
709
|
-
General description of the model
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
903
|
+
General description of the model used in the BMZ metadata.
|
|
904
|
+
data_description : str
|
|
905
|
+
Description of the data the model was trained on.
|
|
906
|
+
covers : list of pathlib.Path or str, default=None
|
|
907
|
+
Paths to the cover images.
|
|
908
|
+
channel_names : list of str, default=None
|
|
909
|
+
Channel names.
|
|
910
|
+
model_version : str, default="0.1.0"
|
|
911
|
+
Version of the model.
|
|
714
912
|
"""
|
|
715
913
|
# TODO: add in docs that it is expected that input_array dimensions match
|
|
716
914
|
# those in data_config
|
|
@@ -729,9 +927,21 @@ class CAREamist:
|
|
|
729
927
|
path_to_archive=path_to_archive,
|
|
730
928
|
model_name=friendly_model_name,
|
|
731
929
|
general_description=general_description,
|
|
930
|
+
data_description=data_description,
|
|
732
931
|
authors=authors,
|
|
733
932
|
input_array=input_array,
|
|
734
933
|
output_array=output,
|
|
934
|
+
covers=covers,
|
|
735
935
|
channel_names=channel_names,
|
|
736
|
-
|
|
936
|
+
model_version=model_version,
|
|
737
937
|
)
|
|
938
|
+
|
|
939
|
+
def get_losses(self) -> dict[str, list]:
|
|
940
|
+
"""Return data that can be used to plot train and validation loss curves.
|
|
941
|
+
|
|
942
|
+
Returns
|
|
943
|
+
-------
|
|
944
|
+
dict of str: list
|
|
945
|
+
Dictionary containing the losses for each epoch.
|
|
946
|
+
"""
|
|
947
|
+
return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")
|
careamics/cli/conf.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import sys
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Optional
|
|
7
7
|
|
|
8
8
|
import click
|
|
9
9
|
import typer
|
|
@@ -17,6 +17,7 @@ from ..config import (
|
|
|
17
17
|
create_n2v_configuration,
|
|
18
18
|
save_configuration,
|
|
19
19
|
)
|
|
20
|
+
from .utils import handle_2D_3D_callback
|
|
20
21
|
|
|
21
22
|
WORK_DIR = Path.cwd()
|
|
22
23
|
|
|
@@ -92,26 +93,6 @@ def conf_options( # numpydoc ignore=PR01
|
|
|
92
93
|
ctx.obj = ConfOptions(dir, name, force, print)
|
|
93
94
|
|
|
94
95
|
|
|
95
|
-
def patch_size_callback(value: Tuple[int, int, int]) -> Tuple[int, ...]:
|
|
96
|
-
"""
|
|
97
|
-
Callback for --patch-size option.
|
|
98
|
-
|
|
99
|
-
Parameters
|
|
100
|
-
----------
|
|
101
|
-
value : (int, int, int)
|
|
102
|
-
Patch size value.
|
|
103
|
-
|
|
104
|
-
Returns
|
|
105
|
-
-------
|
|
106
|
-
(int, int, int) | (int, int)
|
|
107
|
-
If the last element in `value` is -1 the tuple is reduced to the first two
|
|
108
|
-
values.
|
|
109
|
-
"""
|
|
110
|
-
if value[2] == -1:
|
|
111
|
-
return value[:2]
|
|
112
|
-
return value
|
|
113
|
-
|
|
114
|
-
|
|
115
96
|
# TODO: Need to decide how to parse model kwargs
|
|
116
97
|
# - Could be json style string to be loaded as dict e.g. {"depth": 3}
|
|
117
98
|
# - Cons: Annoying to type, easily have syntax errors
|
|
@@ -132,7 +113,7 @@ def care( # numpydoc ignore=PR01
|
|
|
132
113
|
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
133
114
|
),
|
|
134
115
|
click_type=click.Tuple([int, int, int]),
|
|
135
|
-
callback=
|
|
116
|
+
callback=handle_2D_3D_callback,
|
|
136
117
|
),
|
|
137
118
|
],
|
|
138
119
|
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
@@ -154,8 +135,12 @@ def care( # numpydoc ignore=PR01
|
|
|
154
135
|
help="Loss function to use.",
|
|
155
136
|
),
|
|
156
137
|
] = "mae",
|
|
157
|
-
n_channels_in: Annotated[
|
|
158
|
-
|
|
138
|
+
n_channels_in: Annotated[
|
|
139
|
+
Optional[int], typer.Option(help="Number of channels in")
|
|
140
|
+
] = None,
|
|
141
|
+
n_channels_out: Annotated[
|
|
142
|
+
Optional[int], typer.Option(help="Number of channels out")
|
|
143
|
+
] = None,
|
|
159
144
|
logger: Annotated[
|
|
160
145
|
click.Choice,
|
|
161
146
|
typer.Option(
|
|
@@ -215,7 +200,7 @@ def n2n( # numpydoc ignore=PR01
|
|
|
215
200
|
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
216
201
|
),
|
|
217
202
|
click_type=click.Tuple([int, int, int]),
|
|
218
|
-
callback=
|
|
203
|
+
callback=handle_2D_3D_callback,
|
|
219
204
|
),
|
|
220
205
|
],
|
|
221
206
|
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
@@ -237,8 +222,12 @@ def n2n( # numpydoc ignore=PR01
|
|
|
237
222
|
help="Loss function to use.",
|
|
238
223
|
),
|
|
239
224
|
] = "mae",
|
|
240
|
-
n_channels_in: Annotated[
|
|
241
|
-
|
|
225
|
+
n_channels_in: Annotated[
|
|
226
|
+
Optional[int], typer.Option(help="Number of channels in")
|
|
227
|
+
] = None,
|
|
228
|
+
n_channels_out: Annotated[
|
|
229
|
+
Optional[int], typer.Option(help="Number of channels out")
|
|
230
|
+
] = None,
|
|
242
231
|
logger: Annotated[
|
|
243
232
|
click.Choice,
|
|
244
233
|
typer.Option(
|
|
@@ -295,7 +284,7 @@ def n2v( # numpydoc ignore=PR01
|
|
|
295
284
|
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
296
285
|
),
|
|
297
286
|
click_type=click.Tuple([int, int, int]),
|
|
298
|
-
callback=
|
|
287
|
+
callback=handle_2D_3D_callback,
|
|
299
288
|
),
|
|
300
289
|
],
|
|
301
290
|
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
@@ -312,8 +301,8 @@ def n2v( # numpydoc ignore=PR01
|
|
|
312
301
|
] = True,
|
|
313
302
|
use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
|
|
314
303
|
n_channels: Annotated[
|
|
315
|
-
int, typer.Option(help="Number of channels (in and out)")
|
|
316
|
-
] =
|
|
304
|
+
Optional[int], typer.Option(help="Number of channels (in and out)")
|
|
305
|
+
] = None,
|
|
317
306
|
roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
|
|
318
307
|
masked_pixel_percentage: Annotated[
|
|
319
308
|
float, typer.Option(help="Percentage of pixels masked in each patch.")
|
careamics/cli/main.py
CHANGED
|
@@ -9,21 +9,20 @@ its implementation is contained in the conf.py file.
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from typing import Optional
|
|
11
11
|
|
|
12
|
+
import click
|
|
12
13
|
import typer
|
|
13
14
|
from typing_extensions import Annotated
|
|
14
15
|
|
|
15
16
|
from ..careamist import CAREamist
|
|
16
17
|
from . import conf
|
|
18
|
+
from .utils import handle_2D_3D_callback
|
|
17
19
|
|
|
18
20
|
app = typer.Typer(
|
|
19
21
|
help="Run CAREamics algorithms from the command line, including Noise2Void "
|
|
20
|
-
"and its many variants and cousins"
|
|
21
|
-
|
|
22
|
-
app.add_typer(
|
|
23
|
-
conf.app,
|
|
24
|
-
name="conf",
|
|
25
|
-
# callback=conf.conf_options
|
|
22
|
+
"and its many variants and cousins",
|
|
23
|
+
pretty_exceptions_show_locals=False,
|
|
26
24
|
)
|
|
25
|
+
app.add_typer(conf.app, name="conf")
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
@app.command()
|
|
@@ -102,7 +101,7 @@ def train( # numpydoc ignore=PR01
|
|
|
102
101
|
typer.Option(
|
|
103
102
|
"--work-dir",
|
|
104
103
|
"-wd",
|
|
105
|
-
help=("Path to working directory in which to save checkpoints and
|
|
104
|
+
help=("Path to working directory in which to save checkpoints and logs"),
|
|
106
105
|
exists=True,
|
|
107
106
|
file_okay=False,
|
|
108
107
|
dir_okay=True,
|
|
@@ -123,10 +122,112 @@ def train( # numpydoc ignore=PR01
|
|
|
123
122
|
|
|
124
123
|
|
|
125
124
|
@app.command()
|
|
126
|
-
def predict(
|
|
125
|
+
def predict( # numpydoc ignore=PR01
|
|
126
|
+
model: Annotated[
|
|
127
|
+
Path,
|
|
128
|
+
typer.Argument(
|
|
129
|
+
help="Path to a configuration file or a trained model.",
|
|
130
|
+
exists=True,
|
|
131
|
+
file_okay=True,
|
|
132
|
+
dir_okay=False,
|
|
133
|
+
),
|
|
134
|
+
],
|
|
135
|
+
source: Annotated[
|
|
136
|
+
Path,
|
|
137
|
+
typer.Argument(
|
|
138
|
+
help="Path to the training data. Can be a directory or single file.",
|
|
139
|
+
exists=True,
|
|
140
|
+
file_okay=True,
|
|
141
|
+
dir_okay=True,
|
|
142
|
+
),
|
|
143
|
+
],
|
|
144
|
+
batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
|
|
145
|
+
tile_size: Annotated[
|
|
146
|
+
Optional[click.Tuple],
|
|
147
|
+
typer.Option(
|
|
148
|
+
help=(
|
|
149
|
+
"Size of the tiles to use for prediction, (if the data "
|
|
150
|
+
"is not 3D pass the last value as -1 e.g. --tile_size 64 64 -1)."
|
|
151
|
+
),
|
|
152
|
+
click_type=click.Tuple([int, int, int]),
|
|
153
|
+
callback=handle_2D_3D_callback,
|
|
154
|
+
),
|
|
155
|
+
] = None,
|
|
156
|
+
tile_overlap: Annotated[
|
|
157
|
+
click.Tuple,
|
|
158
|
+
typer.Option(
|
|
159
|
+
help=(
|
|
160
|
+
"Overlap between tiles, (if the data is not 3D pass the last value as "
|
|
161
|
+
"-1 e.g. --tile_overlap 64 64 -1)."
|
|
162
|
+
),
|
|
163
|
+
click_type=click.Tuple([int, int, int]),
|
|
164
|
+
callback=handle_2D_3D_callback,
|
|
165
|
+
),
|
|
166
|
+
] = (48, 48, -1),
|
|
167
|
+
axes: Annotated[
|
|
168
|
+
Optional[str],
|
|
169
|
+
typer.Option(
|
|
170
|
+
help="Axes of the input data. If unused the data is assumed to have the "
|
|
171
|
+
"same axes as the original training data."
|
|
172
|
+
),
|
|
173
|
+
] = None,
|
|
174
|
+
data_type: Annotated[
|
|
175
|
+
click.Choice,
|
|
176
|
+
typer.Option(click_type=click.Choice(["tiff"]), help="Type of the input data."),
|
|
177
|
+
] = "tiff",
|
|
178
|
+
tta_transforms: Annotated[
|
|
179
|
+
bool,
|
|
180
|
+
typer.Option(
|
|
181
|
+
"--tta-transforms/--no-tta-transforms",
|
|
182
|
+
"-t/-T",
|
|
183
|
+
help="Whether to apply test-time augmentation.",
|
|
184
|
+
),
|
|
185
|
+
] = False,
|
|
186
|
+
write_type: Annotated[
|
|
187
|
+
click.Choice,
|
|
188
|
+
typer.Option(
|
|
189
|
+
click_type=click.Choice(["tiff"]), help="Type of the output data."
|
|
190
|
+
),
|
|
191
|
+
] = "tiff",
|
|
192
|
+
# TODO: could make dataloader_params as json, necessary?
|
|
193
|
+
work_dir: Annotated[
|
|
194
|
+
Optional[Path],
|
|
195
|
+
typer.Option(
|
|
196
|
+
"--work-dir",
|
|
197
|
+
"-wd",
|
|
198
|
+
help=("Path to working directory."),
|
|
199
|
+
exists=True,
|
|
200
|
+
file_okay=False,
|
|
201
|
+
dir_okay=True,
|
|
202
|
+
),
|
|
203
|
+
] = None,
|
|
204
|
+
prediction_dir: Annotated[
|
|
205
|
+
Path,
|
|
206
|
+
typer.Option(
|
|
207
|
+
"--prediction-dir",
|
|
208
|
+
"-pd",
|
|
209
|
+
help=(
|
|
210
|
+
"Directory to save predictions to. If not an abosulte path it will be "
|
|
211
|
+
"relative to the set working directory."
|
|
212
|
+
),
|
|
213
|
+
file_okay=False,
|
|
214
|
+
dir_okay=True,
|
|
215
|
+
),
|
|
216
|
+
] = Path("predictions"),
|
|
217
|
+
):
|
|
127
218
|
"""Create and save predictions from CAREamics models."""
|
|
128
|
-
|
|
129
|
-
|
|
219
|
+
engine = CAREamist(source=model, work_dir=work_dir)
|
|
220
|
+
engine.predict_to_disk(
|
|
221
|
+
source=source,
|
|
222
|
+
batch_size=batch_size,
|
|
223
|
+
tile_size=tile_size,
|
|
224
|
+
tile_overlap=tile_overlap,
|
|
225
|
+
axes=axes,
|
|
226
|
+
data_type=data_type,
|
|
227
|
+
tta_transforms=tta_transforms,
|
|
228
|
+
write_type=write_type,
|
|
229
|
+
prediction_dir=prediction_dir,
|
|
230
|
+
)
|
|
130
231
|
|
|
131
232
|
|
|
132
233
|
def run():
|
careamics/cli/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Utility functions for the CAREamics CLI."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def handle_2D_3D_callback(
|
|
7
|
+
value: Optional[Tuple[int, int, int]]
|
|
8
|
+
) -> Optional[Tuple[int, ...]]:
|
|
9
|
+
"""
|
|
10
|
+
Callback for options that require 2D or 3D inputs.
|
|
11
|
+
|
|
12
|
+
In the case of 2D, the 3rd element should be set to -1.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
value : (int, int, int)
|
|
17
|
+
Tile size value.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
(int, int, int) | (int, int)
|
|
22
|
+
If the last element in `value` is -1 the tuple is reduced to the first two
|
|
23
|
+
values.
|
|
24
|
+
"""
|
|
25
|
+
if value is None:
|
|
26
|
+
return value
|
|
27
|
+
if value[2] == -1:
|
|
28
|
+
return value[:2]
|
|
29
|
+
return value
|
careamics/config/__init__.py
CHANGED
|
@@ -18,6 +18,7 @@ __all__ = [
|
|
|
18
18
|
"clear_custom_models",
|
|
19
19
|
"GaussianMixtureNMConfig",
|
|
20
20
|
"MultiChannelNMConfig",
|
|
21
|
+
"LVAELossConfig",
|
|
21
22
|
]
|
|
22
23
|
from .architectures import CustomModel, clear_custom_models, register_model
|
|
23
24
|
from .callback_model import CheckpointModel
|
|
@@ -34,6 +35,7 @@ from .configuration_model import (
|
|
|
34
35
|
from .data_model import DataConfig
|
|
35
36
|
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
36
37
|
from .inference_model import InferenceConfig
|
|
38
|
+
from .loss_model import LVAELossConfig
|
|
37
39
|
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
38
40
|
from .training_model import TrainingConfig
|
|
39
41
|
from .vae_algorithm_model import VAEAlgorithmConfig
|