careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""Utility functions for file and paths solver."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
from numpy import ndarray
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
from careamics.config.support import SupportedData
|
|
11
|
+
from careamics.dataset.dataset_utils import list_files, validate_source_target_files
|
|
12
|
+
from careamics.dataset_ng.image_stack_loader.zarr_utils import is_valid_uri
|
|
13
|
+
|
|
14
|
+
ItemType = Path | str | NDArray[Any]
|
|
15
|
+
"""Type of input items passed to the dataset."""
|
|
16
|
+
|
|
17
|
+
InputType = ItemType | Sequence[ItemType] | None
|
|
18
|
+
"""Type of input data passed to the dataset."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def list_files_in_directory(
|
|
22
|
+
data_type: Literal["tiff", "zarr", "czi", "custom"],
|
|
23
|
+
input_data,
|
|
24
|
+
target_data=None,
|
|
25
|
+
extension_filter: str = "",
|
|
26
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
27
|
+
"""List files from input and target directories.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
data_type : Literal["tiff", "zarr", "czi", "custom"]
|
|
32
|
+
The type of data to validate.
|
|
33
|
+
input_data : InputType
|
|
34
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
35
|
+
target_data : Optional[InputType]
|
|
36
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
37
|
+
array.
|
|
38
|
+
extension_filter : str, default=""
|
|
39
|
+
File extension filter to apply when listing files.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
list[Path]
|
|
44
|
+
A list of file paths for input data.
|
|
45
|
+
list[Path] | None
|
|
46
|
+
A list of file paths for target data, or None if target_data is None.
|
|
47
|
+
"""
|
|
48
|
+
input_data = Path(input_data)
|
|
49
|
+
|
|
50
|
+
# list_files will return a list with a single element if the path is a file with
|
|
51
|
+
# the correct extension
|
|
52
|
+
input_files = list_files(input_data, data_type, extension_filter)
|
|
53
|
+
if target_data is None:
|
|
54
|
+
return input_files, None
|
|
55
|
+
else:
|
|
56
|
+
target_data = Path(target_data)
|
|
57
|
+
target_files = list_files(target_data, data_type, extension_filter)
|
|
58
|
+
validate_source_target_files(input_files, target_files)
|
|
59
|
+
return input_files, target_files
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def convert_paths_to_pathlib(
|
|
63
|
+
input_data: Sequence[str | Path],
|
|
64
|
+
target_data: Sequence[str | Path] | None = None,
|
|
65
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
66
|
+
"""Create a list of file paths from the input and target data.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
input_data : Sequence[str | Path]
|
|
71
|
+
Input data, can be a path to a folder, or a list of paths.
|
|
72
|
+
target_data : Sequence[str | Path] | None
|
|
73
|
+
Target data, can be None, a path to a folder, or a list of paths.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
list[Path]
|
|
78
|
+
A list of file paths for input data.
|
|
79
|
+
list[Path] | None
|
|
80
|
+
A list of file paths for target data, or None if target_data is None.
|
|
81
|
+
"""
|
|
82
|
+
input_files = [Path(item) if isinstance(item, str) else item for item in input_data]
|
|
83
|
+
if target_data is None:
|
|
84
|
+
return input_files, None
|
|
85
|
+
else:
|
|
86
|
+
target_files = [
|
|
87
|
+
Path(item) if isinstance(item, str) else item for item in target_data
|
|
88
|
+
]
|
|
89
|
+
validate_source_target_files(input_files, target_files)
|
|
90
|
+
return input_files, target_files
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def validate_input_target_type_consistency(
|
|
94
|
+
input_data: InputType,
|
|
95
|
+
target_data: InputType | None,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Validate if the input and target data types are consistent.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
input_data : InputType
|
|
102
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
103
|
+
target_data : Optional[InputType]
|
|
104
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
105
|
+
array.
|
|
106
|
+
|
|
107
|
+
Raises
|
|
108
|
+
------
|
|
109
|
+
ValueError
|
|
110
|
+
If the input and target data types are not consistent.
|
|
111
|
+
"""
|
|
112
|
+
if input_data is not None and target_data is not None:
|
|
113
|
+
if not isinstance(input_data, type(target_data)):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Inputs for input and target must be of the same type or None. "
|
|
116
|
+
f"Got {type(input_data)} and {type(target_data)}."
|
|
117
|
+
)
|
|
118
|
+
if isinstance(input_data, list) and isinstance(target_data, list):
|
|
119
|
+
if len(input_data) != len(target_data):
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Inputs and targets must have the same length. "
|
|
122
|
+
f"Got {len(input_data)} and {len(target_data)}."
|
|
123
|
+
)
|
|
124
|
+
if not isinstance(input_data[0], type(target_data[0])):
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Inputs and targets must have the same type. "
|
|
127
|
+
f"Got {type(input_data[0])} and {type(target_data[0])}."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def validate_array_input(
|
|
132
|
+
input_data: NDArray | list[NDArray],
|
|
133
|
+
target_data: NDArray | list[NDArray] | None,
|
|
134
|
+
) -> tuple[list[NDArray], list[NDArray] | None]:
|
|
135
|
+
"""Validate if the input data is a numpy array.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
input_data : InputType
|
|
140
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
141
|
+
target_data : Optional[InputType]
|
|
142
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
143
|
+
array.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
list[numpy.ndarray]
|
|
148
|
+
Validated input data.
|
|
149
|
+
list[numpy.ndarray] | None
|
|
150
|
+
Validated target data, None if the target data is None.
|
|
151
|
+
|
|
152
|
+
Raises
|
|
153
|
+
------
|
|
154
|
+
ValueError
|
|
155
|
+
If the input data is not a numpy array or a list of numpy arrays.
|
|
156
|
+
"""
|
|
157
|
+
if isinstance(input_data, ndarray):
|
|
158
|
+
input_list = [input_data]
|
|
159
|
+
|
|
160
|
+
if target_data is not None and not isinstance(target_data, ndarray):
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Wrong target type. Expected numpy.ndarray, got {type(target_data)}. "
|
|
163
|
+
f"Check the data_type parameter or your inputs."
|
|
164
|
+
)
|
|
165
|
+
target_list = [target_data] if target_data is not None else None
|
|
166
|
+
return input_list, target_list
|
|
167
|
+
elif isinstance(input_data, list):
|
|
168
|
+
# TODO warn if wrong types inside list
|
|
169
|
+
input_list = [array for array in input_data if isinstance(array, ndarray)]
|
|
170
|
+
|
|
171
|
+
if target_data is None:
|
|
172
|
+
target_list = None
|
|
173
|
+
else:
|
|
174
|
+
assert isinstance(target_data, list)
|
|
175
|
+
target_list = [array for array in target_data if isinstance(array, ndarray)]
|
|
176
|
+
return input_list, target_list
|
|
177
|
+
else:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"Wrong input type. Expected numpy.ndarray or list of numpy.ndarray, got "
|
|
180
|
+
f"{type(input_data)}. Check the data_type parameter or your inputs."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def validate_path_input(
|
|
185
|
+
data_type: Literal["tiff", "zarr", "czi", "custom"],
|
|
186
|
+
input_data: str | Path | list[str | Path],
|
|
187
|
+
target_data: str | Path | list[str | Path] | None,
|
|
188
|
+
extension_filter: str = "",
|
|
189
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
190
|
+
"""Validate if the input data is a path or a list of paths.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
data_type : Literal["tiff", "zarr", "czi", "custom"]
|
|
195
|
+
The type of data to validate.
|
|
196
|
+
input_data : str | Path | list[str | Path]
|
|
197
|
+
Input data, can be a path to a folder, a list of paths, or a numpy array.
|
|
198
|
+
target_data : str | Path | list[str | Path] | None
|
|
199
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
200
|
+
array.
|
|
201
|
+
extension_filter : str, default=""
|
|
202
|
+
File extension filter to apply when listing files.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
list[Path]
|
|
207
|
+
A list of file paths for input data.
|
|
208
|
+
list[Path] | None
|
|
209
|
+
A list of file paths for target data, or None if target_data is None.
|
|
210
|
+
|
|
211
|
+
Raises
|
|
212
|
+
------
|
|
213
|
+
ValueError
|
|
214
|
+
If the input data is not a path or a list of paths.
|
|
215
|
+
"""
|
|
216
|
+
if isinstance(input_data, (str, Path)):
|
|
217
|
+
input_list, target_list = list_files_in_directory(
|
|
218
|
+
data_type, input_data, target_data, extension_filter
|
|
219
|
+
)
|
|
220
|
+
return input_list, target_list
|
|
221
|
+
elif isinstance(input_data, list):
|
|
222
|
+
# TODO warn if wrong types inside list
|
|
223
|
+
input_list = [
|
|
224
|
+
Path(item)
|
|
225
|
+
for item in input_data
|
|
226
|
+
if isinstance(item, (str, Path)) and Path(item).exists()
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
target_list = None
|
|
230
|
+
if target_data is not None:
|
|
231
|
+
assert isinstance(target_data, list)
|
|
232
|
+
target_list = [
|
|
233
|
+
Path(item)
|
|
234
|
+
for item in target_data
|
|
235
|
+
if isinstance(item, (str, Path)) and Path(item).exists()
|
|
236
|
+
] # consistency with input is enforced by convert_paths_to_pathlib
|
|
237
|
+
|
|
238
|
+
return convert_paths_to_pathlib(input_list, target_list)
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Wrong input type, expected str or Path or list[str | Path], got "
|
|
242
|
+
f"{type(input_data)}. Check the data_type parameter or your inputs."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def validate_zarr_input(
|
|
247
|
+
input_data: str | Path | list[str | Path],
|
|
248
|
+
target_data: str | Path | list[str | Path] | None,
|
|
249
|
+
) -> tuple[list[str] | list[Path], list[str] | list[Path] | None]:
|
|
250
|
+
"""Validate if the input data corresponds a zarr input.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
input_data : str | Path | list[str | Path]
|
|
255
|
+
Input data, can be a path to a folder, to zarr file, a URI pointing to a zarr
|
|
256
|
+
dataset, or a list.
|
|
257
|
+
target_data : str | Path | list[str | Path] | None
|
|
258
|
+
Target data, can be None.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
list[str] or list[Path]
|
|
263
|
+
A list of zarr URIs or path for input data.
|
|
264
|
+
list[str] or list[Path] | None
|
|
265
|
+
A list of zarr URIs or paths for target data, or None if target_data is None.
|
|
266
|
+
|
|
267
|
+
Raises
|
|
268
|
+
------
|
|
269
|
+
ValueError
|
|
270
|
+
If the input and target data types are not consistent.
|
|
271
|
+
ValueError
|
|
272
|
+
If the input data is not a zarr URI or path, or a list of zarr URIs or paths.
|
|
273
|
+
"""
|
|
274
|
+
# validate_input_target_type_consistency is called beforehand, ensuring the types
|
|
275
|
+
# of input and target are the same
|
|
276
|
+
if isinstance(input_data, (str, Path)):
|
|
277
|
+
if Path(input_data).exists():
|
|
278
|
+
# either a path to a folder or a zarr file
|
|
279
|
+
# path to a folder will trigger collection of all zarr files in that folder
|
|
280
|
+
assert target_data is None or isinstance(target_data, (str, Path))
|
|
281
|
+
if target_data is not None and not Path(target_data).exists():
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"Target provided as path, but does not exist: {target_data}."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
return validate_path_input("zarr", input_data, target_data)
|
|
287
|
+
elif isinstance(input_data, str) and is_valid_uri(input_data):
|
|
288
|
+
input_list = [input_data]
|
|
289
|
+
|
|
290
|
+
assert target_data is None or isinstance(target_data, str)
|
|
291
|
+
if target_data is not None and not is_valid_uri(target_data):
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"Wrong target type for zarr data. Expected a zarr URI, got "
|
|
294
|
+
f"{type(target_data)}."
|
|
295
|
+
)
|
|
296
|
+
target_list = [target_data] if target_data is not None else None
|
|
297
|
+
return input_list, target_list
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"Wrong input type for zarr data. Expected a file URI or a path to a "
|
|
301
|
+
f" file, got {input_data}. Path may not exist."
|
|
302
|
+
)
|
|
303
|
+
elif isinstance(input_data, list):
|
|
304
|
+
# use first element as determinant of type
|
|
305
|
+
if isinstance(input_data[0], (str, Path)):
|
|
306
|
+
if Path(input_data[0]).exists():
|
|
307
|
+
return validate_path_input("zarr", input_data, target_data)
|
|
308
|
+
else:
|
|
309
|
+
final_input_list = [
|
|
310
|
+
str(item) for item in input_data if is_valid_uri(item)
|
|
311
|
+
]
|
|
312
|
+
if target_data is not None:
|
|
313
|
+
assert isinstance(target_data, list)
|
|
314
|
+
final_target_list = [
|
|
315
|
+
str(item) for item in target_data if is_valid_uri(item)
|
|
316
|
+
]
|
|
317
|
+
else:
|
|
318
|
+
final_target_list = None
|
|
319
|
+
return final_input_list, final_target_list
|
|
320
|
+
else:
|
|
321
|
+
raise ValueError(
|
|
322
|
+
f"Wrong input type for zarr data. Expected a list of file URIs or "
|
|
323
|
+
f" paths to files, got {type(input_data[0])}."
|
|
324
|
+
)
|
|
325
|
+
else:
|
|
326
|
+
raise ValueError(
|
|
327
|
+
f"Wrong input type for zarr data. Expected a file URI, a path to a file, "
|
|
328
|
+
f" or a list of those, got {type(input_data)}."
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def initialize_data_pair(
|
|
333
|
+
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
|
|
334
|
+
input_data: InputType,
|
|
335
|
+
target_data: InputType | None = None,
|
|
336
|
+
extension_filter: str = "",
|
|
337
|
+
custom_loader: bool = False,
|
|
338
|
+
) -> tuple[InputType | list[InputType], InputType | list[InputType] | None]:
|
|
339
|
+
"""
|
|
340
|
+
Initialize a pair of input and target data.
|
|
341
|
+
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
data_type : Literal["array", "tiff", "zarr", "czi", "custom"]
|
|
345
|
+
The type of data to initialize.
|
|
346
|
+
input_data : InputType
|
|
347
|
+
Input data, can be None, a path to a folder, a list of paths, or a numpy
|
|
348
|
+
array.
|
|
349
|
+
target_data : InputType | None
|
|
350
|
+
Target data, can be None, a path to a folder, a list of paths, or a numpy
|
|
351
|
+
array.
|
|
352
|
+
extension_filter : str, default=""
|
|
353
|
+
File extension filter to apply when listing files.
|
|
354
|
+
custom_loader : bool, default=False
|
|
355
|
+
Whether a custom image stack loader is used.
|
|
356
|
+
|
|
357
|
+
Returns
|
|
358
|
+
-------
|
|
359
|
+
list[numpy.ndarray] | list[pathlib.Path]
|
|
360
|
+
Initialized input data. For file paths, returns a list of Path objects. For
|
|
361
|
+
numpy arrays, returns the arrays directly.
|
|
362
|
+
list[numpy.ndarray] | list[pathlib.Path] | None
|
|
363
|
+
Initialized target data. For file paths, returns a list of Path objects. For
|
|
364
|
+
numpy arrays, returns the arrays directly. Returns None if target_data is None.
|
|
365
|
+
"""
|
|
366
|
+
if input_data is None:
|
|
367
|
+
return None, None
|
|
368
|
+
|
|
369
|
+
validate_input_target_type_consistency(input_data, target_data)
|
|
370
|
+
|
|
371
|
+
if data_type == SupportedData.ARRAY:
|
|
372
|
+
return validate_array_input(input_data, target_data)
|
|
373
|
+
elif data_type in (SupportedData.TIFF, SupportedData.CZI):
|
|
374
|
+
assert data_type != SupportedData.ARRAY.value # for mypy
|
|
375
|
+
|
|
376
|
+
if isinstance(input_data, (str, Path)):
|
|
377
|
+
assert target_data is None or isinstance(target_data, (str, Path))
|
|
378
|
+
|
|
379
|
+
return validate_path_input(data_type, input_data, target_data)
|
|
380
|
+
elif isinstance(input_data, list):
|
|
381
|
+
assert target_data is None or isinstance(target_data, list)
|
|
382
|
+
|
|
383
|
+
return validate_path_input(data_type, input_data, target_data)
|
|
384
|
+
else:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
f"Unsupported input type for {data_type}: {type(input_data)}"
|
|
387
|
+
)
|
|
388
|
+
elif data_type == SupportedData.ZARR:
|
|
389
|
+
return validate_zarr_input(input_data, target_data)
|
|
390
|
+
elif data_type == SupportedData.CUSTOM:
|
|
391
|
+
if custom_loader:
|
|
392
|
+
return input_data, target_data
|
|
393
|
+
return validate_path_input(data_type, input_data, target_data, extension_filter)
|
|
394
|
+
else:
|
|
395
|
+
raise NotImplementedError(f"Unsupported data type: {data_type}")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""CARE Lightning DataModule."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, Union
|
|
5
|
+
|
|
6
|
+
from careamics.config.algorithms.care_algorithm_config import CAREAlgorithm
|
|
7
|
+
from careamics.config.algorithms.n2n_algorithm_config import N2NAlgorithm
|
|
8
|
+
from careamics.config.support import SupportedLoss
|
|
9
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
10
|
+
from careamics.losses import mae_loss, mse_loss
|
|
11
|
+
from careamics.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
from .unet_module import UnetModule
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CAREModule(UnetModule):
|
|
19
|
+
"""CAREamics PyTorch Lightning module for CARE algorithm.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
algorithm_config : CAREAlgorithm or dict
|
|
24
|
+
Configuration for the CARE algorithm, either as a CAREAlgorithm instance or a
|
|
25
|
+
dictionary.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
|
|
29
|
+
"""Instantiate CARE DataModule.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
algorithm_config : CAREAlgorithm or dict
|
|
34
|
+
Configuration for the CARE algorithm, either as a CAREAlgorithm instance or
|
|
35
|
+
a dictionary.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(algorithm_config)
|
|
38
|
+
assert isinstance(
|
|
39
|
+
algorithm_config, CAREAlgorithm | N2NAlgorithm
|
|
40
|
+
), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
|
|
41
|
+
loss = algorithm_config.loss
|
|
42
|
+
if loss == SupportedLoss.MAE:
|
|
43
|
+
self.loss_func: Callable = mae_loss
|
|
44
|
+
elif loss == SupportedLoss.MSE:
|
|
45
|
+
self.loss_func = mse_loss
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unsupported loss for Care: {loss}")
|
|
48
|
+
|
|
49
|
+
def training_step(
|
|
50
|
+
self,
|
|
51
|
+
batch: tuple[ImageRegionData, ImageRegionData],
|
|
52
|
+
batch_idx: Any,
|
|
53
|
+
) -> Any:
|
|
54
|
+
"""Training step for CARE module.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
batch : (ImageRegionData, ImageRegionData)
|
|
59
|
+
A tuple containing the input data and the target data.
|
|
60
|
+
batch_idx : Any
|
|
61
|
+
The index of the current batch in the training loop.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Any
|
|
66
|
+
The loss value computed for the current batch.
|
|
67
|
+
"""
|
|
68
|
+
# TODO: add validation to determine if target is initialized
|
|
69
|
+
x, target = batch[0], batch[1]
|
|
70
|
+
|
|
71
|
+
prediction = self.model(x.data)
|
|
72
|
+
loss = self.loss_func(prediction, target.data)
|
|
73
|
+
|
|
74
|
+
self._log_training_stats(loss, batch_size=x.data.shape[0])
|
|
75
|
+
|
|
76
|
+
return loss
|
|
77
|
+
|
|
78
|
+
def validation_step(
|
|
79
|
+
self,
|
|
80
|
+
batch: tuple[ImageRegionData, ImageRegionData],
|
|
81
|
+
batch_idx: Any,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Validation step for CARE module.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
batch : (ImageRegionData, ImageRegionData)
|
|
88
|
+
A tuple containing the input data and the target data.
|
|
89
|
+
batch_idx : Any
|
|
90
|
+
The index of the current batch in the training loop.
|
|
91
|
+
"""
|
|
92
|
+
x, target = batch[0], batch[1]
|
|
93
|
+
|
|
94
|
+
prediction = self.model(x.data)
|
|
95
|
+
val_loss = self.loss_func(prediction, target.data)
|
|
96
|
+
self.metrics(prediction, target.data)
|
|
97
|
+
self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Noise2Void Lightning DataModule."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Union
|
|
4
|
+
|
|
5
|
+
from careamics.config import (
|
|
6
|
+
N2VAlgorithm,
|
|
7
|
+
)
|
|
8
|
+
from careamics.dataset_ng.dataset import ImageRegionData
|
|
9
|
+
from careamics.losses import n2v_loss
|
|
10
|
+
from careamics.transforms import N2VManipulateTorch
|
|
11
|
+
from careamics.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
from .unet_module import UnetModule
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class N2VModule(UnetModule):
|
|
19
|
+
"""CAREamics PyTorch Lightning module for N2V algorithm.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
algorithm_config : N2VAlgorithm or dict
|
|
24
|
+
Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
|
|
25
|
+
dictionary.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
|
|
29
|
+
"""Instantiate N2V DataModule.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
algorithm_config : N2VAlgorithm or dict
|
|
34
|
+
Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
|
|
35
|
+
dictionary.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(algorithm_config)
|
|
38
|
+
|
|
39
|
+
assert isinstance(
|
|
40
|
+
algorithm_config, N2VAlgorithm
|
|
41
|
+
), "algorithm_config must be a N2VAlgorithm"
|
|
42
|
+
|
|
43
|
+
self.n2v_manipulate = N2VManipulateTorch(
|
|
44
|
+
n2v_manipulate_config=algorithm_config.n2v_config
|
|
45
|
+
)
|
|
46
|
+
self.loss_func = n2v_loss
|
|
47
|
+
|
|
48
|
+
def _load_best_checkpoint(self) -> None:
|
|
49
|
+
"""Load the best checkpoint for N2V model."""
|
|
50
|
+
logger.warning(
|
|
51
|
+
"Loading best checkpoint for N2V model. Note that for N2V, "
|
|
52
|
+
"the checkpoint with the best validation metrics may not necessarily "
|
|
53
|
+
"have the best denoising performance."
|
|
54
|
+
)
|
|
55
|
+
super()._load_best_checkpoint()
|
|
56
|
+
|
|
57
|
+
def training_step(
|
|
58
|
+
self,
|
|
59
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
60
|
+
batch_idx: Any,
|
|
61
|
+
) -> Any:
|
|
62
|
+
"""Training step for N2V model.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
67
|
+
A tuple containing the input data and the target data.
|
|
68
|
+
batch_idx : Any
|
|
69
|
+
The index of the current batch in the training loop.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Any
|
|
74
|
+
The loss value for the current training step.
|
|
75
|
+
"""
|
|
76
|
+
x = batch[0]
|
|
77
|
+
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
78
|
+
prediction = self.model(x_masked)
|
|
79
|
+
loss = self.loss_func(prediction, x_original, mask)
|
|
80
|
+
|
|
81
|
+
self._log_training_stats(loss, batch_size=x.data.shape[0])
|
|
82
|
+
|
|
83
|
+
return loss
|
|
84
|
+
|
|
85
|
+
def validation_step(
|
|
86
|
+
self,
|
|
87
|
+
batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
|
|
88
|
+
batch_idx: Any,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Validation step for N2V model.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
batch : ImageRegionData or (ImageRegionData, ImageRegionData)
|
|
95
|
+
A tuple containing the input data and the target data.
|
|
96
|
+
batch_idx : Any
|
|
97
|
+
The index of the current batch in the training loop.
|
|
98
|
+
"""
|
|
99
|
+
x = batch[0]
|
|
100
|
+
|
|
101
|
+
x_masked, x_original, mask = self.n2v_manipulate(x.data)
|
|
102
|
+
prediction = self.model(x_masked)
|
|
103
|
+
|
|
104
|
+
val_loss = self.loss_func(prediction, x_original, mask)
|
|
105
|
+
self.metrics(prediction, x_original)
|
|
106
|
+
self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
|