careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,529 @@
|
|
|
1
|
+
"""Next-Generation CAREamics DataModule."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from collections.abc import Callable, Sequence
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Literal, Union, overload
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pytorch_lightning as L
|
|
10
|
+
from numpy.typing import NDArray
|
|
11
|
+
from torch.utils.data import DataLoader, Sampler
|
|
12
|
+
from torch.utils.data._utils.collate import default_collate
|
|
13
|
+
|
|
14
|
+
from careamics.config.data.ng_data_config import NGDataConfig
|
|
15
|
+
from careamics.config.support import SupportedData
|
|
16
|
+
from careamics.dataset_ng.factory import create_dataset
|
|
17
|
+
from careamics.dataset_ng.grouped_index_sampler import GroupedIndexSampler
|
|
18
|
+
from careamics.dataset_ng.image_stack_loader import ImageStackLoader
|
|
19
|
+
from careamics.lightning.dataset_ng.data_module_utils import initialize_data_pair
|
|
20
|
+
from careamics.utils import get_logger
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
ItemType = Union[Path, str, NDArray[Any]]
|
|
25
|
+
"""Type of input items passed to the dataset."""
|
|
26
|
+
|
|
27
|
+
InputType = Union[ItemType, Sequence[ItemType], None]
|
|
28
|
+
"""Type of input data passed to the dataset."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CareamicsDataModule(L.LightningDataModule):
|
|
32
|
+
"""Data module for Careamics dataset.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
data_config : DataConfig
|
|
37
|
+
Pydantic model for CAREamics data configuration.
|
|
38
|
+
train_data : Optional[InputType]
|
|
39
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
40
|
+
train_data_target : Optional[InputType]
|
|
41
|
+
Training data target, can be a path to a folder,
|
|
42
|
+
a list of paths, or a numpy array.
|
|
43
|
+
train_data_mask : InputType (when filtering is needed)
|
|
44
|
+
Training data mask, can be a path to a folder,
|
|
45
|
+
a list of paths, or a numpy array. Used for coordinate filtering.
|
|
46
|
+
Only required when using coordinate-based patch filtering.
|
|
47
|
+
val_data : Optional[InputType]
|
|
48
|
+
Validation data, can be a path to a folder,
|
|
49
|
+
a list of paths, or a numpy array.
|
|
50
|
+
val_data_target : Optional[InputType]
|
|
51
|
+
Validation data target, can be a path to a folder,
|
|
52
|
+
a list of paths, or a numpy array.
|
|
53
|
+
pred_data : Optional[InputType]
|
|
54
|
+
Prediction data, can be a path to a folder, a list of paths,
|
|
55
|
+
or a numpy array.
|
|
56
|
+
pred_data_target : Optional[InputType]
|
|
57
|
+
Prediction data target, can be a path to a folder,
|
|
58
|
+
a list of paths, or a numpy array.
|
|
59
|
+
read_source_func : Optional[Callable], default=None
|
|
60
|
+
Function to read the source data. Only used for `custom`
|
|
61
|
+
data type (see DataModel).
|
|
62
|
+
read_kwargs : Optional[dict[str, Any]]
|
|
63
|
+
The kwargs for the read source function.
|
|
64
|
+
image_stack_loader : Optional[ImageStackLoader]
|
|
65
|
+
The image stack loader.
|
|
66
|
+
image_stack_loader_kwargs : Optional[dict[str, Any]]
|
|
67
|
+
The image stack loader kwargs.
|
|
68
|
+
extension_filter : str, default=""
|
|
69
|
+
Filter for file extensions. Only used for `custom` data types
|
|
70
|
+
(see DataModel).
|
|
71
|
+
val_percentage : Optional[float]
|
|
72
|
+
Percentage of the training data to use for validation. Only
|
|
73
|
+
used if `val_data` is None.
|
|
74
|
+
val_minimum_split : int, default=5
|
|
75
|
+
Minimum number of patches or files to split from the training data for
|
|
76
|
+
validation. Only used if `val_data` is None.
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
Attributes
|
|
80
|
+
----------
|
|
81
|
+
config : DataConfig
|
|
82
|
+
Pydantic model for CAREamics data configuration.
|
|
83
|
+
data_type : str
|
|
84
|
+
Type of data, one of SupportedData.
|
|
85
|
+
batch_size : int
|
|
86
|
+
Batch size for the dataloaders.
|
|
87
|
+
extension_filter : str
|
|
88
|
+
Filter for file extensions, by default "".
|
|
89
|
+
read_source_func : Optional[Callable], default=None
|
|
90
|
+
Function to read the source data.
|
|
91
|
+
read_kwargs : Optional[dict[str, Any]], default=None
|
|
92
|
+
The kwargs for the read source function.
|
|
93
|
+
val_percentage : Optional[float]
|
|
94
|
+
Percentage of the training data to use for validation.
|
|
95
|
+
val_minimum_split : int, default=5
|
|
96
|
+
Minimum number of patches or files to split from the training data for
|
|
97
|
+
validation.
|
|
98
|
+
train_data : Optional[Any]
|
|
99
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
100
|
+
train_data_target : Optional[Any]
|
|
101
|
+
Training data target, can be a path to a folder, a list of paths, or a numpy
|
|
102
|
+
array.
|
|
103
|
+
train_data_mask : Optional[Any]
|
|
104
|
+
Training data mask, can be a path to a folder, a list of paths, or a numpy
|
|
105
|
+
array.
|
|
106
|
+
val_data : Optional[Any]
|
|
107
|
+
Validation data, can be a path to a folder, a list of paths, or a numpy array.
|
|
108
|
+
val_data_target : Optional[Any]
|
|
109
|
+
Validation data target, can be a path to a folder, a list of paths, or a numpy
|
|
110
|
+
array.
|
|
111
|
+
pred_data : Optional[Any]
|
|
112
|
+
Prediction data, can be a path to a folder, a list of paths, or a numpy array.
|
|
113
|
+
pred_data_target : Optional[Any]
|
|
114
|
+
Prediction data target, can be a path to a folder, a list of paths, or a numpy
|
|
115
|
+
array.
|
|
116
|
+
|
|
117
|
+
Raises
|
|
118
|
+
------
|
|
119
|
+
ValueError
|
|
120
|
+
If at least one of train_data, val_data or pred_data is not provided.
|
|
121
|
+
ValueError
|
|
122
|
+
If input and target data types are not consistent.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
# standard use (no mask)
|
|
126
|
+
@overload
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
data_config: NGDataConfig,
|
|
130
|
+
*,
|
|
131
|
+
train_data: InputType | None = None,
|
|
132
|
+
train_data_target: InputType | None = None,
|
|
133
|
+
val_data: InputType | None = None,
|
|
134
|
+
val_data_target: InputType | None = None,
|
|
135
|
+
pred_data: InputType | None = None,
|
|
136
|
+
pred_data_target: InputType | None = None,
|
|
137
|
+
extension_filter: str = "",
|
|
138
|
+
val_percentage: float | None = None,
|
|
139
|
+
val_minimum_split: int = 5,
|
|
140
|
+
) -> None: ...
|
|
141
|
+
|
|
142
|
+
# with training mask for filtering
|
|
143
|
+
@overload
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
data_config: NGDataConfig,
|
|
147
|
+
*,
|
|
148
|
+
train_data: InputType | None = None,
|
|
149
|
+
train_data_target: InputType | None = None,
|
|
150
|
+
train_data_mask: InputType,
|
|
151
|
+
val_data: InputType | None = None,
|
|
152
|
+
val_data_target: InputType | None = None,
|
|
153
|
+
pred_data: InputType | None = None,
|
|
154
|
+
pred_data_target: InputType | None = None,
|
|
155
|
+
extension_filter: str = "",
|
|
156
|
+
val_percentage: float | None = None,
|
|
157
|
+
val_minimum_split: int = 5,
|
|
158
|
+
) -> None: ...
|
|
159
|
+
|
|
160
|
+
# custom read function (no mask)
|
|
161
|
+
@overload
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
data_config: NGDataConfig,
|
|
165
|
+
*,
|
|
166
|
+
train_data: InputType | None = None,
|
|
167
|
+
train_data_target: InputType | None = None,
|
|
168
|
+
val_data: InputType | None = None,
|
|
169
|
+
val_data_target: InputType | None = None,
|
|
170
|
+
pred_data: InputType | None = None,
|
|
171
|
+
pred_data_target: InputType | None = None,
|
|
172
|
+
read_source_func: Callable,
|
|
173
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
174
|
+
extension_filter: str = "",
|
|
175
|
+
val_percentage: float | None = None,
|
|
176
|
+
val_minimum_split: int = 5,
|
|
177
|
+
) -> None: ...
|
|
178
|
+
|
|
179
|
+
# custom read function with training mask
|
|
180
|
+
@overload
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
data_config: NGDataConfig,
|
|
184
|
+
*,
|
|
185
|
+
train_data: InputType | None = None,
|
|
186
|
+
train_data_target: InputType | None = None,
|
|
187
|
+
train_data_mask: InputType,
|
|
188
|
+
val_data: InputType | None = None,
|
|
189
|
+
val_data_target: InputType | None = None,
|
|
190
|
+
pred_data: InputType | None = None,
|
|
191
|
+
pred_data_target: InputType | None = None,
|
|
192
|
+
read_source_func: Callable,
|
|
193
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
194
|
+
extension_filter: str = "",
|
|
195
|
+
val_percentage: float | None = None,
|
|
196
|
+
val_minimum_split: int = 5,
|
|
197
|
+
) -> None: ...
|
|
198
|
+
|
|
199
|
+
# image stack loader (no mask)
|
|
200
|
+
@overload
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
data_config: NGDataConfig,
|
|
204
|
+
*,
|
|
205
|
+
train_data: Any | None = None,
|
|
206
|
+
train_data_target: Any | None = None,
|
|
207
|
+
val_data: Any | None = None,
|
|
208
|
+
val_data_target: Any | None = None,
|
|
209
|
+
pred_data: Any | None = None,
|
|
210
|
+
pred_data_target: Any | None = None,
|
|
211
|
+
image_stack_loader: ImageStackLoader,
|
|
212
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
213
|
+
extension_filter: str = "",
|
|
214
|
+
val_percentage: float | None = None,
|
|
215
|
+
val_minimum_split: int = 5,
|
|
216
|
+
) -> None: ...
|
|
217
|
+
|
|
218
|
+
# image stack loader with training mask
|
|
219
|
+
@overload
|
|
220
|
+
def __init__(
|
|
221
|
+
self,
|
|
222
|
+
data_config: NGDataConfig,
|
|
223
|
+
*,
|
|
224
|
+
train_data: Any | None = None,
|
|
225
|
+
train_data_target: Any | None = None,
|
|
226
|
+
train_data_mask: Any,
|
|
227
|
+
val_data: Any | None = None,
|
|
228
|
+
val_data_target: Any | None = None,
|
|
229
|
+
pred_data: Any | None = None,
|
|
230
|
+
pred_data_target: Any | None = None,
|
|
231
|
+
image_stack_loader: ImageStackLoader,
|
|
232
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
233
|
+
extension_filter: str = "",
|
|
234
|
+
val_percentage: float | None = None,
|
|
235
|
+
val_minimum_split: int = 5,
|
|
236
|
+
) -> None: ...
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
data_config: NGDataConfig,
|
|
241
|
+
*,
|
|
242
|
+
train_data: Any | None = None,
|
|
243
|
+
train_data_target: Any | None = None,
|
|
244
|
+
train_data_mask: Any | None = None,
|
|
245
|
+
val_data: Any | None = None,
|
|
246
|
+
val_data_target: Any | None = None,
|
|
247
|
+
pred_data: Any | None = None,
|
|
248
|
+
pred_data_target: Any | None = None,
|
|
249
|
+
read_source_func: Callable | None = None,
|
|
250
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
251
|
+
image_stack_loader: ImageStackLoader | None = None,
|
|
252
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
253
|
+
extension_filter: str = "",
|
|
254
|
+
val_percentage: float | None = None,
|
|
255
|
+
val_minimum_split: int = 5,
|
|
256
|
+
) -> None:
|
|
257
|
+
"""
|
|
258
|
+
Data module for Careamics dataset initialization.
|
|
259
|
+
|
|
260
|
+
Create a lightning datamodule that handles creating datasets for training,
|
|
261
|
+
validation, and prediction.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
data_config : NGDataConfig
|
|
266
|
+
Pydantic model for CAREamics data configuration.
|
|
267
|
+
train_data : Optional[InputType]
|
|
268
|
+
Training data, can be a path to a folder, a list of paths, or a numpy array.
|
|
269
|
+
train_data_target : Optional[InputType]
|
|
270
|
+
Training data target, can be a path to a folder,
|
|
271
|
+
a list of paths, or a numpy array.
|
|
272
|
+
train_data_mask : InputType (when filtering is needed)
|
|
273
|
+
Training data mask, can be a path to a folder,
|
|
274
|
+
a list of paths, or a numpy array. Used for coordinate filtering.
|
|
275
|
+
Only required when using coordinate-based patch filtering.
|
|
276
|
+
val_data : Optional[InputType]
|
|
277
|
+
Validation data, can be a path to a folder,
|
|
278
|
+
a list of paths, or a numpy array.
|
|
279
|
+
val_data_target : Optional[InputType]
|
|
280
|
+
Validation data target, can be a path to a folder,
|
|
281
|
+
a list of paths, or a numpy array.
|
|
282
|
+
pred_data : Optional[InputType]
|
|
283
|
+
Prediction data, can be a path to a folder, a list of paths,
|
|
284
|
+
or a numpy array.
|
|
285
|
+
pred_data_target : Optional[InputType]
|
|
286
|
+
Prediction data target, can be a path to a folder,
|
|
287
|
+
a list of paths, or a numpy array.
|
|
288
|
+
read_source_func : Optional[Callable]
|
|
289
|
+
Function to read the source data, by default None. Only used for `custom`
|
|
290
|
+
data type (see DataModel).
|
|
291
|
+
read_kwargs : Optional[dict[str, Any]]
|
|
292
|
+
The kwargs for the read source function.
|
|
293
|
+
image_stack_loader : Optional[ImageStackLoader]
|
|
294
|
+
The image stack loader.
|
|
295
|
+
image_stack_loader_kwargs : Optional[dict[str, Any]]
|
|
296
|
+
The image stack loader kwargs.
|
|
297
|
+
extension_filter : str
|
|
298
|
+
Filter for file extensions, by default "". Only used for `custom` data types
|
|
299
|
+
(see DataModel).
|
|
300
|
+
val_percentage : Optional[float]
|
|
301
|
+
Percentage of the training data to use for validation. Only
|
|
302
|
+
used if `val_data` is None.
|
|
303
|
+
val_minimum_split : int
|
|
304
|
+
Minimum number of patches or files to split from the training data for
|
|
305
|
+
validation, by default 5. Only used if `val_data` is None.
|
|
306
|
+
"""
|
|
307
|
+
super().__init__()
|
|
308
|
+
|
|
309
|
+
if train_data is None and val_data is None and pred_data is None:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"At least one of train_data, val_data or pred_data must be provided."
|
|
312
|
+
)
|
|
313
|
+
elif train_data is None != val_data is None:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
"If one of train_data or val_data is provided, both must be provided."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
self.config: NGDataConfig = data_config
|
|
319
|
+
self.data_type: str = data_config.data_type
|
|
320
|
+
self.batch_size: int = data_config.batch_size
|
|
321
|
+
|
|
322
|
+
self.extension_filter: str = (
|
|
323
|
+
extension_filter # list_files pulls the correct ext
|
|
324
|
+
)
|
|
325
|
+
self.read_source_func = read_source_func
|
|
326
|
+
self.read_kwargs = read_kwargs
|
|
327
|
+
self.image_stack_loader = image_stack_loader
|
|
328
|
+
self.image_stack_loader_kwargs = image_stack_loader_kwargs
|
|
329
|
+
|
|
330
|
+
# TODO: implement the validation split logic
|
|
331
|
+
self.val_percentage = val_percentage
|
|
332
|
+
self.val_minimum_split = val_minimum_split
|
|
333
|
+
if self.val_percentage is not None:
|
|
334
|
+
raise NotImplementedError("Validation split is not implemented.")
|
|
335
|
+
|
|
336
|
+
custom_loader = self.image_stack_loader is not None
|
|
337
|
+
self.train_data, self.train_data_target = initialize_data_pair(
|
|
338
|
+
self.data_type,
|
|
339
|
+
train_data,
|
|
340
|
+
train_data_target,
|
|
341
|
+
extension_filter,
|
|
342
|
+
custom_loader,
|
|
343
|
+
)
|
|
344
|
+
self.train_data_mask, _ = initialize_data_pair(
|
|
345
|
+
self.data_type, train_data_mask, None, extension_filter, custom_loader
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
self.val_data, self.val_data_target = initialize_data_pair(
|
|
349
|
+
self.data_type, val_data, val_data_target, extension_filter, custom_loader
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# The pred_data_target can be needed to count metrics on the prediction
|
|
353
|
+
self.pred_data, self.pred_data_target = initialize_data_pair(
|
|
354
|
+
self.data_type, pred_data, pred_data_target, extension_filter, custom_loader
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
def setup(self, stage: str) -> None:
|
|
358
|
+
"""
|
|
359
|
+
Setup datasets.
|
|
360
|
+
|
|
361
|
+
Lightning hook that is called at the beginning of fit (train + validate),
|
|
362
|
+
validate, test, or predict. Creates the datasets for a given stage.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
stage : str
|
|
367
|
+
The stage to set up datasets for.
|
|
368
|
+
Is either 'fit', 'validate', 'test', or 'predict'.
|
|
369
|
+
|
|
370
|
+
Raises
|
|
371
|
+
------
|
|
372
|
+
NotImplementedError
|
|
373
|
+
If stage is not one of "fit", "validate" or "predict".
|
|
374
|
+
"""
|
|
375
|
+
if stage == "fit":
|
|
376
|
+
if self.config.mode != "training":
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"CAREamicsDataModule configured for {self.config.mode} cannot be "
|
|
379
|
+
f"used for training. Please create a new CareamicsDataModule with "
|
|
380
|
+
f"a configuration with mode='training'."
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
self.train_dataset = create_dataset(
|
|
384
|
+
config=self.config,
|
|
385
|
+
inputs=self.train_data,
|
|
386
|
+
targets=self.train_data_target,
|
|
387
|
+
masks=self.train_data_mask,
|
|
388
|
+
read_func=self.read_source_func,
|
|
389
|
+
read_kwargs=self.read_kwargs,
|
|
390
|
+
image_stack_loader=self.image_stack_loader,
|
|
391
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
392
|
+
)
|
|
393
|
+
# TODO: ugly, need to find a better solution
|
|
394
|
+
self.stats = self.train_dataset.input_stats
|
|
395
|
+
self.config.set_means_and_stds(
|
|
396
|
+
self.train_dataset.input_stats.means,
|
|
397
|
+
self.train_dataset.input_stats.stds,
|
|
398
|
+
self.train_dataset.target_stats.means,
|
|
399
|
+
self.train_dataset.target_stats.stds,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
validation_config = self.config.convert_mode("validating")
|
|
403
|
+
self.val_dataset = create_dataset(
|
|
404
|
+
config=validation_config,
|
|
405
|
+
inputs=self.val_data,
|
|
406
|
+
targets=self.val_data_target,
|
|
407
|
+
read_func=self.read_source_func,
|
|
408
|
+
read_kwargs=self.read_kwargs,
|
|
409
|
+
image_stack_loader=self.image_stack_loader,
|
|
410
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
411
|
+
)
|
|
412
|
+
elif stage == "validate":
|
|
413
|
+
validation_config = self.config.convert_mode("validating")
|
|
414
|
+
self.val_dataset = create_dataset(
|
|
415
|
+
config=validation_config,
|
|
416
|
+
inputs=self.val_data,
|
|
417
|
+
targets=self.val_data_target,
|
|
418
|
+
read_func=self.read_source_func,
|
|
419
|
+
read_kwargs=self.read_kwargs,
|
|
420
|
+
image_stack_loader=self.image_stack_loader,
|
|
421
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
422
|
+
)
|
|
423
|
+
self.stats = self.val_dataset.input_stats
|
|
424
|
+
elif stage == "predict":
|
|
425
|
+
if self.config.mode == "validating":
|
|
426
|
+
raise ValueError(
|
|
427
|
+
"CAREamicsDataModule configured for validating cannot be used for "
|
|
428
|
+
"prediction. Please create a new CareamicsDataModule with a "
|
|
429
|
+
"configuration with mode='predicting'."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
self.predict_dataset = create_dataset(
|
|
433
|
+
config=(
|
|
434
|
+
self.config.convert_mode("predicting")
|
|
435
|
+
if self.config.mode == "training"
|
|
436
|
+
else self.config
|
|
437
|
+
),
|
|
438
|
+
inputs=self.pred_data,
|
|
439
|
+
targets=self.pred_data_target,
|
|
440
|
+
read_func=self.read_source_func,
|
|
441
|
+
read_kwargs=self.read_kwargs,
|
|
442
|
+
image_stack_loader=self.image_stack_loader,
|
|
443
|
+
image_stack_loader_kwargs=self.image_stack_loader_kwargs,
|
|
444
|
+
)
|
|
445
|
+
self.stats = self.predict_dataset.input_stats
|
|
446
|
+
else:
|
|
447
|
+
raise NotImplementedError(f"Stage {stage} not implemented")
|
|
448
|
+
|
|
449
|
+
def _sampler(self, dataset: Literal["train", "val", "predict"]) -> Sampler | None:
|
|
450
|
+
sampler: GroupedIndexSampler | None
|
|
451
|
+
rng = np.random.default_rng(self.config.seed)
|
|
452
|
+
if not self.config.in_memory and self.config.data_type == SupportedData.TIFF:
|
|
453
|
+
match dataset:
|
|
454
|
+
case "train":
|
|
455
|
+
ds = self.train_dataset
|
|
456
|
+
case "val":
|
|
457
|
+
ds = self.val_dataset
|
|
458
|
+
case "predict":
|
|
459
|
+
ds = self.predict_dataset
|
|
460
|
+
case _:
|
|
461
|
+
raise (
|
|
462
|
+
f"Unrecognized dataset '{dataset}', should be one of 'train', "
|
|
463
|
+
"'val' or 'predict'."
|
|
464
|
+
)
|
|
465
|
+
sampler = GroupedIndexSampler.from_dataset(ds, rng=rng)
|
|
466
|
+
else:
|
|
467
|
+
sampler = None
|
|
468
|
+
return sampler
|
|
469
|
+
|
|
470
|
+
def train_dataloader(self) -> DataLoader:
|
|
471
|
+
"""
|
|
472
|
+
Create a dataloader for training.
|
|
473
|
+
|
|
474
|
+
Returns
|
|
475
|
+
-------
|
|
476
|
+
DataLoader
|
|
477
|
+
Training dataloader.
|
|
478
|
+
"""
|
|
479
|
+
sampler = self._sampler("train")
|
|
480
|
+
dataloader_params = copy.deepcopy(self.config.train_dataloader_params)
|
|
481
|
+
# have to remove shuffle with sampler because of torch error:
|
|
482
|
+
# ValueError: sampler option is mutually exclusive with shuffle
|
|
483
|
+
# TODO: there might be other parameters mutually exclusive with sampler
|
|
484
|
+
if (sampler is not None) and ("shuffle" in dataloader_params):
|
|
485
|
+
del dataloader_params["shuffle"]
|
|
486
|
+
return DataLoader(
|
|
487
|
+
self.train_dataset,
|
|
488
|
+
batch_size=self.batch_size,
|
|
489
|
+
collate_fn=default_collate,
|
|
490
|
+
sampler=sampler,
|
|
491
|
+
**dataloader_params,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
def val_dataloader(self) -> DataLoader:
|
|
495
|
+
"""
|
|
496
|
+
Create a dataloader for validation.
|
|
497
|
+
|
|
498
|
+
Returns
|
|
499
|
+
-------
|
|
500
|
+
DataLoader
|
|
501
|
+
Validation dataloader.
|
|
502
|
+
"""
|
|
503
|
+
sampler = self._sampler("val")
|
|
504
|
+
dataloader_params = copy.deepcopy(self.config.val_dataloader_params)
|
|
505
|
+
if (sampler is not None) and ("shuffle" in dataloader_params):
|
|
506
|
+
del dataloader_params["shuffle"]
|
|
507
|
+
return DataLoader(
|
|
508
|
+
self.val_dataset,
|
|
509
|
+
batch_size=self.batch_size,
|
|
510
|
+
collate_fn=default_collate,
|
|
511
|
+
sampler=sampler,
|
|
512
|
+
**dataloader_params,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
def predict_dataloader(self) -> DataLoader:
|
|
516
|
+
"""
|
|
517
|
+
Create a dataloader for prediction.
|
|
518
|
+
|
|
519
|
+
Returns
|
|
520
|
+
-------
|
|
521
|
+
DataLoader
|
|
522
|
+
Prediction dataloader.
|
|
523
|
+
"""
|
|
524
|
+
return DataLoader(
|
|
525
|
+
self.predict_dataset,
|
|
526
|
+
batch_size=self.batch_size,
|
|
527
|
+
collate_fn=default_collate,
|
|
528
|
+
**self.config.pred_dataloader_params,
|
|
529
|
+
)
|