careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Here, we have multiple folders, each containing images of a single channel.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from functools import cache
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from .types import DataSplitType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def l2(x):
|
|
14
|
+
return np.sqrt(np.mean(np.array(x) ** 2))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MultiCropDset:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
data_config,
|
|
21
|
+
fpath: str,
|
|
22
|
+
load_data_fn=None,
|
|
23
|
+
val_fraction=None,
|
|
24
|
+
test_fraction=None,
|
|
25
|
+
):
|
|
26
|
+
|
|
27
|
+
assert (
|
|
28
|
+
data_config.input_is_sum == True
|
|
29
|
+
), "This dataset is designed for sum of images"
|
|
30
|
+
|
|
31
|
+
self._img_sz = data_config.image_size
|
|
32
|
+
self._enable_rotation = data_config.enable_rotation_aug
|
|
33
|
+
|
|
34
|
+
self._background_values = data_config.background_values
|
|
35
|
+
self._data = load_data_fn(
|
|
36
|
+
data_config, fpath, data_config.datasplit_type, val_fraction, test_fraction
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# remove upper quantiles, crucial for removing puncta
|
|
40
|
+
self.max_val = data_config.max_val
|
|
41
|
+
if self.max_val is not None:
|
|
42
|
+
for ch_idx, data in enumerate(self._data):
|
|
43
|
+
if self.max_val[ch_idx] is not None:
|
|
44
|
+
for idx in range(len(data)):
|
|
45
|
+
data[idx][data[idx] > self.max_val[ch_idx]] = self.max_val[
|
|
46
|
+
ch_idx
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
# remove background values
|
|
50
|
+
if self._background_values is not None:
|
|
51
|
+
final_data_arr = []
|
|
52
|
+
for ch_idx, data in enumerate(self._data):
|
|
53
|
+
data_float = [x.astype(np.float32) for x in data]
|
|
54
|
+
final_data_arr.append(
|
|
55
|
+
[x - self._background_values[ch_idx] for x in data_float]
|
|
56
|
+
)
|
|
57
|
+
self._data = final_data_arr
|
|
58
|
+
|
|
59
|
+
print(
|
|
60
|
+
f"{self.__class__.__name__} N:{len(self)} Rot:{self._enable_rotation} Ch:{len(self._data)} MaxVal:{self.max_val} Bg:{self._background_values}"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def get_max_val(self):
|
|
64
|
+
return self.max_val
|
|
65
|
+
|
|
66
|
+
def compute_mean_std(self):
|
|
67
|
+
mean_tar_dict = defaultdict(list)
|
|
68
|
+
std_tar_dict = defaultdict(list)
|
|
69
|
+
mean_inp = []
|
|
70
|
+
std_inp = []
|
|
71
|
+
for _ in range(30000):
|
|
72
|
+
crops = []
|
|
73
|
+
for ch_idx in range(len(self._data)):
|
|
74
|
+
crop = self.sample_crop(ch_idx)
|
|
75
|
+
mean_tar_dict[ch_idx].append(np.mean(crop))
|
|
76
|
+
std_tar_dict[ch_idx].append(np.std(crop))
|
|
77
|
+
crops.append(crop)
|
|
78
|
+
|
|
79
|
+
inp = 0
|
|
80
|
+
for img in crops:
|
|
81
|
+
inp += img
|
|
82
|
+
|
|
83
|
+
mean_inp.append(np.mean(inp))
|
|
84
|
+
std_inp.append(np.std(inp))
|
|
85
|
+
|
|
86
|
+
output_mean = defaultdict(list)
|
|
87
|
+
output_std = defaultdict(list)
|
|
88
|
+
NC = len(self._data)
|
|
89
|
+
for ch_idx in range(NC):
|
|
90
|
+
output_mean["target"].append(np.mean(mean_tar_dict[ch_idx]))
|
|
91
|
+
output_std["target"].append(l2(std_tar_dict[ch_idx]))
|
|
92
|
+
|
|
93
|
+
output_mean["target"] = np.array(output_mean["target"]).reshape(NC, 1, 1)
|
|
94
|
+
output_std["target"] = np.array(output_std["target"]).reshape(NC, 1, 1)
|
|
95
|
+
|
|
96
|
+
output_mean["input"] = np.array([np.mean(mean_inp)]).reshape(1, 1, 1)
|
|
97
|
+
output_std["input"] = np.array([l2(std_inp)]).reshape(1, 1, 1)
|
|
98
|
+
return dict(output_mean), dict(output_std)
|
|
99
|
+
|
|
100
|
+
def set_mean_std(self, mean_dict, std_dict):
|
|
101
|
+
self._data_mean = mean_dict
|
|
102
|
+
self._data_std = std_dict
|
|
103
|
+
|
|
104
|
+
def get_mean_std(self):
|
|
105
|
+
return self._data_mean, self._data_std
|
|
106
|
+
|
|
107
|
+
def get_num_frames(self):
|
|
108
|
+
return len(self._data)
|
|
109
|
+
|
|
110
|
+
@cache
|
|
111
|
+
def crop_probablities(self, ch_idx):
|
|
112
|
+
sizes = np.array([np.prod(x.shape) for x in self._data[ch_idx]])
|
|
113
|
+
return sizes / sizes.sum()
|
|
114
|
+
|
|
115
|
+
def sample_crop(self, ch_idx):
|
|
116
|
+
idx = None
|
|
117
|
+
count = 0
|
|
118
|
+
while idx is None:
|
|
119
|
+
count += 1
|
|
120
|
+
idx = np.random.choice(
|
|
121
|
+
len(self._data[ch_idx]), p=self.crop_probablities(ch_idx)
|
|
122
|
+
)
|
|
123
|
+
data = self._data[ch_idx][idx]
|
|
124
|
+
if data.shape[0] >= self._img_sz[0] and data.shape[1] >= self._img_sz[1]:
|
|
125
|
+
h = np.random.randint(0, data.shape[0] - self._img_sz[0])
|
|
126
|
+
w = np.random.randint(0, data.shape[1] - self._img_sz[1])
|
|
127
|
+
return data[h : h + self._img_sz[0], w : w + self._img_sz[1]]
|
|
128
|
+
elif count > 100:
|
|
129
|
+
raise ValueError("Cannot find a valid crop")
|
|
130
|
+
else:
|
|
131
|
+
idx = None
|
|
132
|
+
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
def len_per_channel(self, ch_idx):
|
|
136
|
+
return np.sum([np.prod(x.shape) for x in self._data[ch_idx]]) / np.prod(
|
|
137
|
+
self._img_sz
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def imgs_for_patch(self):
|
|
141
|
+
return [self.sample_crop(ch_idx) for ch_idx in range(len(self._data))]
|
|
142
|
+
|
|
143
|
+
def __len__(self):
|
|
144
|
+
len_per_channel = [
|
|
145
|
+
self.len_per_channel(ch_idx) for ch_idx in range(len(self._data))
|
|
146
|
+
]
|
|
147
|
+
return int(np.max(len_per_channel))
|
|
148
|
+
|
|
149
|
+
def _rotate(self, img_tuples):
|
|
150
|
+
return self._rotate2D(img_tuples)
|
|
151
|
+
|
|
152
|
+
def _rotate2D(self, img_tuples):
|
|
153
|
+
img_kwargs = {}
|
|
154
|
+
for i, img in enumerate(img_tuples):
|
|
155
|
+
for k in range(len(img)):
|
|
156
|
+
img_kwargs[f"img{i}_{k}"] = img[k]
|
|
157
|
+
|
|
158
|
+
keys = list(img_kwargs.keys())
|
|
159
|
+
self._rotation_transform.add_targets({k: "image" for k in keys})
|
|
160
|
+
rot_dic = self._rotation_transform(image=img_tuples[0][0], **img_kwargs)
|
|
161
|
+
|
|
162
|
+
rotated_img_tuples = []
|
|
163
|
+
for i, img in enumerate(img_tuples):
|
|
164
|
+
if len(img) == 1:
|
|
165
|
+
rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
|
|
166
|
+
else:
|
|
167
|
+
rotated_img_tuples.append(
|
|
168
|
+
np.concatenate(
|
|
169
|
+
[rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return rotated_img_tuples
|
|
174
|
+
|
|
175
|
+
def _compute_input(self, imgs):
|
|
176
|
+
inp = 0
|
|
177
|
+
for img in imgs:
|
|
178
|
+
inp += img
|
|
179
|
+
|
|
180
|
+
inp = inp[None]
|
|
181
|
+
inp = (inp - self._data_mean["input"]) / (self._data_std["input"])
|
|
182
|
+
return inp
|
|
183
|
+
|
|
184
|
+
def _compute_target(self, imgs):
|
|
185
|
+
imgs = np.stack(imgs)
|
|
186
|
+
target = (imgs - self._data_mean["target"]) / (self._data_std["target"])
|
|
187
|
+
return target
|
|
188
|
+
|
|
189
|
+
def __getitem__(self, idx):
|
|
190
|
+
imgs = self.imgs_for_patch()
|
|
191
|
+
if self._enable_rotation:
|
|
192
|
+
imgs = self._rotate(imgs)
|
|
193
|
+
|
|
194
|
+
inp = self._compute_input(imgs)
|
|
195
|
+
target = self._compute_target(imgs)
|
|
196
|
+
return inp, target
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Callable, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from .config import MicroSplitDataConfig
|
|
8
|
+
from .lc_dataset import LCMultiChDloader
|
|
9
|
+
from .multich_dataset import MultiChDloader
|
|
10
|
+
from .types import DataSplitType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TwoChannelData(Sequence):
|
|
14
|
+
"""
|
|
15
|
+
each element in data_arr should be a N*H*W array
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, data_arr1, data_arr2, paths_data1=None, paths_data2=None):
|
|
19
|
+
assert len(data_arr1) == len(data_arr2)
|
|
20
|
+
self.paths1 = paths_data1
|
|
21
|
+
self.paths2 = paths_data2
|
|
22
|
+
|
|
23
|
+
self._data = []
|
|
24
|
+
for i in range(len(data_arr1)):
|
|
25
|
+
assert data_arr1[i].shape == data_arr2[i].shape
|
|
26
|
+
assert (
|
|
27
|
+
len(data_arr1[i].shape) == 3
|
|
28
|
+
), f"Each element in data arrays should be a N*H*W, but {data_arr1[i].shape}"
|
|
29
|
+
self._data.append(
|
|
30
|
+
np.concatenate(
|
|
31
|
+
[data_arr1[i][..., None], data_arr2[i][..., None]], axis=-1
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def __len__(self):
|
|
36
|
+
n = 0
|
|
37
|
+
for x in self._data:
|
|
38
|
+
n += x.shape[0]
|
|
39
|
+
return n
|
|
40
|
+
|
|
41
|
+
def __getitem__(self, idx):
|
|
42
|
+
n = 0
|
|
43
|
+
for dataidx, x in enumerate(self._data):
|
|
44
|
+
if idx < n + x.shape[0]:
|
|
45
|
+
if self.paths1 is None:
|
|
46
|
+
return x[idx - n], None
|
|
47
|
+
else:
|
|
48
|
+
return x[idx - n], (self.paths1[dataidx], self.paths2[dataidx])
|
|
49
|
+
n += x.shape[0]
|
|
50
|
+
raise IndexError("Index out of range")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MultiChannelData(Sequence):
|
|
54
|
+
"""
|
|
55
|
+
each element in data_arr should be a N*H*W array
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, data_arr, paths=None):
|
|
59
|
+
self.paths = paths
|
|
60
|
+
|
|
61
|
+
self._data = data_arr
|
|
62
|
+
|
|
63
|
+
def __len__(self):
|
|
64
|
+
n = 0
|
|
65
|
+
for x in self._data:
|
|
66
|
+
n += x.shape[0]
|
|
67
|
+
return n
|
|
68
|
+
|
|
69
|
+
def __getitem__(self, idx):
|
|
70
|
+
n = 0
|
|
71
|
+
for dataidx, x in enumerate(self._data):
|
|
72
|
+
if idx < n + x.shape[0]:
|
|
73
|
+
if self.paths is None:
|
|
74
|
+
return x[idx - n], None
|
|
75
|
+
else:
|
|
76
|
+
return x[idx - n], (self.paths[dataidx])
|
|
77
|
+
n += x.shape[0]
|
|
78
|
+
raise IndexError("Index out of range")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SingleFileLCDset(LCMultiChDloader):
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
preloaded_data: NDArray,
|
|
85
|
+
data_config: MicroSplitDataConfig,
|
|
86
|
+
fpath: str,
|
|
87
|
+
load_data_fn: Callable,
|
|
88
|
+
val_fraction=None,
|
|
89
|
+
test_fraction=None,
|
|
90
|
+
):
|
|
91
|
+
self._preloaded_data = preloaded_data
|
|
92
|
+
super().__init__(
|
|
93
|
+
data_config,
|
|
94
|
+
fpath,
|
|
95
|
+
load_data_fn=load_data_fn,
|
|
96
|
+
val_fraction=val_fraction,
|
|
97
|
+
test_fraction=test_fraction,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def data_path(self):
|
|
102
|
+
return self._fpath
|
|
103
|
+
|
|
104
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
def load_data(
|
|
108
|
+
self,
|
|
109
|
+
data_config: MicroSplitDataConfig,
|
|
110
|
+
datasplit_type: DataSplitType,
|
|
111
|
+
load_data_fn: Callable,
|
|
112
|
+
val_fraction=None,
|
|
113
|
+
test_fraction=None,
|
|
114
|
+
allow_generation=None,
|
|
115
|
+
):
|
|
116
|
+
self._data = self._preloaded_data
|
|
117
|
+
assert "channel_1" not in data_config or isinstance(data_config.channel_1, str)
|
|
118
|
+
assert "channel_2" not in data_config or isinstance(data_config.channel_2, str)
|
|
119
|
+
assert "channel_3" not in data_config or isinstance(data_config.channel_3, str)
|
|
120
|
+
self._loaded_data_preprocessing(data_config)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class SingleFileDset(MultiChDloader):
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
preloaded_data: NDArray,
|
|
127
|
+
data_config: MicroSplitDataConfig,
|
|
128
|
+
fpath: str,
|
|
129
|
+
load_data_fn: Callable,
|
|
130
|
+
val_fraction=None,
|
|
131
|
+
test_fraction=None,
|
|
132
|
+
):
|
|
133
|
+
self._preloaded_data = preloaded_data
|
|
134
|
+
super().__init__(
|
|
135
|
+
data_config,
|
|
136
|
+
fpath,
|
|
137
|
+
load_data_fn=load_data_fn,
|
|
138
|
+
val_fraction=val_fraction,
|
|
139
|
+
test_fraction=test_fraction,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def data_path(self):
|
|
147
|
+
return self._fpath
|
|
148
|
+
|
|
149
|
+
def load_data(
|
|
150
|
+
self,
|
|
151
|
+
data_config: MicroSplitDataConfig,
|
|
152
|
+
datasplit_type: DataSplitType,
|
|
153
|
+
load_data_fn: Callable[..., NDArray],
|
|
154
|
+
val_fraction=None,
|
|
155
|
+
test_fraction=None,
|
|
156
|
+
allow_generation=None,
|
|
157
|
+
):
|
|
158
|
+
self._data = self._preloaded_data
|
|
159
|
+
assert (
|
|
160
|
+
"channel_1" not in data_config
|
|
161
|
+
), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
|
|
162
|
+
assert (
|
|
163
|
+
"channel_2" not in data_config
|
|
164
|
+
), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
|
|
165
|
+
assert (
|
|
166
|
+
"channel_3" not in data_config
|
|
167
|
+
), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
|
|
168
|
+
self._loaded_data_preprocessing(data_config)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class MultiFileDset:
|
|
172
|
+
"""
|
|
173
|
+
Here, we handle dataset having multiple files. Each file can have a different spatial dimension and number of frames (Z stack).
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
data_config: MicroSplitDataConfig,
|
|
179
|
+
fpath: str,
|
|
180
|
+
load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
|
|
181
|
+
val_fraction=None,
|
|
182
|
+
test_fraction=None,
|
|
183
|
+
):
|
|
184
|
+
self._fpath = fpath
|
|
185
|
+
data: Union[TwoChannelData, MultiChannelData] = load_data_fn(
|
|
186
|
+
data_config,
|
|
187
|
+
self._fpath,
|
|
188
|
+
data_config.datasplit_type,
|
|
189
|
+
val_fraction=val_fraction,
|
|
190
|
+
test_fraction=test_fraction,
|
|
191
|
+
)
|
|
192
|
+
self.dsets = []
|
|
193
|
+
|
|
194
|
+
for i in range(len(data)):
|
|
195
|
+
prefetched_data, fpath_tuple = data[i]
|
|
196
|
+
if (
|
|
197
|
+
data_config.multiscale_lowres_count is not None
|
|
198
|
+
and data_config.multiscale_lowres_count > 1
|
|
199
|
+
):
|
|
200
|
+
|
|
201
|
+
self.dsets.append(
|
|
202
|
+
SingleFileLCDset(
|
|
203
|
+
prefetched_data[None],
|
|
204
|
+
data_config,
|
|
205
|
+
fpath_tuple,
|
|
206
|
+
load_data_fn,
|
|
207
|
+
val_fraction=val_fraction,
|
|
208
|
+
test_fraction=test_fraction,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
else:
|
|
213
|
+
self.dsets.append(
|
|
214
|
+
SingleFileDset(
|
|
215
|
+
prefetched_data[None],
|
|
216
|
+
data_config,
|
|
217
|
+
fpath_tuple,
|
|
218
|
+
load_data_fn,
|
|
219
|
+
val_fraction=val_fraction,
|
|
220
|
+
test_fraction=test_fraction,
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
self.rm_bkground_set_max_val_and_upperclip_data(
|
|
225
|
+
data_config.max_val, data_config.datasplit_type
|
|
226
|
+
)
|
|
227
|
+
count = 0
|
|
228
|
+
avg_height = 0
|
|
229
|
+
avg_width = 0
|
|
230
|
+
for dset in self.dsets:
|
|
231
|
+
shape = dset.get_data_shape()
|
|
232
|
+
avg_height += shape[1]
|
|
233
|
+
avg_width += shape[2]
|
|
234
|
+
count += shape[0]
|
|
235
|
+
|
|
236
|
+
avg_height = int(avg_height / len(self.dsets))
|
|
237
|
+
avg_width = int(avg_width / len(self.dsets))
|
|
238
|
+
print(
|
|
239
|
+
f"{self.__class__.__name__} avg height: {avg_height}, avg width: {avg_width}, count: {count}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
|
|
243
|
+
self.set_max_val(max_val, datasplit_type)
|
|
244
|
+
self.upperclip_data()
|
|
245
|
+
|
|
246
|
+
def set_mean_std(self, mean_val, std_val):
|
|
247
|
+
for dset in self.dsets:
|
|
248
|
+
dset.set_mean_std(mean_val, std_val)
|
|
249
|
+
|
|
250
|
+
def get_mean_std(self):
|
|
251
|
+
return self.dsets[0].get_mean_std()
|
|
252
|
+
|
|
253
|
+
def compute_max_val(self):
|
|
254
|
+
max_val_arr = []
|
|
255
|
+
for dset in self.dsets:
|
|
256
|
+
max_val_arr.append(dset.compute_max_val())
|
|
257
|
+
return np.max(max_val_arr)
|
|
258
|
+
|
|
259
|
+
def set_max_val(self, max_val, datasplit_type):
|
|
260
|
+
if datasplit_type == DataSplitType.Train:
|
|
261
|
+
assert max_val is None
|
|
262
|
+
max_val = self.compute_max_val()
|
|
263
|
+
for dset in self.dsets:
|
|
264
|
+
dset.set_max_val(max_val, datasplit_type)
|
|
265
|
+
|
|
266
|
+
def upperclip_data(self):
|
|
267
|
+
for dset in self.dsets:
|
|
268
|
+
dset.upperclip_data()
|
|
269
|
+
|
|
270
|
+
def get_max_val(self):
|
|
271
|
+
return self.dsets[0].get_max_val()
|
|
272
|
+
|
|
273
|
+
def get_img_sz(self):
|
|
274
|
+
return self.dsets[0].get_img_sz()
|
|
275
|
+
|
|
276
|
+
def set_img_sz(self, image_size, grid_size):
|
|
277
|
+
for dset in self.dsets:
|
|
278
|
+
dset.set_img_sz(image_size, grid_size)
|
|
279
|
+
|
|
280
|
+
def compute_mean_std(self):
|
|
281
|
+
cur_mean = {"target": 0, "input": 0}
|
|
282
|
+
cur_std = {"target": 0, "input": 0}
|
|
283
|
+
for dset in self.dsets:
|
|
284
|
+
mean, std = dset.compute_mean_std()
|
|
285
|
+
cur_mean["target"] += mean["target"]
|
|
286
|
+
cur_mean["input"] += mean["input"]
|
|
287
|
+
|
|
288
|
+
cur_std["target"] += std["target"]
|
|
289
|
+
cur_std["input"] += std["input"]
|
|
290
|
+
|
|
291
|
+
cur_mean["target"] /= len(self.dsets)
|
|
292
|
+
cur_mean["input"] /= len(self.dsets)
|
|
293
|
+
cur_std["target"] /= len(self.dsets)
|
|
294
|
+
cur_std["input"] /= len(self.dsets)
|
|
295
|
+
return cur_mean, cur_std
|
|
296
|
+
|
|
297
|
+
def compute_individual_mean_std(self):
|
|
298
|
+
cum_mean = 0
|
|
299
|
+
cum_std = 0
|
|
300
|
+
for dset in self.dsets:
|
|
301
|
+
mean, std = dset.compute_individual_mean_std()
|
|
302
|
+
cum_mean += mean
|
|
303
|
+
cum_std += std
|
|
304
|
+
return cum_mean / len(self.dsets), cum_std / len(self.dsets)
|
|
305
|
+
|
|
306
|
+
def get_num_frames(self):
|
|
307
|
+
return len(self.dsets)
|
|
308
|
+
|
|
309
|
+
def reduce_data(
|
|
310
|
+
self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
|
|
311
|
+
):
|
|
312
|
+
assert h_start is None
|
|
313
|
+
assert h_end is None
|
|
314
|
+
assert w_start is None
|
|
315
|
+
assert w_end is None
|
|
316
|
+
self.dsets = [self.dsets[t] for t in t_list]
|
|
317
|
+
print(
|
|
318
|
+
f"[{self.__class__.__name__}] Data reduced. New data count: {len(self.dsets)}"
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def __len__(self):
|
|
322
|
+
out = 0
|
|
323
|
+
for dset in self.dsets:
|
|
324
|
+
out += len(dset)
|
|
325
|
+
return out
|
|
326
|
+
|
|
327
|
+
def __getitem__(self, idx):
|
|
328
|
+
cum_len = 0
|
|
329
|
+
for dset in self.dsets:
|
|
330
|
+
cum_len += len(dset)
|
|
331
|
+
if idx < cum_len:
|
|
332
|
+
rel_idx = idx - (cum_len - len(dset))
|
|
333
|
+
return dset[rel_idx]
|
|
334
|
+
|
|
335
|
+
raise IndexError("Index out of range")
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DataType(Enum):
|
|
5
|
+
HTH24Data = 0
|
|
6
|
+
HTLIF24Data = 1
|
|
7
|
+
PaviaP24Data = 2
|
|
8
|
+
TavernaSox2GolgiV2 = 3
|
|
9
|
+
Dao3ChannelWithInput = 4
|
|
10
|
+
ExpMicroscopyV1 = 5
|
|
11
|
+
ExpMicroscopyV2 = 6
|
|
12
|
+
Dao3Channel = 7
|
|
13
|
+
TavernaSox2Golgi = 8
|
|
14
|
+
HTIba1Ki67 = 9
|
|
15
|
+
OptiMEM100_014 = 10
|
|
16
|
+
SeparateTiffData = 11
|
|
17
|
+
BioSR_MRC = 12
|
|
18
|
+
HTH23BData = 13 # puncta, in case we have differently sized crops for each channel.
|
|
19
|
+
Care3D = 14
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataSplitType(Enum):
|
|
23
|
+
All = 0
|
|
24
|
+
Train = 1
|
|
25
|
+
Val = 2
|
|
26
|
+
Test = 3
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TilingMode(Enum):
|
|
30
|
+
TrimBoundary = 0
|
|
31
|
+
PadBoundary = 1
|
|
32
|
+
ShiftBoundary = 2
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions needed by dataloader & co.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from skimage.io import imread, imsave
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_tiff(path):
|
|
12
|
+
"""
|
|
13
|
+
Returns a 4d numpy array: num_imgs*h*w*num_channels
|
|
14
|
+
"""
|
|
15
|
+
data = imread(path, plugin="tifffile")
|
|
16
|
+
return data
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def save_tiff(path, data):
|
|
20
|
+
imsave(path, data, plugin="tifffile")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_tiffs(paths):
|
|
24
|
+
data = [load_tiff(path) for path in paths]
|
|
25
|
+
return np.concatenate(data, axis=0)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def split_in_half(s, e):
|
|
29
|
+
n = e - s
|
|
30
|
+
s1 = list(np.arange(n // 2))
|
|
31
|
+
s2 = list(np.arange(n // 2, n))
|
|
32
|
+
return [x + s for x in s1], [x + s for x in s2]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def adjust_for_imbalance_in_fraction_value(
|
|
36
|
+
val: List[int],
|
|
37
|
+
test: List[int],
|
|
38
|
+
val_fraction: float,
|
|
39
|
+
test_fraction: float,
|
|
40
|
+
total_size: int,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
here, val and test are divided almost equally. Here, we need to take into account their respective fractions
|
|
44
|
+
and pick elements rendomly from one array and put in the other array.
|
|
45
|
+
"""
|
|
46
|
+
if val_fraction == 0:
|
|
47
|
+
test += val
|
|
48
|
+
val = []
|
|
49
|
+
elif test_fraction == 0:
|
|
50
|
+
val += test
|
|
51
|
+
test = []
|
|
52
|
+
else:
|
|
53
|
+
diff_fraction = test_fraction - val_fraction
|
|
54
|
+
if diff_fraction > 0:
|
|
55
|
+
imb_count = int(diff_fraction * total_size / 2)
|
|
56
|
+
val = list(np.random.RandomState(seed=955).permutation(val))
|
|
57
|
+
test += val[:imb_count]
|
|
58
|
+
val = val[imb_count:]
|
|
59
|
+
elif diff_fraction < 0:
|
|
60
|
+
imb_count = int(-1 * diff_fraction * total_size / 2)
|
|
61
|
+
test = list(np.random.RandomState(seed=955).permutation(test))
|
|
62
|
+
val += test[:imb_count]
|
|
63
|
+
test = test[imb_count:]
|
|
64
|
+
return val, test
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_datasplit_tuples(
|
|
68
|
+
val_fraction: float,
|
|
69
|
+
test_fraction: float,
|
|
70
|
+
total_size: int,
|
|
71
|
+
starting_test: bool = False,
|
|
72
|
+
):
|
|
73
|
+
if starting_test:
|
|
74
|
+
# test => val => train
|
|
75
|
+
test = list(range(0, int(total_size * test_fraction)))
|
|
76
|
+
val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
|
|
77
|
+
train = list(range(val[-1] + 1, total_size))
|
|
78
|
+
else:
|
|
79
|
+
# {test,val}=> train
|
|
80
|
+
test_val_size = int((val_fraction + test_fraction) * total_size)
|
|
81
|
+
train = list(range(test_val_size, total_size))
|
|
82
|
+
|
|
83
|
+
if test_val_size == 0:
|
|
84
|
+
test = []
|
|
85
|
+
val = []
|
|
86
|
+
return train, val, test
|
|
87
|
+
|
|
88
|
+
# Split the test and validation in chunks.
|
|
89
|
+
chunksize = max(1, min(3, test_val_size // 2))
|
|
90
|
+
|
|
91
|
+
nchunks = test_val_size // chunksize
|
|
92
|
+
|
|
93
|
+
test = []
|
|
94
|
+
val = []
|
|
95
|
+
s = 0
|
|
96
|
+
for i in range(nchunks):
|
|
97
|
+
if i % 2 == 0:
|
|
98
|
+
val += list(np.arange(s, s + chunksize))
|
|
99
|
+
else:
|
|
100
|
+
test += list(np.arange(s, s + chunksize))
|
|
101
|
+
s += chunksize
|
|
102
|
+
|
|
103
|
+
if i % 2 == 0:
|
|
104
|
+
test += list(np.arange(s, test_val_size))
|
|
105
|
+
else:
|
|
106
|
+
p1, p2 = split_in_half(s, test_val_size)
|
|
107
|
+
test += p1
|
|
108
|
+
val += p2
|
|
109
|
+
|
|
110
|
+
val, test = adjust_for_imbalance_in_fraction_value(
|
|
111
|
+
val, test, val_fraction, test_fraction, total_size
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return train, val, test
|