careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,1121 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A place for Datasets and Dataloaders.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Optional, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
from .utils.empty_patch_fetcher import EmptyPatchFetcher
|
|
13
|
+
from .utils.index_manager import GridIndexManager
|
|
14
|
+
from .utils.index_switcher import IndexSwitcher
|
|
15
|
+
from .config import MicroSplitDataConfig
|
|
16
|
+
from .types import DataSplitType, TilingMode
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MultiChDloader(Dataset):
|
|
20
|
+
"""Multi-channel dataset loader."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
data_config: MicroSplitDataConfig,
|
|
25
|
+
datapath: Union[str, Path],
|
|
26
|
+
load_data_fn: Optional[Callable] = None,
|
|
27
|
+
val_fraction: float = 0.1,
|
|
28
|
+
test_fraction: float = 0.1,
|
|
29
|
+
allow_generation: bool = False,
|
|
30
|
+
):
|
|
31
|
+
""" """
|
|
32
|
+
self._data_type = data_config.data_type
|
|
33
|
+
self._fpath = datapath
|
|
34
|
+
self._data = self._noise_data = None
|
|
35
|
+
self.Z = 1
|
|
36
|
+
self._5Ddata = False
|
|
37
|
+
self._tiling_mode = data_config.tiling_mode
|
|
38
|
+
# by default, if the noise is present, add it to the input and target.
|
|
39
|
+
self._disable_noise = False # to add synthetic noise
|
|
40
|
+
self._poisson_noise_factor = None
|
|
41
|
+
self._train_index_switcher = None
|
|
42
|
+
self._depth3D = data_config.depth3D
|
|
43
|
+
self._mode_3D = data_config.mode_3D
|
|
44
|
+
# NOTE: Input is the sum of the different channels. It is not the average of the different channels.
|
|
45
|
+
self._input_is_sum = data_config.input_is_sum
|
|
46
|
+
self._num_channels = data_config.num_channels
|
|
47
|
+
self._input_idx = data_config.input_idx
|
|
48
|
+
self._tar_idx_list = data_config.target_idx_list
|
|
49
|
+
|
|
50
|
+
if data_config.datasplit_type == DataSplitType.Train:
|
|
51
|
+
self._datausage_fraction = data_config.trainig_datausage_fraction
|
|
52
|
+
# assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
|
|
53
|
+
self._validtarget_rand_fract = data_config.validtarget_random_fraction
|
|
54
|
+
# self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
|
|
55
|
+
# self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
|
|
56
|
+
# self._idx_count = 0
|
|
57
|
+
elif data_config.datasplit_type == DataSplitType.Val:
|
|
58
|
+
self._datausage_fraction = data_config.validation_datausage_fraction
|
|
59
|
+
else:
|
|
60
|
+
self._datausage_fraction = 1.0
|
|
61
|
+
|
|
62
|
+
self.load_data(
|
|
63
|
+
data_config,
|
|
64
|
+
data_config.datasplit_type,
|
|
65
|
+
load_data_fn=load_data_fn,
|
|
66
|
+
val_fraction=val_fraction,
|
|
67
|
+
test_fraction=test_fraction,
|
|
68
|
+
allow_generation=data_config.allow_generation,
|
|
69
|
+
)
|
|
70
|
+
self._normalized_input = data_config.normalized_input
|
|
71
|
+
self._quantile = 1.0
|
|
72
|
+
self._channelwise_quantile = False
|
|
73
|
+
self._background_quantile = 0.0
|
|
74
|
+
self._clip_background_noise_to_zero = False
|
|
75
|
+
self._skip_normalization_using_mean = False
|
|
76
|
+
self._empty_patch_replacement_enabled = False
|
|
77
|
+
|
|
78
|
+
self._background_values = None
|
|
79
|
+
|
|
80
|
+
self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
|
|
81
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
82
|
+
if (
|
|
83
|
+
self._overlapping_padding_kwargs is None
|
|
84
|
+
or data_config.multiscale_lowres_count is not None
|
|
85
|
+
):
|
|
86
|
+
# raise warning
|
|
87
|
+
print("Padding is not used with this alignement style")
|
|
88
|
+
else:
|
|
89
|
+
assert (
|
|
90
|
+
self._overlapping_padding_kwargs is not None
|
|
91
|
+
), "When not trimming boudnary, padding is needed."
|
|
92
|
+
|
|
93
|
+
self._is_train = data_config.datasplit_type == DataSplitType.Train
|
|
94
|
+
|
|
95
|
+
# input = alpha * ch1 + (1-alpha)*ch2.
|
|
96
|
+
# alpha is sampled randomly between these two extremes
|
|
97
|
+
self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
|
|
98
|
+
|
|
99
|
+
self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
|
|
100
|
+
|
|
101
|
+
# changed set_img_sz because "grid_size" in data_config returns false
|
|
102
|
+
try:
|
|
103
|
+
grid_size = data_config.grid_size
|
|
104
|
+
except AttributeError:
|
|
105
|
+
grid_size = data_config.image_size
|
|
106
|
+
|
|
107
|
+
if self._is_train:
|
|
108
|
+
self._start_alpha_arr = data_config.start_alpha
|
|
109
|
+
self._end_alpha_arr = data_config.end_alpha
|
|
110
|
+
|
|
111
|
+
self.set_img_sz(data_config.image_size, grid_size)
|
|
112
|
+
|
|
113
|
+
if self._validtarget_rand_fract is not None:
|
|
114
|
+
self._train_index_switcher = IndexSwitcher(
|
|
115
|
+
self.idx_manager, data_config, self._img_sz
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
else:
|
|
119
|
+
self.set_img_sz(data_config.image_size, grid_size)
|
|
120
|
+
|
|
121
|
+
self._return_alpha = False
|
|
122
|
+
self._return_index = False
|
|
123
|
+
|
|
124
|
+
self._empty_patch_replacement_enabled = (
|
|
125
|
+
data_config.empty_patch_replacement_enabled and self._is_train
|
|
126
|
+
)
|
|
127
|
+
if self._empty_patch_replacement_enabled:
|
|
128
|
+
self._empty_patch_replacement_channel_idx = (
|
|
129
|
+
data_config.empty_patch_replacement_channel_idx
|
|
130
|
+
)
|
|
131
|
+
self._empty_patch_replacement_probab = (
|
|
132
|
+
data_config.empty_patch_replacement_probab
|
|
133
|
+
)
|
|
134
|
+
data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
|
|
135
|
+
# NOTE: This is on the raw data. So, it must be called before removing the background.
|
|
136
|
+
self._empty_patch_fetcher = EmptyPatchFetcher(
|
|
137
|
+
self.idx_manager,
|
|
138
|
+
self._img_sz,
|
|
139
|
+
data_frames,
|
|
140
|
+
max_val_threshold=data_config.empty_patch_max_val_threshold,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
144
|
+
data_config.max_val, data_config.datasplit_type
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
|
|
148
|
+
|
|
149
|
+
self._mean = None
|
|
150
|
+
self._std = None
|
|
151
|
+
self._use_one_mu_std = data_config.use_one_mu_std
|
|
152
|
+
|
|
153
|
+
self._target_separate_normalization = data_config.target_separate_normalization
|
|
154
|
+
|
|
155
|
+
self._enable_rotation = data_config.enable_rotation_aug
|
|
156
|
+
flipz_3D = data_config.random_flip_z_3D
|
|
157
|
+
self._flipz_3D = flipz_3D and self._enable_rotation
|
|
158
|
+
|
|
159
|
+
self._enable_random_cropping = data_config.enable_random_cropping
|
|
160
|
+
self._uncorrelated_channels = (
|
|
161
|
+
data_config.uncorrelated_channels and self._is_train
|
|
162
|
+
)
|
|
163
|
+
self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
|
|
164
|
+
assert self._is_train or self._uncorrelated_channels is False
|
|
165
|
+
assert (
|
|
166
|
+
self._enable_random_cropping is True or self._uncorrelated_channels is False
|
|
167
|
+
)
|
|
168
|
+
# Randomly rotate [-90,90]
|
|
169
|
+
|
|
170
|
+
self._rotation_transform = None
|
|
171
|
+
if self._enable_rotation:
|
|
172
|
+
# TODO: fix this import
|
|
173
|
+
import albumentations as A
|
|
174
|
+
|
|
175
|
+
self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
|
|
176
|
+
|
|
177
|
+
# TODO: remove print log messages
|
|
178
|
+
# if print_vars:
|
|
179
|
+
# msg = self._init_msg()
|
|
180
|
+
# print(msg)
|
|
181
|
+
|
|
182
|
+
def disable_noise(self):
|
|
183
|
+
assert (
|
|
184
|
+
self._poisson_noise_factor is None
|
|
185
|
+
), "This is not supported. Poisson noise is added to the data itself and so the noise cannot be disabled."
|
|
186
|
+
self._disable_noise = True
|
|
187
|
+
|
|
188
|
+
def enable_noise(self):
|
|
189
|
+
self._disable_noise = False
|
|
190
|
+
|
|
191
|
+
def get_data_shape(self):
|
|
192
|
+
return self._data.shape
|
|
193
|
+
|
|
194
|
+
def load_data(
|
|
195
|
+
self,
|
|
196
|
+
data_config,
|
|
197
|
+
datasplit_type,
|
|
198
|
+
load_data_fn: Callable,
|
|
199
|
+
val_fraction=None,
|
|
200
|
+
test_fraction=None,
|
|
201
|
+
allow_generation=None,
|
|
202
|
+
):
|
|
203
|
+
self._data = load_data_fn(
|
|
204
|
+
data_config,
|
|
205
|
+
self._fpath,
|
|
206
|
+
datasplit_type,
|
|
207
|
+
val_fraction=val_fraction,
|
|
208
|
+
test_fraction=test_fraction,
|
|
209
|
+
allow_generation=allow_generation,
|
|
210
|
+
)
|
|
211
|
+
self._loaded_data_preprocessing(data_config)
|
|
212
|
+
|
|
213
|
+
def _loaded_data_preprocessing(self, data_config):
|
|
214
|
+
old_shape = self._data.shape
|
|
215
|
+
if self._datausage_fraction < 1.0:
|
|
216
|
+
framepixelcount = np.prod(self._data.shape[1:3])
|
|
217
|
+
pixelcount = int(
|
|
218
|
+
len(self._data) * framepixelcount * self._datausage_fraction
|
|
219
|
+
)
|
|
220
|
+
frame_count = int(np.ceil(pixelcount / framepixelcount))
|
|
221
|
+
last_frame_reduced_size, _ = IndexSwitcher.get_reduced_frame_size(
|
|
222
|
+
self._data.shape[:3], self._datausage_fraction
|
|
223
|
+
)
|
|
224
|
+
self._data = self._data[:frame_count].copy()
|
|
225
|
+
if frame_count == 1:
|
|
226
|
+
self._data = self._data[
|
|
227
|
+
:, :last_frame_reduced_size, :last_frame_reduced_size
|
|
228
|
+
].copy()
|
|
229
|
+
print(
|
|
230
|
+
f"[{self.__class__.__name__}] New data shape: {self._data.shape} Old: {old_shape}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
msg = ""
|
|
234
|
+
if data_config.poisson_noise_factor > 0:
|
|
235
|
+
self._poisson_noise_factor = data_config.poisson_noise_factor
|
|
236
|
+
msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
|
|
237
|
+
self._data = np.random.poisson(self._data / self._poisson_noise_factor)
|
|
238
|
+
|
|
239
|
+
if data_config.enable_gaussian_noise:
|
|
240
|
+
synthetic_scale = data_config.synthetic_gaussian_scale
|
|
241
|
+
msg += f"Adding Gaussian noise with scale {synthetic_scale}"
|
|
242
|
+
# 0 => noise for input. 1: => noise for all targets.
|
|
243
|
+
shape = self._data.shape
|
|
244
|
+
self._noise_data = np.random.normal(
|
|
245
|
+
0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
|
|
246
|
+
)
|
|
247
|
+
if data_config.input_has_dependant_noise:
|
|
248
|
+
msg += ". Moreover, input has dependent noise"
|
|
249
|
+
self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
|
|
250
|
+
print(msg)
|
|
251
|
+
|
|
252
|
+
if len(self._data.shape) == 5:
|
|
253
|
+
if self._mode_3D:
|
|
254
|
+
self._5Ddata = True
|
|
255
|
+
else:
|
|
256
|
+
assert self._depth3D == 1, "Depth3D must be 1 for 2D training"
|
|
257
|
+
self._data = self._data.reshape(-1, *self._data.shape[2:])
|
|
258
|
+
|
|
259
|
+
if self._5Ddata:
|
|
260
|
+
self.Z = self._data.shape[1]
|
|
261
|
+
|
|
262
|
+
if self._depth3D > 1:
|
|
263
|
+
assert self._5Ddata, "Data must be 5D:NxZxHxWxC for 3D data"
|
|
264
|
+
|
|
265
|
+
assert (
|
|
266
|
+
self._data.shape[-1] == self._num_channels
|
|
267
|
+
), "Number of channels in data and config do not match."
|
|
268
|
+
|
|
269
|
+
def save_background(self, channel_idx, frame_idx, background_value):
|
|
270
|
+
self._background_values[frame_idx, channel_idx] = background_value
|
|
271
|
+
|
|
272
|
+
def get_background(self, channel_idx, frame_idx):
|
|
273
|
+
return self._background_values[frame_idx, channel_idx]
|
|
274
|
+
|
|
275
|
+
def remove_background(self):
|
|
276
|
+
|
|
277
|
+
self._background_values = np.zeros((self._data.shape[0], self._data.shape[-1]))
|
|
278
|
+
|
|
279
|
+
if self._background_quantile == 0.0:
|
|
280
|
+
assert (
|
|
281
|
+
self._clip_background_noise_to_zero is False
|
|
282
|
+
), "This operation currently happens later in this function."
|
|
283
|
+
return
|
|
284
|
+
|
|
285
|
+
if self._data.dtype in [np.uint16]:
|
|
286
|
+
# unsigned integer creates havoc
|
|
287
|
+
self._data = self._data.astype(np.int32)
|
|
288
|
+
|
|
289
|
+
for ch in range(self._data.shape[-1]):
|
|
290
|
+
for idx in range(self._data.shape[0]):
|
|
291
|
+
qval = np.quantile(self._data[idx, ..., ch], self._background_quantile)
|
|
292
|
+
assert (
|
|
293
|
+
np.abs(qval) > 20
|
|
294
|
+
), "We are truncating the qval to an integer which will only make sense if it is large enough"
|
|
295
|
+
# NOTE: Here, there can be an issue if you work with normalized data
|
|
296
|
+
qval = int(qval)
|
|
297
|
+
self.save_background(ch, idx, qval)
|
|
298
|
+
self._data[idx, ..., ch] -= qval
|
|
299
|
+
|
|
300
|
+
if self._clip_background_noise_to_zero:
|
|
301
|
+
self._data[self._data < 0] = 0
|
|
302
|
+
|
|
303
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
304
|
+
self.remove_background()
|
|
305
|
+
self.set_max_val(max_val, datasplit_type)
|
|
306
|
+
self.upperclip_data()
|
|
307
|
+
|
|
308
|
+
def upperclip_data(self):
|
|
309
|
+
if isinstance(self.max_val, list):
|
|
310
|
+
chN = self._data.shape[-1]
|
|
311
|
+
assert chN == len(self.max_val)
|
|
312
|
+
for ch in range(chN):
|
|
313
|
+
ch_data = self._data[..., ch]
|
|
314
|
+
ch_q = self.max_val[ch]
|
|
315
|
+
ch_data[ch_data > ch_q] = ch_q
|
|
316
|
+
self._data[..., ch] = ch_data
|
|
317
|
+
else:
|
|
318
|
+
self._data[self._data > self.max_val] = self.max_val
|
|
319
|
+
|
|
320
|
+
def compute_max_val(self):
|
|
321
|
+
if self._channelwise_quantile:
|
|
322
|
+
max_val_arr = [
|
|
323
|
+
np.quantile(self._data[..., i], self._quantile)
|
|
324
|
+
for i in range(self._data.shape[-1])
|
|
325
|
+
]
|
|
326
|
+
return max_val_arr
|
|
327
|
+
else:
|
|
328
|
+
return np.quantile(self._data, self._quantile)
|
|
329
|
+
|
|
330
|
+
def set_max_val(self, max_val, datasplit_type):
|
|
331
|
+
|
|
332
|
+
if max_val is None:
|
|
333
|
+
assert datasplit_type == DataSplitType.Train
|
|
334
|
+
self.max_val = self.compute_max_val()
|
|
335
|
+
else:
|
|
336
|
+
assert max_val is not None
|
|
337
|
+
self.max_val = max_val
|
|
338
|
+
|
|
339
|
+
def get_max_val(self):
|
|
340
|
+
return self.max_val
|
|
341
|
+
|
|
342
|
+
def get_img_sz(self):
|
|
343
|
+
return self._img_sz
|
|
344
|
+
|
|
345
|
+
def get_num_frames(self):
|
|
346
|
+
return self._data.shape[0]
|
|
347
|
+
|
|
348
|
+
def reduce_data(
|
|
349
|
+
self,
|
|
350
|
+
t_list=None,
|
|
351
|
+
z_start=None,
|
|
352
|
+
z_end=None,
|
|
353
|
+
h_start=None,
|
|
354
|
+
h_end=None,
|
|
355
|
+
w_start=None,
|
|
356
|
+
w_end=None,
|
|
357
|
+
):
|
|
358
|
+
if self._5Ddata:
|
|
359
|
+
if t_list is None:
|
|
360
|
+
t_list = list(range(self._data.shape[0]))
|
|
361
|
+
if z_start is None:
|
|
362
|
+
z_start = 0
|
|
363
|
+
if z_end is None:
|
|
364
|
+
z_end = self._data.shape[1]
|
|
365
|
+
if h_start is None:
|
|
366
|
+
h_start = 0
|
|
367
|
+
if h_end is None:
|
|
368
|
+
h_end = self._data.shape[2]
|
|
369
|
+
if w_start is None:
|
|
370
|
+
w_start = 0
|
|
371
|
+
if w_end is None:
|
|
372
|
+
w_end = self._data.shape[3]
|
|
373
|
+
self._data = self._data[
|
|
374
|
+
t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
|
|
375
|
+
].copy()
|
|
376
|
+
if self._noise_data is not None:
|
|
377
|
+
self._noise_data = self._noise_data[
|
|
378
|
+
t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
|
|
379
|
+
].copy()
|
|
380
|
+
else:
|
|
381
|
+
if t_list is None:
|
|
382
|
+
t_list = list(range(self._data.shape[0]))
|
|
383
|
+
if h_start is None:
|
|
384
|
+
h_start = 0
|
|
385
|
+
if h_end is None:
|
|
386
|
+
h_end = self._data.shape[1]
|
|
387
|
+
if w_start is None:
|
|
388
|
+
w_start = 0
|
|
389
|
+
if w_end is None:
|
|
390
|
+
w_end = self._data.shape[2]
|
|
391
|
+
|
|
392
|
+
self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
|
|
393
|
+
if self._noise_data is not None:
|
|
394
|
+
self._noise_data = self._noise_data[
|
|
395
|
+
t_list, h_start:h_end, w_start:w_end, :
|
|
396
|
+
].copy()
|
|
397
|
+
# TODO where tf is self._img_sz defined?
|
|
398
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
399
|
+
print(
|
|
400
|
+
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
def get_idx_manager_shapes(
|
|
404
|
+
self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
|
|
405
|
+
):
|
|
406
|
+
numC = self._data.shape[-1]
|
|
407
|
+
if self._5Ddata:
|
|
408
|
+
patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
|
|
409
|
+
if isinstance(grid_size, int):
|
|
410
|
+
grid_shape = (1, 1, grid_size, grid_size, numC)
|
|
411
|
+
else:
|
|
412
|
+
assert len(grid_size) == 3
|
|
413
|
+
assert all(
|
|
414
|
+
[g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
|
|
415
|
+
), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
|
|
416
|
+
grid_shape = (1, grid_size[0], grid_size[1], grid_size[2], numC)
|
|
417
|
+
else:
|
|
418
|
+
assert isinstance(grid_size, int)
|
|
419
|
+
grid_shape = (1, grid_size, grid_size, numC)
|
|
420
|
+
patch_shape = (1, patch_size, patch_size, numC)
|
|
421
|
+
|
|
422
|
+
return patch_shape, grid_shape
|
|
423
|
+
|
|
424
|
+
def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
|
|
425
|
+
"""
|
|
426
|
+
If one wants to change the image size on the go, then this can be used.
|
|
427
|
+
Args:
|
|
428
|
+
image_size: size of one patch
|
|
429
|
+
grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
|
|
430
|
+
"""
|
|
431
|
+
# hacky way to deal with image shape from new conf
|
|
432
|
+
self._img_sz = image_size[-1] # TODO revisit!
|
|
433
|
+
self._grid_sz = grid_size
|
|
434
|
+
shape = self._data.shape
|
|
435
|
+
|
|
436
|
+
patch_shape, grid_shape = self.get_idx_manager_shapes(
|
|
437
|
+
self._img_sz, self._grid_sz
|
|
438
|
+
)
|
|
439
|
+
self.idx_manager = GridIndexManager(
|
|
440
|
+
shape, grid_shape, patch_shape, self._tiling_mode
|
|
441
|
+
)
|
|
442
|
+
# self.set_repeat_factor()
|
|
443
|
+
|
|
444
|
+
def __len__(self):
|
|
445
|
+
# Vera: N is the number of frames in Z stack
|
|
446
|
+
# Repeat factor is n_rows * n_cols
|
|
447
|
+
return self.idx_manager.total_grid_count()
|
|
448
|
+
|
|
449
|
+
def set_repeat_factor(self):
|
|
450
|
+
if self._grid_sz > 1:
|
|
451
|
+
self._repeat_factor = self.idx_manager.grid_rows(
|
|
452
|
+
self._grid_sz
|
|
453
|
+
) * self.idx_manager.grid_cols(self._grid_sz)
|
|
454
|
+
else:
|
|
455
|
+
self._repeat_factor = self.idx_manager.grid_rows(
|
|
456
|
+
self._img_sz
|
|
457
|
+
) * self.idx_manager.grid_cols(self._img_sz)
|
|
458
|
+
|
|
459
|
+
def _init_msg(
|
|
460
|
+
self,
|
|
461
|
+
):
|
|
462
|
+
msg = (
|
|
463
|
+
f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
|
|
464
|
+
)
|
|
465
|
+
dim_sizes = [
|
|
466
|
+
self.idx_manager.get_individual_dim_grid_count(dim)
|
|
467
|
+
for dim in range(len(self._data.shape))
|
|
468
|
+
]
|
|
469
|
+
dim_sizes = ",".join([str(x) for x in dim_sizes])
|
|
470
|
+
msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
|
|
471
|
+
msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
|
|
472
|
+
msg += f" TrimB:{self._tiling_mode}"
|
|
473
|
+
# msg += f' NormInp:{self._normalized_input}'
|
|
474
|
+
# msg += f' SingleNorm:{self._use_one_mu_std}'
|
|
475
|
+
msg += f" Rot:{self._enable_rotation}"
|
|
476
|
+
if self._flipz_3D:
|
|
477
|
+
msg += f" FlipZ:{self._flipz_3D}"
|
|
478
|
+
|
|
479
|
+
msg += f" RandCrop:{self._enable_random_cropping}"
|
|
480
|
+
msg += f" Channel:{self._num_channels}"
|
|
481
|
+
# msg += f' Q:{self._quantile}'
|
|
482
|
+
if self._input_is_sum:
|
|
483
|
+
msg += f" SummedInput:{self._input_is_sum}"
|
|
484
|
+
|
|
485
|
+
if self._empty_patch_replacement_enabled:
|
|
486
|
+
msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}"
|
|
487
|
+
if self._uncorrelated_channels:
|
|
488
|
+
msg += f" Uncorr:{self._uncorrelated_channels}"
|
|
489
|
+
if self._empty_patch_replacement_enabled:
|
|
490
|
+
msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}"
|
|
491
|
+
if self._background_quantile > 0.0:
|
|
492
|
+
msg += f" BckQ:{self._background_quantile}"
|
|
493
|
+
|
|
494
|
+
if self._start_alpha_arr is not None:
|
|
495
|
+
msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]"
|
|
496
|
+
return msg
|
|
497
|
+
|
|
498
|
+
def _crop_imgs(self, index, *img_tuples: np.ndarray):
|
|
499
|
+
h, w = img_tuples[0].shape[-2:]
|
|
500
|
+
if self._img_sz is None:
|
|
501
|
+
return (
|
|
502
|
+
*img_tuples,
|
|
503
|
+
{"h": [0, h], "w": [0, w], "hflip": False, "wflip": False},
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
if self._enable_random_cropping:
|
|
507
|
+
patch_start_loc = self._get_random_hw(h, w)
|
|
508
|
+
if self._5Ddata:
|
|
509
|
+
patch_start_loc = (
|
|
510
|
+
np.random.choice(1 + img_tuples[0].shape[-3] - self._depth3D),
|
|
511
|
+
) + patch_start_loc
|
|
512
|
+
else:
|
|
513
|
+
patch_start_loc = self._get_deterministic_loc(index)
|
|
514
|
+
|
|
515
|
+
cropped_imgs = []
|
|
516
|
+
for img in img_tuples:
|
|
517
|
+
img = self._crop_flip_img(img, patch_start_loc, False, False)
|
|
518
|
+
cropped_imgs.append(img)
|
|
519
|
+
|
|
520
|
+
return (
|
|
521
|
+
*tuple(cropped_imgs),
|
|
522
|
+
{
|
|
523
|
+
"hflip": False,
|
|
524
|
+
"wflip": False,
|
|
525
|
+
},
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
|
|
529
|
+
if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
|
|
530
|
+
# In training, this is used.
|
|
531
|
+
# NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
|
|
532
|
+
# The only benefit this if else loop provides is that it makes it easier to see what happens during training.
|
|
533
|
+
patch_end_loc = (
|
|
534
|
+
np.array(patch_start_loc, dtype=np.int32)
|
|
535
|
+
+ self.idx_manager.patch_shape[1:-1]
|
|
536
|
+
)
|
|
537
|
+
if self._5Ddata:
|
|
538
|
+
z_start, h_start, w_start = patch_start_loc
|
|
539
|
+
z_end, h_end, w_end = patch_end_loc
|
|
540
|
+
new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
|
|
541
|
+
else:
|
|
542
|
+
h_start, w_start = patch_start_loc
|
|
543
|
+
h_end, w_end = patch_end_loc
|
|
544
|
+
new_img = img[..., h_start:h_end, w_start:w_end]
|
|
545
|
+
|
|
546
|
+
return new_img
|
|
547
|
+
else:
|
|
548
|
+
# During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
|
|
549
|
+
# In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
|
|
550
|
+
return self._crop_img_with_padding(img, patch_start_loc)
|
|
551
|
+
|
|
552
|
+
def get_begin_end_padding(self, start_pos, end_pos, max_len):
|
|
553
|
+
"""
|
|
554
|
+
The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
|
|
555
|
+
padding on all four sides so that the final patch size is self._img_sz.
|
|
556
|
+
"""
|
|
557
|
+
pad_start = 0
|
|
558
|
+
pad_end = 0
|
|
559
|
+
if start_pos < 0:
|
|
560
|
+
pad_start = -1 * start_pos
|
|
561
|
+
|
|
562
|
+
pad_end = max(0, end_pos - max_len)
|
|
563
|
+
|
|
564
|
+
return pad_start, pad_end
|
|
565
|
+
|
|
566
|
+
def _crop_img_with_padding(
|
|
567
|
+
self, img: np.ndarray, patch_start_loc, max_len_vals=None
|
|
568
|
+
):
|
|
569
|
+
if max_len_vals is None:
|
|
570
|
+
max_len_vals = self.idx_manager.data_shape[1:-1]
|
|
571
|
+
patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
|
|
572
|
+
self.idx_manager.patch_shape[1:-1], dtype=int
|
|
573
|
+
)
|
|
574
|
+
boundary_crossed = []
|
|
575
|
+
valid_slice = []
|
|
576
|
+
padding = [[0, 0]]
|
|
577
|
+
for start_idx, end_idx, max_len in zip(
|
|
578
|
+
patch_start_loc, patch_end_loc, max_len_vals
|
|
579
|
+
):
|
|
580
|
+
boundary_crossed.append(end_idx > max_len or start_idx < 0)
|
|
581
|
+
valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
|
|
582
|
+
pad = [0, 0]
|
|
583
|
+
if boundary_crossed[-1]:
|
|
584
|
+
pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
|
|
585
|
+
padding.append(pad)
|
|
586
|
+
# max() is needed since h_start could be negative.
|
|
587
|
+
if self._5Ddata:
|
|
588
|
+
new_img = img[
|
|
589
|
+
...,
|
|
590
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
591
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
592
|
+
valid_slice[2][0] : valid_slice[2][1],
|
|
593
|
+
]
|
|
594
|
+
else:
|
|
595
|
+
new_img = img[
|
|
596
|
+
...,
|
|
597
|
+
valid_slice[0][0] : valid_slice[0][1],
|
|
598
|
+
valid_slice[1][0] : valid_slice[1][1],
|
|
599
|
+
]
|
|
600
|
+
|
|
601
|
+
# print(np.array(padding).shape, img.shape, new_img.shape)
|
|
602
|
+
# print(padding)
|
|
603
|
+
if not np.all(padding == 0):
|
|
604
|
+
new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
|
|
605
|
+
|
|
606
|
+
return new_img
|
|
607
|
+
|
|
608
|
+
def _crop_flip_img(
|
|
609
|
+
self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
|
|
610
|
+
):
|
|
611
|
+
new_img = self._crop_img(img, patch_start_loc)
|
|
612
|
+
if h_flip:
|
|
613
|
+
new_img = new_img[..., ::-1, :]
|
|
614
|
+
if w_flip:
|
|
615
|
+
new_img = new_img[..., :, ::-1]
|
|
616
|
+
|
|
617
|
+
return new_img.astype(np.float32)
|
|
618
|
+
|
|
619
|
+
def _load_img(
|
|
620
|
+
self, index: Union[int, tuple[int, int]]
|
|
621
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
622
|
+
"""
|
|
623
|
+
Returns the channels and also the respective noise channels.
|
|
624
|
+
"""
|
|
625
|
+
if isinstance(index, int) or isinstance(index, np.int64):
|
|
626
|
+
idx = index
|
|
627
|
+
else:
|
|
628
|
+
idx = index[0]
|
|
629
|
+
|
|
630
|
+
patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
|
|
631
|
+
imgs = self._data[patch_loc_list[0]]
|
|
632
|
+
# if self._5Ddata:
|
|
633
|
+
# assert self._noise_data is None, 'Noise is not supported for 5D data'
|
|
634
|
+
# n_loc, z_loc = patch_loc_list[:2]
|
|
635
|
+
# z_loc_interval = range(z_loc, z_loc + self._depth3D)
|
|
636
|
+
# imgs = self._data[n_loc, z_loc_interval]
|
|
637
|
+
# else:
|
|
638
|
+
# imgs = self._data[patch_loc_list[0]]
|
|
639
|
+
|
|
640
|
+
loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
|
|
641
|
+
noise = []
|
|
642
|
+
if self._noise_data is not None and not self._disable_noise:
|
|
643
|
+
noise = [
|
|
644
|
+
self._noise_data[patch_loc_list[0]][None, ..., i]
|
|
645
|
+
for i in range(self._noise_data.shape[-1])
|
|
646
|
+
]
|
|
647
|
+
return tuple(loaded_imgs), tuple(noise)
|
|
648
|
+
|
|
649
|
+
def get_mean_std(self):
|
|
650
|
+
return self._mean, self._std
|
|
651
|
+
|
|
652
|
+
def set_mean_std(self, mean_val, std_val):
|
|
653
|
+
self._mean = mean_val
|
|
654
|
+
self._std = std_val
|
|
655
|
+
|
|
656
|
+
def normalize_img(self, *img_tuples):
|
|
657
|
+
mean, std = self.get_mean_std()
|
|
658
|
+
mean = mean["target"]
|
|
659
|
+
std = std["target"]
|
|
660
|
+
mean = mean.squeeze()
|
|
661
|
+
std = std.squeeze()
|
|
662
|
+
normalized_imgs = []
|
|
663
|
+
for i, img in enumerate(img_tuples):
|
|
664
|
+
img = (img - mean[i]) / std[i]
|
|
665
|
+
normalized_imgs.append(img)
|
|
666
|
+
return tuple(normalized_imgs)
|
|
667
|
+
|
|
668
|
+
def normalize_input(self, x):
|
|
669
|
+
mean_dict, std_dict = self.get_mean_std()
|
|
670
|
+
mean_ = mean_dict["input"].mean()
|
|
671
|
+
std_ = std_dict["input"].mean()
|
|
672
|
+
return (x - mean_) / std_
|
|
673
|
+
|
|
674
|
+
def normalize_target(self, target):
|
|
675
|
+
mean_dict, std_dict = self.get_mean_std()
|
|
676
|
+
mean_ = mean_dict["target"].squeeze(0)
|
|
677
|
+
std_ = std_dict["target"].squeeze(0)
|
|
678
|
+
return (target - mean_) / std_
|
|
679
|
+
|
|
680
|
+
def get_grid_size(self):
|
|
681
|
+
return self._grid_sz
|
|
682
|
+
|
|
683
|
+
def get_idx_manager(self):
|
|
684
|
+
return self.idx_manager
|
|
685
|
+
|
|
686
|
+
def per_side_overlap_pixelcount(self):
|
|
687
|
+
return (self._img_sz - self._grid_sz) // 2
|
|
688
|
+
|
|
689
|
+
# def on_boundary(self, cur_loc, frame_size):
|
|
690
|
+
# return cur_loc + self._img_sz > frame_size or cur_loc < 0
|
|
691
|
+
|
|
692
|
+
def _get_deterministic_loc(self, index: int):
|
|
693
|
+
"""
|
|
694
|
+
It returns the top-left corner of the patch corresponding to index.
|
|
695
|
+
"""
|
|
696
|
+
loc_list = self.idx_manager.get_patch_location_from_dataset_idx(index)
|
|
697
|
+
# last dim is channel. we need to take the third and the second last element.
|
|
698
|
+
return loc_list[1:-1]
|
|
699
|
+
|
|
700
|
+
def compute_individual_mean_std(self):
|
|
701
|
+
# numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
|
|
702
|
+
# mean = np.mean(self._data, axis=(0, 1, 2))
|
|
703
|
+
# std = np.std(self._data, axis=(0, 1, 2))
|
|
704
|
+
mean_arr = []
|
|
705
|
+
std_arr = []
|
|
706
|
+
for ch_idx in range(self._data.shape[-1]):
|
|
707
|
+
mean_ = (
|
|
708
|
+
0.0
|
|
709
|
+
if self._skip_normalization_using_mean
|
|
710
|
+
else self._data[..., ch_idx].mean()
|
|
711
|
+
)
|
|
712
|
+
if self._noise_data is not None:
|
|
713
|
+
std_ = (
|
|
714
|
+
self._data[..., ch_idx] + self._noise_data[..., ch_idx + 1]
|
|
715
|
+
).std()
|
|
716
|
+
else:
|
|
717
|
+
std_ = self._data[..., ch_idx].std()
|
|
718
|
+
|
|
719
|
+
mean_arr.append(mean_)
|
|
720
|
+
std_arr.append(std_)
|
|
721
|
+
|
|
722
|
+
mean = np.array(mean_arr)
|
|
723
|
+
std = np.array(std_arr)
|
|
724
|
+
if (
|
|
725
|
+
self._5Ddata
|
|
726
|
+
): # NOTE: IDEALLY this should be only when the model expects 3D data.
|
|
727
|
+
return mean[None, :, None, None, None], std[None, :, None, None, None]
|
|
728
|
+
|
|
729
|
+
return mean[None, :, None, None], std[None, :, None, None]
|
|
730
|
+
|
|
731
|
+
def compute_mean_std(self, allow_for_validation_data=False):
|
|
732
|
+
"""
|
|
733
|
+
Note that we must compute this only for training data.
|
|
734
|
+
"""
|
|
735
|
+
assert (
|
|
736
|
+
self._is_train is True or allow_for_validation_data
|
|
737
|
+
), "This is just allowed for training data"
|
|
738
|
+
assert self._use_one_mu_std is True, "This is the only supported case"
|
|
739
|
+
|
|
740
|
+
if self._input_idx is not None:
|
|
741
|
+
assert (
|
|
742
|
+
self._tar_idx_list is not None
|
|
743
|
+
), "tar_idx_list must be set if input_idx is set."
|
|
744
|
+
assert self._noise_data is None, "This is not supported with noise"
|
|
745
|
+
assert (
|
|
746
|
+
self._target_separate_normalization is True
|
|
747
|
+
), "This is not supported with target_separate_normalization=False"
|
|
748
|
+
|
|
749
|
+
mean, std = self.compute_individual_mean_std()
|
|
750
|
+
mean_dict = {
|
|
751
|
+
"input": mean[:, self._input_idx : self._input_idx + 1],
|
|
752
|
+
"target": mean[:, self._tar_idx_list],
|
|
753
|
+
}
|
|
754
|
+
std_dict = {
|
|
755
|
+
"input": std[:, self._input_idx : self._input_idx + 1],
|
|
756
|
+
"target": std[:, self._tar_idx_list],
|
|
757
|
+
}
|
|
758
|
+
return mean_dict, std_dict
|
|
759
|
+
|
|
760
|
+
if self._input_is_sum:
|
|
761
|
+
assert self._noise_data is None, "This is not supported with noise"
|
|
762
|
+
mean = [
|
|
763
|
+
np.mean(self._data[..., k : k + 1], keepdims=True)
|
|
764
|
+
for k in range(self._num_channels)
|
|
765
|
+
]
|
|
766
|
+
mean = np.sum(mean, keepdims=True)[0]
|
|
767
|
+
std = np.linalg.norm(
|
|
768
|
+
[
|
|
769
|
+
np.std(self._data[..., k : k + 1], keepdims=True)
|
|
770
|
+
for k in range(self._num_channels)
|
|
771
|
+
],
|
|
772
|
+
keepdims=True,
|
|
773
|
+
)[0]
|
|
774
|
+
else:
|
|
775
|
+
mean = np.mean(self._data, keepdims=True).reshape(1, 1, 1, 1)
|
|
776
|
+
if self._noise_data is not None:
|
|
777
|
+
std = np.std(
|
|
778
|
+
self._data + self._noise_data[..., 1:], keepdims=True
|
|
779
|
+
).reshape(1, 1, 1, 1)
|
|
780
|
+
else:
|
|
781
|
+
std = np.std(self._data, keepdims=True).reshape(1, 1, 1, 1)
|
|
782
|
+
|
|
783
|
+
mean = np.repeat(mean, self._num_channels, axis=1)
|
|
784
|
+
std = np.repeat(std, self._num_channels, axis=1)
|
|
785
|
+
|
|
786
|
+
if self._skip_normalization_using_mean:
|
|
787
|
+
mean = np.zeros_like(mean)
|
|
788
|
+
|
|
789
|
+
if self._5Ddata:
|
|
790
|
+
mean = mean[:, :, None]
|
|
791
|
+
std = std[:, :, None]
|
|
792
|
+
|
|
793
|
+
mean_dict = {"input": mean} # , 'target':mean}
|
|
794
|
+
std_dict = {"input": std} # , 'target':std}
|
|
795
|
+
|
|
796
|
+
if self._target_separate_normalization:
|
|
797
|
+
mean, std = self.compute_individual_mean_std()
|
|
798
|
+
|
|
799
|
+
mean_dict["target"] = mean
|
|
800
|
+
std_dict["target"] = std
|
|
801
|
+
return mean_dict, std_dict
|
|
802
|
+
|
|
803
|
+
def _get_random_hw(self, h: int, w: int):
|
|
804
|
+
"""
|
|
805
|
+
Random starting position for the crop for the img with index `index`.
|
|
806
|
+
"""
|
|
807
|
+
if h != self._img_sz:
|
|
808
|
+
h_start = np.random.choice(h - self._img_sz)
|
|
809
|
+
w_start = np.random.choice(w - self._img_sz)
|
|
810
|
+
else:
|
|
811
|
+
h_start = 0
|
|
812
|
+
w_start = 0
|
|
813
|
+
return h_start, w_start
|
|
814
|
+
|
|
815
|
+
def _get_img(self, index: Union[int, tuple[int, int]]):
|
|
816
|
+
"""
|
|
817
|
+
Loads an image.
|
|
818
|
+
Crops the image such that cropped image has content.
|
|
819
|
+
"""
|
|
820
|
+
img_tuples, noise_tuples = self._load_img(index)
|
|
821
|
+
cropped_img_tuples = self._crop_imgs(index, *img_tuples, *noise_tuples)[:-1]
|
|
822
|
+
cropped_noise_tuples = cropped_img_tuples[len(img_tuples) :]
|
|
823
|
+
cropped_img_tuples = cropped_img_tuples[: len(img_tuples)]
|
|
824
|
+
return cropped_img_tuples, cropped_noise_tuples
|
|
825
|
+
|
|
826
|
+
def replace_with_empty_patch(self, img_tuples):
|
|
827
|
+
"""
|
|
828
|
+
Replaces the content of one of the channels with background
|
|
829
|
+
"""
|
|
830
|
+
empty_index = self._empty_patch_fetcher.sample()
|
|
831
|
+
empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
|
|
832
|
+
assert (
|
|
833
|
+
len(empty_img_noise_tuples) == 0
|
|
834
|
+
), "Noise is not supported with empty patch replacement"
|
|
835
|
+
final_img_tuples = []
|
|
836
|
+
for tuple_idx in range(len(img_tuples)):
|
|
837
|
+
if tuple_idx == self._empty_patch_replacement_channel_idx:
|
|
838
|
+
final_img_tuples.append(empty_img_tuples[tuple_idx])
|
|
839
|
+
else:
|
|
840
|
+
final_img_tuples.append(img_tuples[tuple_idx])
|
|
841
|
+
return tuple(final_img_tuples)
|
|
842
|
+
|
|
843
|
+
def get_mean_std_for_input(self):
|
|
844
|
+
mean, std = self.get_mean_std()
|
|
845
|
+
return mean["input"], std["input"]
|
|
846
|
+
|
|
847
|
+
def _compute_target(self, img_tuples, alpha):
|
|
848
|
+
if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int):
|
|
849
|
+
target = img_tuples[self._tar_idx_list]
|
|
850
|
+
else:
|
|
851
|
+
if self._tar_idx_list is not None:
|
|
852
|
+
assert isinstance(self._tar_idx_list, list) or isinstance(
|
|
853
|
+
self._tar_idx_list, tuple
|
|
854
|
+
)
|
|
855
|
+
img_tuples = [img_tuples[i] for i in self._tar_idx_list]
|
|
856
|
+
|
|
857
|
+
target = np.concatenate(img_tuples, axis=0)
|
|
858
|
+
return target
|
|
859
|
+
|
|
860
|
+
def _compute_input_with_alpha(self, img_tuples, alpha_list):
|
|
861
|
+
# assert self._normalized_input is True, "normalization should happen here"
|
|
862
|
+
if self._input_idx is not None:
|
|
863
|
+
inp = img_tuples[self._input_idx]
|
|
864
|
+
else:
|
|
865
|
+
inp = 0
|
|
866
|
+
for alpha, img in zip(alpha_list, img_tuples):
|
|
867
|
+
inp += img * alpha
|
|
868
|
+
|
|
869
|
+
if self._normalized_input is False:
|
|
870
|
+
return inp.astype(np.float32)
|
|
871
|
+
|
|
872
|
+
mean, std = self.get_mean_std_for_input()
|
|
873
|
+
mean = mean.squeeze()
|
|
874
|
+
std = std.squeeze()
|
|
875
|
+
if mean.size == 1:
|
|
876
|
+
mean = mean.reshape(
|
|
877
|
+
1,
|
|
878
|
+
)
|
|
879
|
+
std = std.reshape(
|
|
880
|
+
1,
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
for i in range(len(mean)):
|
|
884
|
+
assert mean[0] == mean[i]
|
|
885
|
+
assert std[0] == std[i]
|
|
886
|
+
|
|
887
|
+
inp = (inp - mean[0]) / std[0]
|
|
888
|
+
return inp.astype(np.float32)
|
|
889
|
+
|
|
890
|
+
def _sample_alpha(self):
|
|
891
|
+
alpha_arr = []
|
|
892
|
+
for i in range(self._num_channels):
|
|
893
|
+
alpha_pos = np.random.rand()
|
|
894
|
+
alpha = self._start_alpha_arr[i] + alpha_pos * (
|
|
895
|
+
self._end_alpha_arr[i] - self._start_alpha_arr[i]
|
|
896
|
+
)
|
|
897
|
+
alpha_arr.append(alpha)
|
|
898
|
+
return alpha_arr
|
|
899
|
+
|
|
900
|
+
def _compute_input(self, img_tuples):
|
|
901
|
+
alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))]
|
|
902
|
+
if self._start_alpha_arr is not None:
|
|
903
|
+
alpha = self._sample_alpha()
|
|
904
|
+
|
|
905
|
+
inp = self._compute_input_with_alpha(img_tuples, alpha)
|
|
906
|
+
if self._input_is_sum:
|
|
907
|
+
inp = len(img_tuples) * inp
|
|
908
|
+
return inp, alpha
|
|
909
|
+
|
|
910
|
+
def _get_index_from_valid_target_logic(self, index):
|
|
911
|
+
if self._validtarget_rand_fract is not None:
|
|
912
|
+
if np.random.rand() < self._validtarget_rand_fract:
|
|
913
|
+
index = self._train_index_switcher.get_valid_target_index()
|
|
914
|
+
else:
|
|
915
|
+
index = self._train_index_switcher.get_invalid_target_index()
|
|
916
|
+
return index
|
|
917
|
+
|
|
918
|
+
def _rotate2D(self, img_tuples, noise_tuples):
|
|
919
|
+
img_kwargs = {}
|
|
920
|
+
for i, img in enumerate(img_tuples):
|
|
921
|
+
for k in range(len(img)):
|
|
922
|
+
img_kwargs[f"img{i}_{k}"] = img[k]
|
|
923
|
+
|
|
924
|
+
noise_kwargs = {}
|
|
925
|
+
for i, nimg in enumerate(noise_tuples):
|
|
926
|
+
for k in range(len(nimg)):
|
|
927
|
+
noise_kwargs[f"noise{i}_{k}"] = nimg[k]
|
|
928
|
+
|
|
929
|
+
keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
|
|
930
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
931
|
+
rot_dic = self._rotation_transform(
|
|
932
|
+
image=img_tuples[0][0], **img_kwargs, **noise_kwargs
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
rotated_img_tuples = []
|
|
936
|
+
for i, img in enumerate(img_tuples):
|
|
937
|
+
if len(img) == 1:
|
|
938
|
+
rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
|
|
939
|
+
else:
|
|
940
|
+
rotated_img_tuples.append(
|
|
941
|
+
np.concatenate(
|
|
942
|
+
[rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
|
|
943
|
+
)
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
rotated_noise_tuples = []
|
|
947
|
+
for i, nimg in enumerate(noise_tuples):
|
|
948
|
+
if len(nimg) == 1:
|
|
949
|
+
rotated_noise_tuples.append(rot_dic[f"noise{i}_0"][None])
|
|
950
|
+
else:
|
|
951
|
+
rotated_noise_tuples.append(
|
|
952
|
+
np.concatenate(
|
|
953
|
+
[rot_dic[f"noise{i}_{k}"][None] for k in range(len(nimg))],
|
|
954
|
+
axis=0,
|
|
955
|
+
)
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
return rotated_img_tuples, rotated_noise_tuples
|
|
959
|
+
|
|
960
|
+
def _rotate(self, img_tuples, noise_tuples):
|
|
961
|
+
|
|
962
|
+
if self._5Ddata:
|
|
963
|
+
return self._rotate3D(img_tuples, noise_tuples)
|
|
964
|
+
else:
|
|
965
|
+
return self._rotate2D(img_tuples, noise_tuples)
|
|
966
|
+
|
|
967
|
+
def _rotate3D(self, img_tuples, noise_tuples):
|
|
968
|
+
img_kwargs = {}
|
|
969
|
+
# random flip in z direction
|
|
970
|
+
flip_z = self._flipz_3D and np.random.rand() < 0.5
|
|
971
|
+
for i, img in enumerate(img_tuples):
|
|
972
|
+
for j in range(self._depth3D):
|
|
973
|
+
for k in range(len(img)):
|
|
974
|
+
if flip_z:
|
|
975
|
+
z_idx = self._depth3D - 1 - j
|
|
976
|
+
else:
|
|
977
|
+
z_idx = j
|
|
978
|
+
img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
|
|
979
|
+
|
|
980
|
+
noise_kwargs = {}
|
|
981
|
+
for i, nimg in enumerate(noise_tuples):
|
|
982
|
+
for j in range(self._depth3D):
|
|
983
|
+
for k in range(len(nimg)):
|
|
984
|
+
if flip_z:
|
|
985
|
+
z_idx = self._depth3D - 1 - j
|
|
986
|
+
else:
|
|
987
|
+
z_idx = j
|
|
988
|
+
noise_kwargs[f"noise{i}_{z_idx}_{k}"] = nimg[k, j]
|
|
989
|
+
|
|
990
|
+
keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
|
|
991
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
992
|
+
rot_dic = self._rotation_transform(
|
|
993
|
+
image=img_tuples[0][0][0], **img_kwargs, **noise_kwargs
|
|
994
|
+
)
|
|
995
|
+
rotated_img_tuples = []
|
|
996
|
+
for i, img in enumerate(img_tuples):
|
|
997
|
+
if len(img) == 1:
|
|
998
|
+
rotated_img_tuples.append(
|
|
999
|
+
np.concatenate(
|
|
1000
|
+
[
|
|
1001
|
+
rot_dic[f"img{i}_{j}_0"][None, None]
|
|
1002
|
+
for j in range(self._depth3D)
|
|
1003
|
+
],
|
|
1004
|
+
axis=1,
|
|
1005
|
+
)
|
|
1006
|
+
)
|
|
1007
|
+
else:
|
|
1008
|
+
temp_arr = []
|
|
1009
|
+
for k in range(len(img)):
|
|
1010
|
+
temp_arr.append(
|
|
1011
|
+
np.concatenate(
|
|
1012
|
+
[
|
|
1013
|
+
rot_dic[f"img{i}_{j}_{k}"][None, None]
|
|
1014
|
+
for j in range(self._depth3D)
|
|
1015
|
+
],
|
|
1016
|
+
axis=1,
|
|
1017
|
+
)
|
|
1018
|
+
)
|
|
1019
|
+
rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
|
|
1020
|
+
|
|
1021
|
+
rotated_noise_tuples = []
|
|
1022
|
+
for i, nimg in enumerate(noise_tuples):
|
|
1023
|
+
if len(nimg) == 1:
|
|
1024
|
+
rotated_noise_tuples.append(
|
|
1025
|
+
np.concatenate(
|
|
1026
|
+
[
|
|
1027
|
+
rot_dic[f"noise{i}_{j}_0"][None, None]
|
|
1028
|
+
for j in range(self._depth3D)
|
|
1029
|
+
],
|
|
1030
|
+
axis=1,
|
|
1031
|
+
)
|
|
1032
|
+
)
|
|
1033
|
+
else:
|
|
1034
|
+
temp_arr = []
|
|
1035
|
+
for k in range(len(nimg)):
|
|
1036
|
+
temp_arr.append(
|
|
1037
|
+
np.concatenate(
|
|
1038
|
+
[
|
|
1039
|
+
rot_dic[f"noise{i}_{j}_{k}"][None, None]
|
|
1040
|
+
for j in range(self._depth3D)
|
|
1041
|
+
],
|
|
1042
|
+
axis=1,
|
|
1043
|
+
)
|
|
1044
|
+
)
|
|
1045
|
+
rotated_noise_tuples.append(np.concatenate(temp_arr, axis=0))
|
|
1046
|
+
|
|
1047
|
+
return rotated_img_tuples, rotated_noise_tuples
|
|
1048
|
+
|
|
1049
|
+
def get_uncorrelated_img_tuples(self, index):
|
|
1050
|
+
"""
|
|
1051
|
+
Content of channels like actin and nuclei is "correlated" in its
|
|
1052
|
+
respective location, this function allows to pick channels' content
|
|
1053
|
+
from different patches of the image to make it "uncorrelated".
|
|
1054
|
+
"""
|
|
1055
|
+
img_tuples, noise_tuples = self._get_img(index)
|
|
1056
|
+
assert len(noise_tuples) == 0
|
|
1057
|
+
img_tuples = [img_tuples[0]]
|
|
1058
|
+
for ch_idx in range(1, len(img_tuples)):
|
|
1059
|
+
new_index = np.random.randint(len(self))
|
|
1060
|
+
other_img_tuples, _ = self._get_img(new_index)
|
|
1061
|
+
img_tuples.append(other_img_tuples[ch_idx])
|
|
1062
|
+
return img_tuples, noise_tuples
|
|
1063
|
+
|
|
1064
|
+
def __getitem__(
|
|
1065
|
+
self, index: Union[int, tuple[int, int]]
|
|
1066
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
1067
|
+
# Vera: input can be both real microscopic image and two separate channels that are summed in the code
|
|
1068
|
+
|
|
1069
|
+
if self._train_index_switcher is not None:
|
|
1070
|
+
index = self._get_index_from_valid_target_logic(index)
|
|
1071
|
+
|
|
1072
|
+
if (
|
|
1073
|
+
self._uncorrelated_channels
|
|
1074
|
+
and np.random.rand() < self._uncorrelated_channel_probab
|
|
1075
|
+
):
|
|
1076
|
+
img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
|
|
1077
|
+
else:
|
|
1078
|
+
img_tuples, noise_tuples = self._get_img(index)
|
|
1079
|
+
|
|
1080
|
+
assert (
|
|
1081
|
+
self._empty_patch_replacement_enabled != True
|
|
1082
|
+
), "This is not supported with noise"
|
|
1083
|
+
|
|
1084
|
+
# Replace the content of one of the channels
|
|
1085
|
+
# with background with given probability
|
|
1086
|
+
if self._empty_patch_replacement_enabled:
|
|
1087
|
+
if np.random.rand() < self._empty_patch_replacement_probab:
|
|
1088
|
+
img_tuples = self.replace_with_empty_patch(img_tuples)
|
|
1089
|
+
|
|
1090
|
+
# Noise tuples are not needed for the paper
|
|
1091
|
+
# the image tuples are noisy by default
|
|
1092
|
+
# TODO: remove noise tuples completely?
|
|
1093
|
+
if self._enable_rotation:
|
|
1094
|
+
img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
|
|
1095
|
+
|
|
1096
|
+
# Add noise tuples with image tuples to create the input
|
|
1097
|
+
if len(noise_tuples) > 0:
|
|
1098
|
+
factor = np.sqrt(2) if self._input_is_sum else 1.0
|
|
1099
|
+
input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
|
|
1100
|
+
else:
|
|
1101
|
+
input_tuples = img_tuples
|
|
1102
|
+
|
|
1103
|
+
# Weight the individual channels, typically alpha is fixed
|
|
1104
|
+
inp, alpha = self._compute_input(input_tuples)
|
|
1105
|
+
|
|
1106
|
+
# Add noise tuples to the image tuples to create the target
|
|
1107
|
+
if len(noise_tuples) >= 1:
|
|
1108
|
+
img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
|
|
1109
|
+
|
|
1110
|
+
target = self._compute_target(img_tuples, alpha)
|
|
1111
|
+
norm_target = self.normalize_target(target)
|
|
1112
|
+
|
|
1113
|
+
output = [inp, norm_target]
|
|
1114
|
+
|
|
1115
|
+
if self._return_alpha:
|
|
1116
|
+
output.append(alpha)
|
|
1117
|
+
|
|
1118
|
+
if self._return_index:
|
|
1119
|
+
output.append(index)
|
|
1120
|
+
|
|
1121
|
+
return tuple(output)
|