careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/__init__.py
CHANGED
|
@@ -7,7 +7,22 @@ try:
|
|
|
7
7
|
except PackageNotFoundError:
|
|
8
8
|
__version__ = "uninstalled"
|
|
9
9
|
|
|
10
|
-
__all__ = [
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CAREamist",
|
|
12
|
+
"Configuration",
|
|
13
|
+
"algorithm_factory",
|
|
14
|
+
"configuration_factory",
|
|
15
|
+
"data_factory",
|
|
16
|
+
"load_configuration",
|
|
17
|
+
"save_configuration",
|
|
18
|
+
]
|
|
11
19
|
|
|
12
20
|
from .careamist import CAREamist
|
|
13
|
-
from .config import
|
|
21
|
+
from .config import (
|
|
22
|
+
Configuration,
|
|
23
|
+
algorithm_factory,
|
|
24
|
+
configuration_factory,
|
|
25
|
+
data_factory,
|
|
26
|
+
load_configuration,
|
|
27
|
+
save_configuration,
|
|
28
|
+
)
|
careamics/careamist.py
CHANGED
|
@@ -11,16 +11,17 @@ from pytorch_lightning.callbacks import (
|
|
|
11
11
|
EarlyStopping,
|
|
12
12
|
ModelCheckpoint,
|
|
13
13
|
)
|
|
14
|
-
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
14
|
+
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
|
|
15
15
|
|
|
16
|
-
from careamics.config import Configuration,
|
|
16
|
+
from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
|
|
17
17
|
from careamics.config.support import (
|
|
18
18
|
SupportedAlgorithm,
|
|
19
19
|
SupportedArchitecture,
|
|
20
20
|
SupportedData,
|
|
21
21
|
SupportedLogger,
|
|
22
22
|
)
|
|
23
|
-
from careamics.dataset.dataset_utils import reshape_array
|
|
23
|
+
from careamics.dataset.dataset_utils import list_files, reshape_array
|
|
24
|
+
from careamics.file_io import WriteFunc, get_write_func
|
|
24
25
|
from careamics.lightning import (
|
|
25
26
|
FCNModule,
|
|
26
27
|
HyperParametersCallback,
|
|
@@ -32,10 +33,11 @@ from careamics.lightning import (
|
|
|
32
33
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
33
34
|
from careamics.prediction_utils import convert_outputs
|
|
34
35
|
from careamics.utils import check_path_exists, get_logger
|
|
36
|
+
from careamics.utils.lightning_utils import read_csv_logger
|
|
35
37
|
|
|
36
38
|
logger = get_logger(__name__)
|
|
37
39
|
|
|
38
|
-
LOGGER_TYPES =
|
|
40
|
+
LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class CAREamist:
|
|
@@ -135,7 +137,7 @@ class CAREamist:
|
|
|
135
137
|
self.cfg = source
|
|
136
138
|
|
|
137
139
|
# instantiate model
|
|
138
|
-
if isinstance(self.cfg.algorithm_config,
|
|
140
|
+
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
|
|
139
141
|
self.model = FCNModule(
|
|
140
142
|
algorithm_config=self.cfg.algorithm_config,
|
|
141
143
|
)
|
|
@@ -144,6 +146,7 @@ class CAREamist:
|
|
|
144
146
|
|
|
145
147
|
# path to configuration file or model
|
|
146
148
|
else:
|
|
149
|
+
# TODO: update this check so models can be downloaded directly from BMZ
|
|
147
150
|
source = check_path_exists(source)
|
|
148
151
|
|
|
149
152
|
# configuration file
|
|
@@ -154,7 +157,8 @@ class CAREamist:
|
|
|
154
157
|
self.cfg = load_configuration(source)
|
|
155
158
|
|
|
156
159
|
# instantiate model
|
|
157
|
-
|
|
160
|
+
# TODO call model factory here
|
|
161
|
+
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
|
|
158
162
|
self.model = FCNModule(
|
|
159
163
|
algorithm_config=self.cfg.algorithm_config,
|
|
160
164
|
) # type: ignore
|
|
@@ -169,18 +173,29 @@ class CAREamist:
|
|
|
169
173
|
self._define_callbacks(callbacks)
|
|
170
174
|
|
|
171
175
|
# instantiate logger
|
|
176
|
+
csv_logger = CSVLogger(
|
|
177
|
+
name=self.cfg.experiment_name,
|
|
178
|
+
save_dir=self.work_dir / "csv_logs",
|
|
179
|
+
)
|
|
180
|
+
|
|
172
181
|
if self.cfg.training_config.has_logger():
|
|
173
182
|
if self.cfg.training_config.logger == SupportedLogger.WANDB:
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
183
|
+
experiment_logger: LOGGER_TYPES = [
|
|
184
|
+
WandbLogger(
|
|
185
|
+
name=self.cfg.experiment_name,
|
|
186
|
+
save_dir=self.work_dir / Path("wandb_logs"),
|
|
187
|
+
),
|
|
188
|
+
csv_logger,
|
|
189
|
+
]
|
|
178
190
|
elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
191
|
+
experiment_logger = [
|
|
192
|
+
TensorBoardLogger(
|
|
193
|
+
save_dir=self.work_dir / Path("tb_logs"),
|
|
194
|
+
),
|
|
195
|
+
csv_logger,
|
|
196
|
+
]
|
|
182
197
|
else:
|
|
183
|
-
|
|
198
|
+
experiment_logger = [csv_logger]
|
|
184
199
|
|
|
185
200
|
# instantiate trainer
|
|
186
201
|
self.trainer = Trainer(
|
|
@@ -194,7 +209,7 @@ class CAREamist:
|
|
|
194
209
|
gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
|
|
195
210
|
callbacks=self.callbacks,
|
|
196
211
|
default_root_dir=self.work_dir,
|
|
197
|
-
logger=
|
|
212
|
+
logger=experiment_logger,
|
|
198
213
|
)
|
|
199
214
|
|
|
200
215
|
# place holder for the datamodules
|
|
@@ -519,7 +534,7 @@ class CAREamist:
|
|
|
519
534
|
*,
|
|
520
535
|
batch_size: int = 1,
|
|
521
536
|
tile_size: Optional[tuple[int, ...]] = None,
|
|
522
|
-
tile_overlap: tuple[int, ...] = (48, 48),
|
|
537
|
+
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
523
538
|
axes: Optional[str] = None,
|
|
524
539
|
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
525
540
|
tta_transforms: bool = False,
|
|
@@ -535,7 +550,7 @@ class CAREamist:
|
|
|
535
550
|
*,
|
|
536
551
|
batch_size: int = 1,
|
|
537
552
|
tile_size: Optional[tuple[int, ...]] = None,
|
|
538
|
-
tile_overlap: tuple[int, ...] = (48, 48),
|
|
553
|
+
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
539
554
|
axes: Optional[str] = None,
|
|
540
555
|
data_type: Optional[Literal["array"]] = None,
|
|
541
556
|
tta_transforms: bool = False,
|
|
@@ -546,7 +561,7 @@ class CAREamist:
|
|
|
546
561
|
self,
|
|
547
562
|
source: Union[PredictDataModule, Path, str, NDArray],
|
|
548
563
|
*,
|
|
549
|
-
batch_size:
|
|
564
|
+
batch_size: int = 1,
|
|
550
565
|
tile_size: Optional[tuple[int, ...]] = None,
|
|
551
566
|
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
552
567
|
axes: Optional[str] = None,
|
|
@@ -567,7 +582,7 @@ class CAREamist:
|
|
|
567
582
|
configuration parameters will be used, with the `patch_size` instead of
|
|
568
583
|
`tile_size`.
|
|
569
584
|
|
|
570
|
-
Test-time augmentation (TTA) can be switched
|
|
585
|
+
Test-time augmentation (TTA) can be switched on using the `tta_transforms`
|
|
571
586
|
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
572
587
|
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
573
588
|
should not be used if you did not train with these augmentations.
|
|
@@ -580,7 +595,7 @@ class CAREamist:
|
|
|
580
595
|
|
|
581
596
|
Parameters
|
|
582
597
|
----------
|
|
583
|
-
source :
|
|
598
|
+
source : PredictDataModule, pathlib.Path, str or numpy.ndarray
|
|
584
599
|
Data to predict on.
|
|
585
600
|
batch_size : int, default=1
|
|
586
601
|
Batch size for prediction.
|
|
@@ -668,15 +683,195 @@ class CAREamist:
|
|
|
668
683
|
)
|
|
669
684
|
return convert_outputs(predictions, self.pred_datamodule.tiled)
|
|
670
685
|
|
|
686
|
+
def predict_to_disk(
|
|
687
|
+
self,
|
|
688
|
+
source: Union[PredictDataModule, Path, str],
|
|
689
|
+
*,
|
|
690
|
+
batch_size: int = 1,
|
|
691
|
+
tile_size: Optional[tuple[int, ...]] = None,
|
|
692
|
+
tile_overlap: Optional[tuple[int, ...]] = (48, 48),
|
|
693
|
+
axes: Optional[str] = None,
|
|
694
|
+
data_type: Optional[Literal["tiff", "custom"]] = None,
|
|
695
|
+
tta_transforms: bool = False,
|
|
696
|
+
dataloader_params: Optional[dict] = None,
|
|
697
|
+
read_source_func: Optional[Callable] = None,
|
|
698
|
+
extension_filter: str = "",
|
|
699
|
+
write_type: Literal["tiff", "custom"] = "tiff",
|
|
700
|
+
write_extension: Optional[str] = None,
|
|
701
|
+
write_func: Optional[WriteFunc] = None,
|
|
702
|
+
write_func_kwargs: Optional[dict[str, Any]] = None,
|
|
703
|
+
prediction_dir: Union[Path, str] = "predictions",
|
|
704
|
+
**kwargs,
|
|
705
|
+
) -> None:
|
|
706
|
+
"""
|
|
707
|
+
Make predictions on the provided data and save outputs to files.
|
|
708
|
+
|
|
709
|
+
The predictions will be saved in a new directory 'predictions' within the set
|
|
710
|
+
working directory. The directory stucture within the 'predictions' directory
|
|
711
|
+
will match that of the source directory.
|
|
712
|
+
|
|
713
|
+
The `source` must be from files and not arrays. The file names of the
|
|
714
|
+
predictions will match those of the source. If there is more than one sample
|
|
715
|
+
within a file, the samples will be saved to seperate files. The file names of
|
|
716
|
+
samples will have the name of the corresponding source file but with the sample
|
|
717
|
+
index appended. E.g. If the the source file name is 'images.tiff' then the first
|
|
718
|
+
sample's prediction will be saved with the file name "image_0.tiff".
|
|
719
|
+
Input can be a PredictDataModule instance, a path to a data file, or a numpy
|
|
720
|
+
array.
|
|
721
|
+
|
|
722
|
+
If `data_type`, `axes` and `tile_size` are not provided, the training
|
|
723
|
+
configuration parameters will be used, with the `patch_size` instead of
|
|
724
|
+
`tile_size`.
|
|
725
|
+
|
|
726
|
+
Test-time augmentation (TTA) can be switched on using the `tta_transforms`
|
|
727
|
+
parameter. The TTA augmentation applies all possible flip and 90 degrees
|
|
728
|
+
rotations to the prediction input and averages the predictions. TTA augmentation
|
|
729
|
+
should not be used if you did not train with these augmentations.
|
|
730
|
+
|
|
731
|
+
Note that if you are using a UNet model and tiling, the tile size must be
|
|
732
|
+
divisible in every dimension by 2**d, where d is the depth of the model. This
|
|
733
|
+
avoids artefacts arising from the broken shift invariance induced by the
|
|
734
|
+
pooling layers of the UNet. If your image has less dimensions, as it may
|
|
735
|
+
happen in the Z dimension, consider padding your image.
|
|
736
|
+
|
|
737
|
+
Parameters
|
|
738
|
+
----------
|
|
739
|
+
source : PredictDataModule or pathlib.Path, str
|
|
740
|
+
Data to predict on.
|
|
741
|
+
batch_size : int, default=1
|
|
742
|
+
Batch size for prediction.
|
|
743
|
+
tile_size : tuple of int, optional
|
|
744
|
+
Size of the tiles to use for prediction.
|
|
745
|
+
tile_overlap : tuple of int, default=(48, 48)
|
|
746
|
+
Overlap between tiles.
|
|
747
|
+
axes : str, optional
|
|
748
|
+
Axes of the input data, by default None.
|
|
749
|
+
data_type : {"array", "tiff", "custom"}, optional
|
|
750
|
+
Type of the input data.
|
|
751
|
+
tta_transforms : bool, default=True
|
|
752
|
+
Whether to apply test-time augmentation.
|
|
753
|
+
dataloader_params : dict, optional
|
|
754
|
+
Parameters to pass to the dataloader.
|
|
755
|
+
read_source_func : Callable, optional
|
|
756
|
+
Function to read the source data.
|
|
757
|
+
extension_filter : str, default=""
|
|
758
|
+
Filter for the file extension.
|
|
759
|
+
write_type : {"tiff", "custom"}, default="tiff"
|
|
760
|
+
The data type to save as, includes custom.
|
|
761
|
+
write_extension : str, optional
|
|
762
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
763
|
+
`write_type` an extension to save the data with must be passed.
|
|
764
|
+
write_func : WriteFunc, optional
|
|
765
|
+
If a known `write_type` is selected this argument is ignored. For a custom
|
|
766
|
+
`write_type` a function to save the data must be passed. See notes below.
|
|
767
|
+
write_func_kwargs : dict of {str: any}, optional
|
|
768
|
+
Additional keyword arguments to be passed to the save function.
|
|
769
|
+
prediction_dir : Path | str, default="predictions"
|
|
770
|
+
The path to save the prediction results to. If `prediction_dir` is not
|
|
771
|
+
absolute, the directory will be assumed to be relative to the pre-set
|
|
772
|
+
`work_dir`. If the directory does not exist it will be created.
|
|
773
|
+
**kwargs : Any
|
|
774
|
+
Unused.
|
|
775
|
+
|
|
776
|
+
Raises
|
|
777
|
+
------
|
|
778
|
+
ValueError
|
|
779
|
+
If `write_type` is custom and `write_extension` is None.
|
|
780
|
+
ValueError
|
|
781
|
+
If `write_type` is custom and `write_fun is None.
|
|
782
|
+
ValueError
|
|
783
|
+
If `source` is not `str`, `Path` or `PredictDataModule`
|
|
784
|
+
"""
|
|
785
|
+
if write_func_kwargs is None:
|
|
786
|
+
write_func_kwargs = {}
|
|
787
|
+
|
|
788
|
+
if Path(prediction_dir).is_absolute():
|
|
789
|
+
write_dir = Path(prediction_dir)
|
|
790
|
+
else:
|
|
791
|
+
write_dir = self.work_dir / prediction_dir
|
|
792
|
+
write_dir.mkdir(exist_ok=True, parents=True)
|
|
793
|
+
|
|
794
|
+
# guards for custom types
|
|
795
|
+
if write_type == SupportedData.CUSTOM:
|
|
796
|
+
if write_extension is None:
|
|
797
|
+
raise ValueError(
|
|
798
|
+
"A `write_extension` must be provided for custom write types."
|
|
799
|
+
)
|
|
800
|
+
if write_func is None:
|
|
801
|
+
raise ValueError(
|
|
802
|
+
"A `write_func` must be provided for custom write types."
|
|
803
|
+
)
|
|
804
|
+
else:
|
|
805
|
+
write_func = get_write_func(write_type)
|
|
806
|
+
write_extension = SupportedData.get_extension(write_type)
|
|
807
|
+
|
|
808
|
+
# extract file names
|
|
809
|
+
source_path: Union[Path, str, NDArray]
|
|
810
|
+
source_data_type: Literal["array", "tiff", "custom"]
|
|
811
|
+
if isinstance(source, PredictDataModule):
|
|
812
|
+
source_path = source.pred_data
|
|
813
|
+
source_data_type = source.data_type
|
|
814
|
+
extension_filter = source.extension_filter
|
|
815
|
+
elif isinstance(source, (str, Path)):
|
|
816
|
+
source_path = source
|
|
817
|
+
source_data_type = data_type or self.cfg.data_config.data_type
|
|
818
|
+
extension_filter = SupportedData.get_extension_pattern(
|
|
819
|
+
SupportedData(source_data_type)
|
|
820
|
+
)
|
|
821
|
+
else:
|
|
822
|
+
raise ValueError(f"Unsupported source type: '{type(source)}'.")
|
|
823
|
+
|
|
824
|
+
if source_data_type == "array":
|
|
825
|
+
raise ValueError(
|
|
826
|
+
"Predicting to disk is not supported for input type 'array'."
|
|
827
|
+
)
|
|
828
|
+
assert isinstance(source_path, (Path, str)) # because data_type != "array"
|
|
829
|
+
source_path = Path(source_path)
|
|
830
|
+
|
|
831
|
+
file_paths = list_files(source_path, source_data_type, extension_filter)
|
|
832
|
+
|
|
833
|
+
# predict and write each file in turn
|
|
834
|
+
for file_path in file_paths:
|
|
835
|
+
# source_path is relative to original source path...
|
|
836
|
+
# should mirror original directory structure
|
|
837
|
+
prediction = self.predict(
|
|
838
|
+
source=file_path,
|
|
839
|
+
batch_size=batch_size,
|
|
840
|
+
tile_size=tile_size,
|
|
841
|
+
tile_overlap=tile_overlap,
|
|
842
|
+
axes=axes,
|
|
843
|
+
data_type=data_type,
|
|
844
|
+
tta_transforms=tta_transforms,
|
|
845
|
+
dataloader_params=dataloader_params,
|
|
846
|
+
read_source_func=read_source_func,
|
|
847
|
+
extension_filter=extension_filter,
|
|
848
|
+
**kwargs,
|
|
849
|
+
)
|
|
850
|
+
# TODO: cast to float16?
|
|
851
|
+
write_data = np.concatenate(prediction)
|
|
852
|
+
|
|
853
|
+
# create directory structure and write path
|
|
854
|
+
if not source_path.is_file():
|
|
855
|
+
file_write_dir = write_dir / file_path.parent.relative_to(source_path)
|
|
856
|
+
else:
|
|
857
|
+
file_write_dir = write_dir
|
|
858
|
+
file_write_dir.mkdir(parents=True, exist_ok=True)
|
|
859
|
+
write_path = (file_write_dir / file_path.name).with_suffix(write_extension)
|
|
860
|
+
|
|
861
|
+
# write data
|
|
862
|
+
write_func(file_path=write_path, img=write_data)
|
|
863
|
+
|
|
671
864
|
def export_to_bmz(
|
|
672
865
|
self,
|
|
673
866
|
path_to_archive: Union[Path, str],
|
|
674
867
|
friendly_model_name: str,
|
|
675
868
|
input_array: NDArray,
|
|
676
869
|
authors: list[dict],
|
|
677
|
-
general_description: str
|
|
870
|
+
general_description: str,
|
|
871
|
+
data_description: str,
|
|
872
|
+
covers: Optional[list[Union[Path, str]]] = None,
|
|
678
873
|
channel_names: Optional[list[str]] = None,
|
|
679
|
-
|
|
874
|
+
model_version: str = "0.1.0",
|
|
680
875
|
) -> None:
|
|
681
876
|
"""Export the model to the BioImage Model Zoo format.
|
|
682
877
|
|
|
@@ -706,11 +901,15 @@ class CAREamist:
|
|
|
706
901
|
authors : list of dict
|
|
707
902
|
List of authors of the model.
|
|
708
903
|
general_description : str
|
|
709
|
-
General description of the model
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
904
|
+
General description of the model used in the BMZ metadata.
|
|
905
|
+
data_description : str
|
|
906
|
+
Description of the data the model was trained on.
|
|
907
|
+
covers : list of pathlib.Path or str, default=None
|
|
908
|
+
Paths to the cover images.
|
|
909
|
+
channel_names : list of str, default=None
|
|
910
|
+
Channel names.
|
|
911
|
+
model_version : str, default="0.1.0"
|
|
912
|
+
Version of the model.
|
|
714
913
|
"""
|
|
715
914
|
# TODO: add in docs that it is expected that input_array dimensions match
|
|
716
915
|
# those in data_config
|
|
@@ -729,9 +928,21 @@ class CAREamist:
|
|
|
729
928
|
path_to_archive=path_to_archive,
|
|
730
929
|
model_name=friendly_model_name,
|
|
731
930
|
general_description=general_description,
|
|
931
|
+
data_description=data_description,
|
|
732
932
|
authors=authors,
|
|
733
933
|
input_array=input_array,
|
|
734
934
|
output_array=output,
|
|
935
|
+
covers=covers,
|
|
735
936
|
channel_names=channel_names,
|
|
736
|
-
|
|
937
|
+
model_version=model_version,
|
|
737
938
|
)
|
|
939
|
+
|
|
940
|
+
def get_losses(self) -> dict[str, list]:
|
|
941
|
+
"""Return data that can be used to plot train and validation loss curves.
|
|
942
|
+
|
|
943
|
+
Returns
|
|
944
|
+
-------
|
|
945
|
+
dict of str: list
|
|
946
|
+
Dictionary containing the losses for each epoch.
|
|
947
|
+
"""
|
|
948
|
+
return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")
|
careamics/cli/conf.py
CHANGED
|
@@ -3,12 +3,11 @@
|
|
|
3
3
|
import sys
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Annotated, Optional
|
|
7
7
|
|
|
8
8
|
import click
|
|
9
9
|
import typer
|
|
10
10
|
import yaml
|
|
11
|
-
from typing_extensions import Annotated
|
|
12
11
|
|
|
13
12
|
from ..config import (
|
|
14
13
|
Configuration,
|
|
@@ -17,6 +16,7 @@ from ..config import (
|
|
|
17
16
|
create_n2v_configuration,
|
|
18
17
|
save_configuration,
|
|
19
18
|
)
|
|
19
|
+
from .utils import handle_2D_3D_callback
|
|
20
20
|
|
|
21
21
|
WORK_DIR = Path.cwd()
|
|
22
22
|
|
|
@@ -92,26 +92,6 @@ def conf_options( # numpydoc ignore=PR01
|
|
|
92
92
|
ctx.obj = ConfOptions(dir, name, force, print)
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def patch_size_callback(value: Tuple[int, int, int]) -> Tuple[int, ...]:
|
|
96
|
-
"""
|
|
97
|
-
Callback for --patch-size option.
|
|
98
|
-
|
|
99
|
-
Parameters
|
|
100
|
-
----------
|
|
101
|
-
value : (int, int, int)
|
|
102
|
-
Patch size value.
|
|
103
|
-
|
|
104
|
-
Returns
|
|
105
|
-
-------
|
|
106
|
-
(int, int, int) | (int, int)
|
|
107
|
-
If the last element in `value` is -1 the tuple is reduced to the first two
|
|
108
|
-
values.
|
|
109
|
-
"""
|
|
110
|
-
if value[2] == -1:
|
|
111
|
-
return value[:2]
|
|
112
|
-
return value
|
|
113
|
-
|
|
114
|
-
|
|
115
95
|
# TODO: Need to decide how to parse model kwargs
|
|
116
96
|
# - Could be json style string to be loaded as dict e.g. {"depth": 3}
|
|
117
97
|
# - Cons: Annoying to type, easily have syntax errors
|
|
@@ -132,7 +112,7 @@ def care( # numpydoc ignore=PR01
|
|
|
132
112
|
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
133
113
|
),
|
|
134
114
|
click_type=click.Tuple([int, int, int]),
|
|
135
|
-
callback=
|
|
115
|
+
callback=handle_2D_3D_callback,
|
|
136
116
|
),
|
|
137
117
|
],
|
|
138
118
|
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
@@ -154,8 +134,12 @@ def care( # numpydoc ignore=PR01
|
|
|
154
134
|
help="Loss function to use.",
|
|
155
135
|
),
|
|
156
136
|
] = "mae",
|
|
157
|
-
n_channels_in: Annotated[
|
|
158
|
-
|
|
137
|
+
n_channels_in: Annotated[
|
|
138
|
+
Optional[int], typer.Option(help="Number of channels in")
|
|
139
|
+
] = None,
|
|
140
|
+
n_channels_out: Annotated[
|
|
141
|
+
Optional[int], typer.Option(help="Number of channels out")
|
|
142
|
+
] = None,
|
|
159
143
|
logger: Annotated[
|
|
160
144
|
click.Choice,
|
|
161
145
|
typer.Option(
|
|
@@ -215,7 +199,7 @@ def n2n( # numpydoc ignore=PR01
|
|
|
215
199
|
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
216
200
|
),
|
|
217
201
|
click_type=click.Tuple([int, int, int]),
|
|
218
|
-
callback=
|
|
202
|
+
callback=handle_2D_3D_callback,
|
|
219
203
|
),
|
|
220
204
|
],
|
|
221
205
|
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
@@ -237,8 +221,12 @@ def n2n( # numpydoc ignore=PR01
|
|
|
237
221
|
help="Loss function to use.",
|
|
238
222
|
),
|
|
239
223
|
] = "mae",
|
|
240
|
-
n_channels_in: Annotated[
|
|
241
|
-
|
|
224
|
+
n_channels_in: Annotated[
|
|
225
|
+
Optional[int], typer.Option(help="Number of channels in")
|
|
226
|
+
] = None,
|
|
227
|
+
n_channels_out: Annotated[
|
|
228
|
+
Optional[int], typer.Option(help="Number of channels out")
|
|
229
|
+
] = None,
|
|
242
230
|
logger: Annotated[
|
|
243
231
|
click.Choice,
|
|
244
232
|
typer.Option(
|
|
@@ -295,7 +283,7 @@ def n2v( # numpydoc ignore=PR01
|
|
|
295
283
|
"is not 3D pass the last value as -1 e.g. --patch-size 64 64 -1)."
|
|
296
284
|
),
|
|
297
285
|
click_type=click.Tuple([int, int, int]),
|
|
298
|
-
callback=
|
|
286
|
+
callback=handle_2D_3D_callback,
|
|
299
287
|
),
|
|
300
288
|
],
|
|
301
289
|
batch_size: Annotated[int, typer.Option(help="Batch size.")],
|
|
@@ -312,8 +300,8 @@ def n2v( # numpydoc ignore=PR01
|
|
|
312
300
|
] = True,
|
|
313
301
|
use_n2v2: Annotated[bool, typer.Option(help="Whether to use N2V2")] = False,
|
|
314
302
|
n_channels: Annotated[
|
|
315
|
-
int, typer.Option(help="Number of channels (in and out)")
|
|
316
|
-
] =
|
|
303
|
+
Optional[int], typer.Option(help="Number of channels (in and out)")
|
|
304
|
+
] = None,
|
|
317
305
|
roi_size: Annotated[int, typer.Option(help="N2V pixel manipulation area.")] = 11,
|
|
318
306
|
masked_pixel_percentage: Annotated[
|
|
319
307
|
float, typer.Option(help="Percentage of pixels masked in each patch.")
|
careamics/cli/main.py
CHANGED
|
@@ -7,23 +7,21 @@ its implementation is contained in the conf.py file.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Optional
|
|
10
|
+
from typing import Annotated, Optional
|
|
11
11
|
|
|
12
|
+
import click
|
|
12
13
|
import typer
|
|
13
|
-
from typing_extensions import Annotated
|
|
14
14
|
|
|
15
15
|
from ..careamist import CAREamist
|
|
16
16
|
from . import conf
|
|
17
|
+
from .utils import handle_2D_3D_callback
|
|
17
18
|
|
|
18
19
|
app = typer.Typer(
|
|
19
20
|
help="Run CAREamics algorithms from the command line, including Noise2Void "
|
|
20
|
-
"and its many variants and cousins"
|
|
21
|
-
|
|
22
|
-
app.add_typer(
|
|
23
|
-
conf.app,
|
|
24
|
-
name="conf",
|
|
25
|
-
# callback=conf.conf_options
|
|
21
|
+
"and its many variants and cousins",
|
|
22
|
+
pretty_exceptions_show_locals=False,
|
|
26
23
|
)
|
|
24
|
+
app.add_typer(conf.app, name="conf")
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
@app.command()
|
|
@@ -102,7 +100,7 @@ def train( # numpydoc ignore=PR01
|
|
|
102
100
|
typer.Option(
|
|
103
101
|
"--work-dir",
|
|
104
102
|
"-wd",
|
|
105
|
-
help=("Path to working directory in which to save checkpoints and
|
|
103
|
+
help=("Path to working directory in which to save checkpoints and logs"),
|
|
106
104
|
exists=True,
|
|
107
105
|
file_okay=False,
|
|
108
106
|
dir_okay=True,
|
|
@@ -123,10 +121,112 @@ def train( # numpydoc ignore=PR01
|
|
|
123
121
|
|
|
124
122
|
|
|
125
123
|
@app.command()
|
|
126
|
-
def predict(
|
|
124
|
+
def predict( # numpydoc ignore=PR01
|
|
125
|
+
model: Annotated[
|
|
126
|
+
Path,
|
|
127
|
+
typer.Argument(
|
|
128
|
+
help="Path to a configuration file or a trained model.",
|
|
129
|
+
exists=True,
|
|
130
|
+
file_okay=True,
|
|
131
|
+
dir_okay=False,
|
|
132
|
+
),
|
|
133
|
+
],
|
|
134
|
+
source: Annotated[
|
|
135
|
+
Path,
|
|
136
|
+
typer.Argument(
|
|
137
|
+
help="Path to the training data. Can be a directory or single file.",
|
|
138
|
+
exists=True,
|
|
139
|
+
file_okay=True,
|
|
140
|
+
dir_okay=True,
|
|
141
|
+
),
|
|
142
|
+
],
|
|
143
|
+
batch_size: Annotated[int, typer.Option(help="Batch size.")] = 1,
|
|
144
|
+
tile_size: Annotated[
|
|
145
|
+
Optional[click.Tuple],
|
|
146
|
+
typer.Option(
|
|
147
|
+
help=(
|
|
148
|
+
"Size of the tiles to use for prediction, (if the data "
|
|
149
|
+
"is not 3D pass the last value as -1 e.g. --tile_size 64 64 -1)."
|
|
150
|
+
),
|
|
151
|
+
click_type=click.Tuple([int, int, int]),
|
|
152
|
+
callback=handle_2D_3D_callback,
|
|
153
|
+
),
|
|
154
|
+
] = None,
|
|
155
|
+
tile_overlap: Annotated[
|
|
156
|
+
click.Tuple,
|
|
157
|
+
typer.Option(
|
|
158
|
+
help=(
|
|
159
|
+
"Overlap between tiles, (if the data is not 3D pass the last value as "
|
|
160
|
+
"-1 e.g. --tile_overlap 64 64 -1)."
|
|
161
|
+
),
|
|
162
|
+
click_type=click.Tuple([int, int, int]),
|
|
163
|
+
callback=handle_2D_3D_callback,
|
|
164
|
+
),
|
|
165
|
+
] = (48, 48, -1),
|
|
166
|
+
axes: Annotated[
|
|
167
|
+
Optional[str],
|
|
168
|
+
typer.Option(
|
|
169
|
+
help="Axes of the input data. If unused the data is assumed to have the "
|
|
170
|
+
"same axes as the original training data."
|
|
171
|
+
),
|
|
172
|
+
] = None,
|
|
173
|
+
data_type: Annotated[
|
|
174
|
+
click.Choice,
|
|
175
|
+
typer.Option(click_type=click.Choice(["tiff"]), help="Type of the input data."),
|
|
176
|
+
] = "tiff",
|
|
177
|
+
tta_transforms: Annotated[
|
|
178
|
+
bool,
|
|
179
|
+
typer.Option(
|
|
180
|
+
"--tta-transforms/--no-tta-transforms",
|
|
181
|
+
"-t/-T",
|
|
182
|
+
help="Whether to apply test-time augmentation.",
|
|
183
|
+
),
|
|
184
|
+
] = False,
|
|
185
|
+
write_type: Annotated[
|
|
186
|
+
click.Choice,
|
|
187
|
+
typer.Option(
|
|
188
|
+
click_type=click.Choice(["tiff"]), help="Type of the output data."
|
|
189
|
+
),
|
|
190
|
+
] = "tiff",
|
|
191
|
+
# TODO: could make dataloader_params as json, necessary?
|
|
192
|
+
work_dir: Annotated[
|
|
193
|
+
Optional[Path],
|
|
194
|
+
typer.Option(
|
|
195
|
+
"--work-dir",
|
|
196
|
+
"-wd",
|
|
197
|
+
help=("Path to working directory."),
|
|
198
|
+
exists=True,
|
|
199
|
+
file_okay=False,
|
|
200
|
+
dir_okay=True,
|
|
201
|
+
),
|
|
202
|
+
] = None,
|
|
203
|
+
prediction_dir: Annotated[
|
|
204
|
+
Path,
|
|
205
|
+
typer.Option(
|
|
206
|
+
"--prediction-dir",
|
|
207
|
+
"-pd",
|
|
208
|
+
help=(
|
|
209
|
+
"Directory to save predictions to. If not an abosulte path it will be "
|
|
210
|
+
"relative to the set working directory."
|
|
211
|
+
),
|
|
212
|
+
file_okay=False,
|
|
213
|
+
dir_okay=True,
|
|
214
|
+
),
|
|
215
|
+
] = Path("predictions"),
|
|
216
|
+
):
|
|
127
217
|
"""Create and save predictions from CAREamics models."""
|
|
128
|
-
|
|
129
|
-
|
|
218
|
+
engine = CAREamist(source=model, work_dir=work_dir)
|
|
219
|
+
engine.predict_to_disk(
|
|
220
|
+
source=source,
|
|
221
|
+
batch_size=batch_size,
|
|
222
|
+
tile_size=tile_size,
|
|
223
|
+
tile_overlap=tile_overlap,
|
|
224
|
+
axes=axes,
|
|
225
|
+
data_type=data_type,
|
|
226
|
+
tta_transforms=tta_transforms,
|
|
227
|
+
write_type=write_type,
|
|
228
|
+
prediction_dir=prediction_dir,
|
|
229
|
+
)
|
|
130
230
|
|
|
131
231
|
|
|
132
232
|
def run():
|