careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
"""MicroSplit data module for training and validation."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pytorch_lightning as L
|
|
9
|
+
import tifffile
|
|
10
|
+
from numpy.typing import NDArray
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
|
|
13
|
+
from careamics.dataset.dataset_utils.dataset_utils import reshape_array
|
|
14
|
+
from careamics.lvae_training.dataset import (
|
|
15
|
+
DataSplitType,
|
|
16
|
+
DataType,
|
|
17
|
+
LCMultiChDloader,
|
|
18
|
+
MicroSplitDataConfig,
|
|
19
|
+
)
|
|
20
|
+
from careamics.lvae_training.dataset.types import TilingMode
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# TODO refactor
|
|
24
|
+
def load_one_file(fpath):
|
|
25
|
+
"""Load a single 2D image file.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
fpath : str or Path
|
|
30
|
+
Path to the image file.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
numpy.ndarray
|
|
35
|
+
Reshaped image data.
|
|
36
|
+
"""
|
|
37
|
+
data = tifffile.imread(fpath)
|
|
38
|
+
if len(data.shape) == 2:
|
|
39
|
+
axes = "YX"
|
|
40
|
+
elif len(data.shape) == 3:
|
|
41
|
+
axes = "SYX"
|
|
42
|
+
elif len(data.shape) == 4:
|
|
43
|
+
axes = "STYX"
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(f"Invalid data shape: {data.shape}")
|
|
46
|
+
data = reshape_array(data, axes)
|
|
47
|
+
data = data.reshape(-1, data.shape[-2], data.shape[-1])
|
|
48
|
+
return data
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# TODO refactor
|
|
52
|
+
def load_data(datadir):
|
|
53
|
+
"""Load data from a directory containing channel subdirectories with image files.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
datadir : str or Path
|
|
58
|
+
Path to the data directory containing channel subdirectories.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
numpy.ndarray
|
|
63
|
+
Stacked array of all channels' data.
|
|
64
|
+
"""
|
|
65
|
+
data_path = Path(datadir)
|
|
66
|
+
|
|
67
|
+
channel_dirs = sorted(p for p in data_path.iterdir() if p.is_dir())
|
|
68
|
+
channels_data = []
|
|
69
|
+
|
|
70
|
+
for channel_dir in channel_dirs:
|
|
71
|
+
image_files = sorted(f for f in channel_dir.iterdir() if f.is_file())
|
|
72
|
+
channel_images = [load_one_file(image_path) for image_path in image_files]
|
|
73
|
+
|
|
74
|
+
channel_stack = np.concatenate(
|
|
75
|
+
channel_images, axis=0
|
|
76
|
+
) # FIXME: this line works if images have a singleton channel dimension.
|
|
77
|
+
# Specify in the notebook or change with `torch.stack`??
|
|
78
|
+
channels_data.append(channel_stack)
|
|
79
|
+
|
|
80
|
+
final_data = np.stack(channels_data, axis=-1)
|
|
81
|
+
return final_data
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# TODO refactor
|
|
85
|
+
def get_datasplit_tuples(val_fraction, test_fraction, data_length):
|
|
86
|
+
"""Get train/val/test indices for data splitting.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
val_fraction : float or None
|
|
91
|
+
Fraction of data to use for validation.
|
|
92
|
+
test_fraction : float or None
|
|
93
|
+
Fraction of data to use for testing.
|
|
94
|
+
data_length : int
|
|
95
|
+
Total length of the dataset.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
|
|
100
|
+
Training, validation, and test indices.
|
|
101
|
+
"""
|
|
102
|
+
indices = np.arange(data_length)
|
|
103
|
+
np.random.shuffle(indices)
|
|
104
|
+
|
|
105
|
+
if val_fraction is None:
|
|
106
|
+
val_fraction = 0.0
|
|
107
|
+
if test_fraction is None:
|
|
108
|
+
test_fraction = 0.0
|
|
109
|
+
|
|
110
|
+
val_size = int(data_length * val_fraction)
|
|
111
|
+
test_size = int(data_length * test_fraction)
|
|
112
|
+
train_size = data_length - val_size - test_size
|
|
113
|
+
|
|
114
|
+
train_idx = indices[:train_size]
|
|
115
|
+
val_idx = indices[train_size : train_size + val_size]
|
|
116
|
+
test_idx = indices[train_size + val_size :]
|
|
117
|
+
|
|
118
|
+
return train_idx, val_idx, test_idx
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# TODO refactor
|
|
122
|
+
def get_train_val_data(
|
|
123
|
+
data_config,
|
|
124
|
+
datadir,
|
|
125
|
+
datasplit_type: DataSplitType,
|
|
126
|
+
val_fraction=None,
|
|
127
|
+
test_fraction=None,
|
|
128
|
+
allow_generation=None,
|
|
129
|
+
**kwargs,
|
|
130
|
+
):
|
|
131
|
+
"""Load and split data according to configuration.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
data_config : MicroSplitDataConfig
|
|
136
|
+
Data configuration object.
|
|
137
|
+
datadir : str or Path
|
|
138
|
+
Path to the data directory.
|
|
139
|
+
datasplit_type : DataSplitType
|
|
140
|
+
Type of data split to return.
|
|
141
|
+
val_fraction : float, optional
|
|
142
|
+
Fraction of data to use for validation.
|
|
143
|
+
test_fraction : float, optional
|
|
144
|
+
Fraction of data to use for testing.
|
|
145
|
+
allow_generation : bool, optional
|
|
146
|
+
Whether to allow data generation.
|
|
147
|
+
**kwargs
|
|
148
|
+
Additional keyword arguments.
|
|
149
|
+
|
|
150
|
+
Returns
|
|
151
|
+
-------
|
|
152
|
+
numpy.ndarray
|
|
153
|
+
Split data array.
|
|
154
|
+
"""
|
|
155
|
+
data = load_data(datadir)
|
|
156
|
+
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
157
|
+
val_fraction, test_fraction, len(data)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if datasplit_type == DataSplitType.All:
|
|
161
|
+
data = data.astype(np.float64)
|
|
162
|
+
elif datasplit_type == DataSplitType.Train:
|
|
163
|
+
data = data[train_idx].astype(np.float64)
|
|
164
|
+
elif datasplit_type == DataSplitType.Val:
|
|
165
|
+
data = data[val_idx].astype(np.float64)
|
|
166
|
+
elif datasplit_type == DataSplitType.Test:
|
|
167
|
+
# TODO this is only used for prediction, and only because old dataset uses it
|
|
168
|
+
data = data[test_idx].astype(np.float64)
|
|
169
|
+
else:
|
|
170
|
+
raise Exception("invalid datasplit")
|
|
171
|
+
|
|
172
|
+
return data
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class MicroSplitDataModule(L.LightningDataModule):
|
|
176
|
+
"""Lightning DataModule for MicroSplit-style datasets.
|
|
177
|
+
|
|
178
|
+
Matches the interface of TrainDataModule, but internally uses original MicroSplit
|
|
179
|
+
dataset logic.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
data_config : MicroSplitDataConfig
|
|
184
|
+
Configuration for the MicroSplit dataset.
|
|
185
|
+
train_data : str
|
|
186
|
+
Path to training data directory.
|
|
187
|
+
val_data : str, optional
|
|
188
|
+
Path to validation data directory.
|
|
189
|
+
train_data_target : str, optional
|
|
190
|
+
Path to training target data.
|
|
191
|
+
val_data_target : str, optional
|
|
192
|
+
Path to validation target data.
|
|
193
|
+
read_source_func : Callable, optional
|
|
194
|
+
Function to read source data.
|
|
195
|
+
extension_filter : str, optional
|
|
196
|
+
File extension filter.
|
|
197
|
+
val_percentage : float, optional
|
|
198
|
+
Percentage of data to use for validation, by default 0.1.
|
|
199
|
+
val_minimum_split : int, optional
|
|
200
|
+
Minimum number of samples for validation split, by default 5.
|
|
201
|
+
use_in_memory : bool, optional
|
|
202
|
+
Whether to use in-memory dataset, by default True.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
data_config: MicroSplitDataConfig,
|
|
208
|
+
train_data: str,
|
|
209
|
+
val_data: str | None = None,
|
|
210
|
+
train_data_target: str | None = None,
|
|
211
|
+
val_data_target: str | None = None,
|
|
212
|
+
read_source_func: Callable | None = None,
|
|
213
|
+
extension_filter: str = "",
|
|
214
|
+
val_percentage: float = 0.1,
|
|
215
|
+
val_minimum_split: int = 5,
|
|
216
|
+
use_in_memory: bool = True,
|
|
217
|
+
):
|
|
218
|
+
"""Initialize MicroSplitDataModule.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
data_config : MicroSplitDataConfig
|
|
223
|
+
Configuration for the MicroSplit dataset.
|
|
224
|
+
train_data : str
|
|
225
|
+
Path to training data directory.
|
|
226
|
+
val_data : str, optional
|
|
227
|
+
Path to validation data directory.
|
|
228
|
+
train_data_target : str, optional
|
|
229
|
+
Path to training target data.
|
|
230
|
+
val_data_target : str, optional
|
|
231
|
+
Path to validation target data.
|
|
232
|
+
read_source_func : Callable, optional
|
|
233
|
+
Function to read source data.
|
|
234
|
+
extension_filter : str, optional
|
|
235
|
+
File extension filter.
|
|
236
|
+
val_percentage : float, optional
|
|
237
|
+
Percentage of data to use for validation, by default 0.1.
|
|
238
|
+
val_minimum_split : int, optional
|
|
239
|
+
Minimum number of samples for validation split, by default 5.
|
|
240
|
+
use_in_memory : bool, optional
|
|
241
|
+
Whether to use in-memory dataset, by default True.
|
|
242
|
+
"""
|
|
243
|
+
super().__init__()
|
|
244
|
+
# Dataset selection logic (adapted from create_train_val_datasets)
|
|
245
|
+
self.train_config = data_config # SHould configs be separated?
|
|
246
|
+
self.val_config = data_config
|
|
247
|
+
self.test_config = data_config
|
|
248
|
+
|
|
249
|
+
datapath = train_data
|
|
250
|
+
load_data_func = read_source_func
|
|
251
|
+
|
|
252
|
+
dataset_class = LCMultiChDloader # TODO hardcoded for now
|
|
253
|
+
|
|
254
|
+
# Create datasets
|
|
255
|
+
self.train_dataset = dataset_class(
|
|
256
|
+
self.train_config,
|
|
257
|
+
datapath,
|
|
258
|
+
load_data_fn=load_data_func,
|
|
259
|
+
val_fraction=val_percentage,
|
|
260
|
+
test_fraction=0.1,
|
|
261
|
+
)
|
|
262
|
+
max_val = self.train_dataset.get_max_val()
|
|
263
|
+
self.val_config.max_val = max_val
|
|
264
|
+
if self.train_config.datasplit_type == DataSplitType.All:
|
|
265
|
+
self.val_config.datasplit_type = DataSplitType.All
|
|
266
|
+
self.test_config.datasplit_type = DataSplitType.All
|
|
267
|
+
self.val_dataset = dataset_class(
|
|
268
|
+
self.val_config,
|
|
269
|
+
datapath,
|
|
270
|
+
load_data_fn=load_data_func,
|
|
271
|
+
val_fraction=val_percentage,
|
|
272
|
+
test_fraction=0.1,
|
|
273
|
+
)
|
|
274
|
+
self.test_config.max_val = max_val
|
|
275
|
+
self.test_dataset = dataset_class(
|
|
276
|
+
self.test_config,
|
|
277
|
+
datapath,
|
|
278
|
+
load_data_fn=load_data_func,
|
|
279
|
+
val_fraction=val_percentage,
|
|
280
|
+
test_fraction=0.1,
|
|
281
|
+
)
|
|
282
|
+
mean_val, std_val = self.train_dataset.compute_mean_std()
|
|
283
|
+
self.train_dataset.set_mean_std(mean_val, std_val)
|
|
284
|
+
self.val_dataset.set_mean_std(mean_val, std_val)
|
|
285
|
+
self.test_dataset.set_mean_std(mean_val, std_val)
|
|
286
|
+
data_stats = self.train_dataset.get_mean_std()
|
|
287
|
+
|
|
288
|
+
# Store data statistics
|
|
289
|
+
self.data_stats = (
|
|
290
|
+
data_stats[0],
|
|
291
|
+
data_stats[1],
|
|
292
|
+
) # TODO repeats old logic, revisit
|
|
293
|
+
|
|
294
|
+
def train_dataloader(self):
|
|
295
|
+
"""Create a dataloader for training.
|
|
296
|
+
|
|
297
|
+
Returns
|
|
298
|
+
-------
|
|
299
|
+
DataLoader
|
|
300
|
+
Training dataloader.
|
|
301
|
+
"""
|
|
302
|
+
return DataLoader(
|
|
303
|
+
self.train_dataset,
|
|
304
|
+
batch_size=self.train_config.batch_size,
|
|
305
|
+
# TODO should be inside dataloader params?
|
|
306
|
+
**self.train_config.train_dataloader_params,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def val_dataloader(self):
|
|
310
|
+
"""Create a dataloader for validation.
|
|
311
|
+
|
|
312
|
+
Returns
|
|
313
|
+
-------
|
|
314
|
+
DataLoader
|
|
315
|
+
Validation dataloader.
|
|
316
|
+
"""
|
|
317
|
+
return DataLoader(
|
|
318
|
+
self.val_dataset,
|
|
319
|
+
batch_size=self.train_config.batch_size,
|
|
320
|
+
**self.val_config.val_dataloader_params, # TODO duplicated
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def get_data_stats(self):
|
|
324
|
+
"""Get data statistics.
|
|
325
|
+
|
|
326
|
+
Returns
|
|
327
|
+
-------
|
|
328
|
+
tuple[dict, dict]
|
|
329
|
+
A tuple containing two dictionaries:
|
|
330
|
+
- data_mean: mean values for input and target
|
|
331
|
+
- data_std: standard deviation values for input and target
|
|
332
|
+
"""
|
|
333
|
+
return self.data_stats, self.val_config.max_val # TODO should be in the config?
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def create_microsplit_train_datamodule(
|
|
337
|
+
train_data: str,
|
|
338
|
+
patch_size: tuple,
|
|
339
|
+
data_type: DataType,
|
|
340
|
+
axes: str, # TODO should be there after refactoring
|
|
341
|
+
batch_size: int,
|
|
342
|
+
val_data: str | None = None,
|
|
343
|
+
num_channels: int = 2,
|
|
344
|
+
depth3D: int = 1,
|
|
345
|
+
grid_size: tuple | None = None,
|
|
346
|
+
multiscale_count: int | None = None,
|
|
347
|
+
tiling_mode: TilingMode = TilingMode.ShiftBoundary,
|
|
348
|
+
read_source_func: Callable | None = None, # TODO should be there after refactoring
|
|
349
|
+
extension_filter: str = "",
|
|
350
|
+
val_percentage: float = 0.1,
|
|
351
|
+
val_minimum_split: int = 5,
|
|
352
|
+
use_in_memory: bool = True,
|
|
353
|
+
transforms: list | None = None, # TODO should it be here?
|
|
354
|
+
train_dataloader_params: dict | None = None,
|
|
355
|
+
val_dataloader_params: dict | None = None,
|
|
356
|
+
**dataset_kwargs,
|
|
357
|
+
) -> MicroSplitDataModule:
|
|
358
|
+
"""
|
|
359
|
+
Create a MicroSplitDataModule for MicroSplit-style datasets.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
train_data : str
|
|
364
|
+
Path to training data.
|
|
365
|
+
patch_size : tuple
|
|
366
|
+
Size of one patch of data.
|
|
367
|
+
data_type : DataType
|
|
368
|
+
Type of the dataset (must be a DataType enum value).
|
|
369
|
+
axes : str
|
|
370
|
+
Axes of the data (e.g., 'SYX').
|
|
371
|
+
batch_size : int
|
|
372
|
+
Batch size for dataloaders.
|
|
373
|
+
val_data : str, optional
|
|
374
|
+
Path to validation data.
|
|
375
|
+
num_channels : int, default=2
|
|
376
|
+
Number of channels in the input.
|
|
377
|
+
depth3D : int, default=1
|
|
378
|
+
Number of slices in 3D.
|
|
379
|
+
grid_size : tuple, optional
|
|
380
|
+
Grid size for patch extraction.
|
|
381
|
+
multiscale_count : int, optional
|
|
382
|
+
Number of LC scales.
|
|
383
|
+
tiling_mode : TilingMode, default=ShiftBoundary
|
|
384
|
+
Tiling mode for patch extraction.
|
|
385
|
+
read_source_func : Callable, optional
|
|
386
|
+
Function to read the source data.
|
|
387
|
+
extension_filter : str, optional
|
|
388
|
+
File extension filter.
|
|
389
|
+
val_percentage : float, default=0.1
|
|
390
|
+
Percentage of training data to use for validation.
|
|
391
|
+
val_minimum_split : int, default=5
|
|
392
|
+
Minimum number of patches/files for validation split.
|
|
393
|
+
use_in_memory : bool, default=True
|
|
394
|
+
Use in-memory dataset if possible.
|
|
395
|
+
transforms : list, optional
|
|
396
|
+
List of transforms to apply.
|
|
397
|
+
train_dataloader_params : dict, optional
|
|
398
|
+
Parameters for training dataloader.
|
|
399
|
+
val_dataloader_params : dict, optional
|
|
400
|
+
Parameters for validation dataloader.
|
|
401
|
+
**dataset_kwargs :
|
|
402
|
+
Additional arguments passed to DatasetConfig.
|
|
403
|
+
|
|
404
|
+
Returns
|
|
405
|
+
-------
|
|
406
|
+
MicroSplitDataModule
|
|
407
|
+
Configured MicroSplitDataModule instance.
|
|
408
|
+
"""
|
|
409
|
+
# Create dataset configs with only valid parameters
|
|
410
|
+
dataset_config_params = {
|
|
411
|
+
"data_type": data_type,
|
|
412
|
+
"image_size": patch_size,
|
|
413
|
+
"num_channels": num_channels,
|
|
414
|
+
"depth3D": depth3D,
|
|
415
|
+
"grid_size": grid_size,
|
|
416
|
+
"multiscale_lowres_count": multiscale_count,
|
|
417
|
+
"tiling_mode": tiling_mode,
|
|
418
|
+
"batch_size": batch_size,
|
|
419
|
+
"train_dataloader_params": train_dataloader_params,
|
|
420
|
+
"val_dataloader_params": val_dataloader_params,
|
|
421
|
+
**dataset_kwargs,
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
train_config = MicroSplitDataConfig(
|
|
425
|
+
**dataset_config_params,
|
|
426
|
+
datasplit_type=DataSplitType.Train,
|
|
427
|
+
)
|
|
428
|
+
# val_config = MicroSplitDataConfig(
|
|
429
|
+
# **dataset_config_params,
|
|
430
|
+
# datasplit_type=DataSplitType.Val,
|
|
431
|
+
# )
|
|
432
|
+
# TODO, data config is duplicated here and in configuration
|
|
433
|
+
|
|
434
|
+
return MicroSplitDataModule(
|
|
435
|
+
data_config=train_config,
|
|
436
|
+
train_data=train_data,
|
|
437
|
+
val_data=val_data or train_data,
|
|
438
|
+
train_data_target=None,
|
|
439
|
+
val_data_target=None,
|
|
440
|
+
read_source_func=get_train_val_data, # Use our wrapped function
|
|
441
|
+
extension_filter=extension_filter,
|
|
442
|
+
val_percentage=val_percentage,
|
|
443
|
+
val_minimum_split=val_minimum_split,
|
|
444
|
+
use_in_memory=use_in_memory,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class MicroSplitPredictDataModule(L.LightningDataModule):
|
|
449
|
+
"""Lightning DataModule for MicroSplit-style prediction datasets.
|
|
450
|
+
|
|
451
|
+
Matches the interface of PredictDataModule, but internally uses MicroSplit
|
|
452
|
+
dataset logic for prediction.
|
|
453
|
+
|
|
454
|
+
Parameters
|
|
455
|
+
----------
|
|
456
|
+
pred_config : MicroSplitDataConfig
|
|
457
|
+
Configuration for MicroSplit prediction.
|
|
458
|
+
pred_data : str or Path or numpy.ndarray
|
|
459
|
+
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
460
|
+
read_source_func : Callable, optional
|
|
461
|
+
Function to read custom types.
|
|
462
|
+
extension_filter : str, optional
|
|
463
|
+
Filter to filter file extensions for custom types.
|
|
464
|
+
dataloader_params : dict, optional
|
|
465
|
+
Dataloader parameters.
|
|
466
|
+
"""
|
|
467
|
+
|
|
468
|
+
def __init__(
|
|
469
|
+
self,
|
|
470
|
+
pred_config: MicroSplitDataConfig,
|
|
471
|
+
pred_data: Union[str, Path, NDArray],
|
|
472
|
+
read_source_func: Callable | None = None,
|
|
473
|
+
extension_filter: str = "",
|
|
474
|
+
dataloader_params: dict | None = None,
|
|
475
|
+
) -> None:
|
|
476
|
+
"""
|
|
477
|
+
Constructor for MicroSplit prediction data module.
|
|
478
|
+
|
|
479
|
+
Parameters
|
|
480
|
+
----------
|
|
481
|
+
pred_config : MicroSplitDataConfig
|
|
482
|
+
Configuration for MicroSplit prediction.
|
|
483
|
+
pred_data : str or Path or numpy.ndarray
|
|
484
|
+
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
485
|
+
read_source_func : Callable, optional
|
|
486
|
+
Function to read custom types, by default None.
|
|
487
|
+
extension_filter : str, optional
|
|
488
|
+
Filter to filter file extensions for custom types, by default "".
|
|
489
|
+
dataloader_params : dict, optional
|
|
490
|
+
Dataloader parameters, by default {}.
|
|
491
|
+
"""
|
|
492
|
+
super().__init__()
|
|
493
|
+
|
|
494
|
+
if dataloader_params is None:
|
|
495
|
+
dataloader_params = {}
|
|
496
|
+
self.pred_config = pred_config
|
|
497
|
+
self.pred_data = pred_data
|
|
498
|
+
self.read_source_func = read_source_func or get_train_val_data
|
|
499
|
+
self.extension_filter = extension_filter
|
|
500
|
+
self.dataloader_params = dataloader_params
|
|
501
|
+
|
|
502
|
+
def prepare_data(self) -> None:
|
|
503
|
+
"""Hook used to prepare the data before calling `setup`."""
|
|
504
|
+
# # TODO currently data preparation is handled in dataset creation, revisit!
|
|
505
|
+
pass
|
|
506
|
+
|
|
507
|
+
def setup(self, stage: str | None = None) -> None:
|
|
508
|
+
"""
|
|
509
|
+
Hook called at the beginning of predict.
|
|
510
|
+
|
|
511
|
+
Parameters
|
|
512
|
+
----------
|
|
513
|
+
stage : Optional[str], optional
|
|
514
|
+
Stage, by default None.
|
|
515
|
+
"""
|
|
516
|
+
# Create prediction dataset using LCMultiChDloader
|
|
517
|
+
self.predict_dataset = LCMultiChDloader(
|
|
518
|
+
self.pred_config,
|
|
519
|
+
self.pred_data,
|
|
520
|
+
load_data_fn=self.read_source_func,
|
|
521
|
+
val_fraction=0.0, # No validation split for prediction
|
|
522
|
+
test_fraction=1.0, # No test split for prediction
|
|
523
|
+
)
|
|
524
|
+
self.predict_dataset.set_mean_std(*self.pred_config.data_stats)
|
|
525
|
+
|
|
526
|
+
def predict_dataloader(self) -> DataLoader:
|
|
527
|
+
"""
|
|
528
|
+
Create a dataloader for prediction.
|
|
529
|
+
|
|
530
|
+
Returns
|
|
531
|
+
-------
|
|
532
|
+
DataLoader
|
|
533
|
+
Prediction dataloader.
|
|
534
|
+
"""
|
|
535
|
+
return DataLoader(
|
|
536
|
+
self.predict_dataset,
|
|
537
|
+
batch_size=self.pred_config.batch_size,
|
|
538
|
+
**self.dataloader_params,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def create_microsplit_predict_datamodule(
|
|
543
|
+
pred_data: Union[str, Path, NDArray],
|
|
544
|
+
tile_size: tuple,
|
|
545
|
+
data_type: DataType,
|
|
546
|
+
axes: str,
|
|
547
|
+
batch_size: int = 1,
|
|
548
|
+
num_channels: int = 2,
|
|
549
|
+
depth3D: int = 1,
|
|
550
|
+
grid_size: int | None = None,
|
|
551
|
+
multiscale_count: int | None = None,
|
|
552
|
+
data_stats: tuple | None = None,
|
|
553
|
+
tiling_mode: TilingMode = TilingMode.ShiftBoundary,
|
|
554
|
+
read_source_func: Callable | None = None,
|
|
555
|
+
extension_filter: str = "",
|
|
556
|
+
dataloader_params: dict | None = None,
|
|
557
|
+
**dataset_kwargs,
|
|
558
|
+
) -> MicroSplitPredictDataModule:
|
|
559
|
+
"""
|
|
560
|
+
Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
pred_data : str or Path or numpy.ndarray
|
|
565
|
+
Prediction data, can be a path to a folder, a file or a numpy array.
|
|
566
|
+
tile_size : tuple
|
|
567
|
+
Size of one tile of data.
|
|
568
|
+
data_type : DataType
|
|
569
|
+
Type of the dataset (must be a DataType enum value).
|
|
570
|
+
axes : str
|
|
571
|
+
Axes of the data (e.g., 'SYX').
|
|
572
|
+
batch_size : int, default=1
|
|
573
|
+
Batch size for prediction dataloader.
|
|
574
|
+
num_channels : int, default=2
|
|
575
|
+
Number of channels in the input.
|
|
576
|
+
depth3D : int, default=1
|
|
577
|
+
Number of slices in 3D.
|
|
578
|
+
grid_size : tuple, optional
|
|
579
|
+
Grid size for patch extraction.
|
|
580
|
+
multiscale_count : int, optional
|
|
581
|
+
Number of LC scales.
|
|
582
|
+
data_stats : tuple, optional
|
|
583
|
+
Data statistics, by default None.
|
|
584
|
+
tiling_mode : TilingMode, default=ShiftBoundary
|
|
585
|
+
Tiling mode for patch extraction.
|
|
586
|
+
read_source_func : Callable, optional
|
|
587
|
+
Function to read the source data.
|
|
588
|
+
extension_filter : str, optional
|
|
589
|
+
File extension filter.
|
|
590
|
+
dataloader_params : dict, optional
|
|
591
|
+
Parameters for prediction dataloader.
|
|
592
|
+
**dataset_kwargs :
|
|
593
|
+
Additional arguments passed to MicroSplitDataConfig.
|
|
594
|
+
|
|
595
|
+
Returns
|
|
596
|
+
-------
|
|
597
|
+
MicroSplitPredictDataModule
|
|
598
|
+
Configured MicroSplitPredictDataModule instance.
|
|
599
|
+
"""
|
|
600
|
+
if dataloader_params is None:
|
|
601
|
+
dataloader_params = {}
|
|
602
|
+
|
|
603
|
+
# Create prediction config with only valid parameters
|
|
604
|
+
prediction_config_params = {
|
|
605
|
+
"data_type": data_type,
|
|
606
|
+
"image_size": tile_size,
|
|
607
|
+
"num_channels": num_channels,
|
|
608
|
+
"depth3D": depth3D,
|
|
609
|
+
"grid_size": grid_size,
|
|
610
|
+
"multiscale_lowres_count": multiscale_count,
|
|
611
|
+
"data_stats": data_stats,
|
|
612
|
+
"tiling_mode": tiling_mode,
|
|
613
|
+
"batch_size": batch_size,
|
|
614
|
+
"datasplit_type": DataSplitType.Test, # For prediction, use all data
|
|
615
|
+
**dataset_kwargs,
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
pred_config = MicroSplitDataConfig(**prediction_config_params)
|
|
619
|
+
|
|
620
|
+
# Remove batch_size from dataloader_params if present
|
|
621
|
+
if "batch_size" in dataloader_params:
|
|
622
|
+
del dataloader_params["batch_size"]
|
|
623
|
+
|
|
624
|
+
return MicroSplitPredictDataModule(
|
|
625
|
+
pred_config=pred_config,
|
|
626
|
+
pred_data=pred_data,
|
|
627
|
+
read_source_func=(
|
|
628
|
+
read_source_func if read_source_func is not None else get_train_val_data
|
|
629
|
+
),
|
|
630
|
+
extension_filter=extension_filter,
|
|
631
|
+
dataloader_params=dataloader_params,
|
|
632
|
+
)
|