careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixel manipulation methods.
|
|
3
|
+
|
|
4
|
+
Pixel manipulation is used in N2V and similar algorithm to replace the value of
|
|
5
|
+
masked pixels.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _apply_struct_mask(
|
|
14
|
+
patch: np.ndarray,
|
|
15
|
+
coords: np.ndarray,
|
|
16
|
+
struct_params: StructMaskParameters,
|
|
17
|
+
rng: np.random.Generator | None = None,
|
|
18
|
+
) -> np.ndarray:
|
|
19
|
+
"""Apply structN2V masks to patch.
|
|
20
|
+
|
|
21
|
+
Each point in `coords` corresponds to the center of a mask, masks are paremeterized
|
|
22
|
+
by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by
|
|
23
|
+
a random value.
|
|
24
|
+
|
|
25
|
+
Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
patch : np.ndarray
|
|
30
|
+
Patch to be manipulated, 2D or 3D.
|
|
31
|
+
coords : np.ndarray
|
|
32
|
+
Coordinates of the ROI(subpatch) centers.
|
|
33
|
+
struct_params : StructMaskParameters
|
|
34
|
+
Parameters for the structN2V mask (axis and span).
|
|
35
|
+
rng : np.random.Generator or None
|
|
36
|
+
Random number generator.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
np.ndarray
|
|
41
|
+
Patch with the structN2V mask applied.
|
|
42
|
+
"""
|
|
43
|
+
if rng is None:
|
|
44
|
+
rng = np.random.default_rng()
|
|
45
|
+
|
|
46
|
+
# relative axis
|
|
47
|
+
moving_axis = -1 - struct_params.axis
|
|
48
|
+
|
|
49
|
+
# Create a mask array
|
|
50
|
+
mask = np.expand_dims(
|
|
51
|
+
np.ones(struct_params.span), axis=list(range(len(patch.shape) - 1))
|
|
52
|
+
) # (1, 1, span) or (1, span)
|
|
53
|
+
|
|
54
|
+
# Move the moving axis to the correct position
|
|
55
|
+
# i.e. the axis along which the coordinates should change
|
|
56
|
+
mask = np.moveaxis(mask, -1, moving_axis)
|
|
57
|
+
center = np.array(mask.shape) // 2
|
|
58
|
+
|
|
59
|
+
# Mark the center
|
|
60
|
+
mask[tuple(center.T)] = 0
|
|
61
|
+
|
|
62
|
+
# displacements from center
|
|
63
|
+
dx = np.indices(mask.shape)[:, mask == 1] - center[:, None]
|
|
64
|
+
|
|
65
|
+
# combine all coords (ndim, npts,) with all displacements (ncoords,ndim,)
|
|
66
|
+
mix = dx.T[..., None] + coords.T[None]
|
|
67
|
+
mix = mix.transpose([1, 0, 2]).reshape([mask.ndim, -1]).T
|
|
68
|
+
|
|
69
|
+
# delete entries that are out of bounds
|
|
70
|
+
mix = np.delete(mix, mix[:, moving_axis] < 0, axis=0)
|
|
71
|
+
|
|
72
|
+
max_bound = patch.shape[moving_axis] - 1
|
|
73
|
+
mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
|
|
74
|
+
|
|
75
|
+
# replace neighbouring pixels with random values from flat dist
|
|
76
|
+
patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
|
|
77
|
+
|
|
78
|
+
return patch
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
82
|
+
"""
|
|
83
|
+
Randomly sample a jitter to be applied to the masking grid.
|
|
84
|
+
|
|
85
|
+
This is done to account for cases where the step size is not an integer.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
step : float
|
|
90
|
+
Step size of the grid, output of np.linspace.
|
|
91
|
+
rng : np.random.Generator
|
|
92
|
+
Random number generator.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
np.ndarray
|
|
97
|
+
Array of random jitter to be added to the grid.
|
|
98
|
+
"""
|
|
99
|
+
# Define the random jitter to be added to the grid
|
|
100
|
+
odd_jitter = np.where(np.floor(step) == step, 0, rng.integers(0, 2))
|
|
101
|
+
|
|
102
|
+
# Round the step size to the nearest integer depending on the jitter
|
|
103
|
+
return np.floor(step) if odd_jitter == 0 else np.ceil(step)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _get_stratified_coords(
|
|
107
|
+
mask_pixel_perc: float,
|
|
108
|
+
shape: tuple[int, ...],
|
|
109
|
+
rng: np.random.Generator | None = None,
|
|
110
|
+
) -> np.ndarray:
|
|
111
|
+
"""
|
|
112
|
+
Generate coordinates of the pixels to mask.
|
|
113
|
+
|
|
114
|
+
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
115
|
+
the distance between masked pixels is approximately the same.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
mask_pixel_perc : float
|
|
120
|
+
Actual (quasi) percentage of masked pixels across the whole image. Used in
|
|
121
|
+
calculating the distance between masked pixels across each axis.
|
|
122
|
+
shape : tuple[int, ...]
|
|
123
|
+
Shape of the input patch.
|
|
124
|
+
rng : np.random.Generator or None
|
|
125
|
+
Random number generator.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
np.ndarray
|
|
130
|
+
Array of coordinates of the masked pixels.
|
|
131
|
+
"""
|
|
132
|
+
if len(shape) < 2 or len(shape) > 3:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Calculating coordinates is only possible for 2D and 3D patches"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
if rng is None:
|
|
138
|
+
rng = np.random.default_rng()
|
|
139
|
+
|
|
140
|
+
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
141
|
+
np.int32
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Define a grid of coordinates for each axis in the input patch and the step size
|
|
145
|
+
pixel_coords = []
|
|
146
|
+
steps = []
|
|
147
|
+
for axis_size in shape:
|
|
148
|
+
# make sure axis size is evenly divisible by box size
|
|
149
|
+
num_pixels = int(np.ceil(axis_size / mask_pixel_distance))
|
|
150
|
+
axis_pixel_coords, step = np.linspace(
|
|
151
|
+
0, axis_size, num_pixels, dtype=np.int32, endpoint=False, retstep=True
|
|
152
|
+
)
|
|
153
|
+
# explain
|
|
154
|
+
pixel_coords.append(axis_pixel_coords.T)
|
|
155
|
+
steps.append(step)
|
|
156
|
+
|
|
157
|
+
# Create a meshgrid of coordinates for each axis in the input patch
|
|
158
|
+
coordinate_grid_list = np.meshgrid(*pixel_coords)
|
|
159
|
+
coordinate_grid = np.array(coordinate_grid_list).reshape(len(shape), -1).T
|
|
160
|
+
|
|
161
|
+
grid_random_increment = rng.integers(
|
|
162
|
+
_odd_jitter_func(float(max(steps)), rng) # type: ignore
|
|
163
|
+
* np.ones_like(coordinate_grid).astype(np.int32)
|
|
164
|
+
- 1,
|
|
165
|
+
size=coordinate_grid.shape,
|
|
166
|
+
endpoint=True,
|
|
167
|
+
)
|
|
168
|
+
coordinate_grid += grid_random_increment
|
|
169
|
+
coordinate_grid = np.clip(coordinate_grid, 0, np.array(shape) - 1)
|
|
170
|
+
return coordinate_grid
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _create_subpatch_center_mask(
|
|
174
|
+
subpatch: np.ndarray, center_coords: np.ndarray
|
|
175
|
+
) -> np.ndarray:
|
|
176
|
+
"""Create a mask with the center of the subpatch masked.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
subpatch : np.ndarray
|
|
181
|
+
Subpatch to be manipulated.
|
|
182
|
+
center_coords : np.ndarray
|
|
183
|
+
Coordinates of the original center before possible crop.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
np.ndarray
|
|
188
|
+
Mask with the center of the subpatch masked.
|
|
189
|
+
"""
|
|
190
|
+
mask = np.ones(subpatch.shape)
|
|
191
|
+
mask[tuple(center_coords)] = 0
|
|
192
|
+
return np.ma.make_mask(mask) # type: ignore
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _create_subpatch_struct_mask(
|
|
196
|
+
subpatch: np.ndarray, center_coords: np.ndarray, struct_params: StructMaskParameters
|
|
197
|
+
) -> np.ndarray:
|
|
198
|
+
"""Create a structN2V mask for the subpatch.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
subpatch : np.ndarray
|
|
203
|
+
Subpatch to be manipulated.
|
|
204
|
+
center_coords : np.ndarray
|
|
205
|
+
Coordinates of the original center before possible crop.
|
|
206
|
+
struct_params : StructMaskParameters
|
|
207
|
+
Parameters for the structN2V mask (axis and span).
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
np.ndarray
|
|
212
|
+
StructN2V mask for the subpatch.
|
|
213
|
+
"""
|
|
214
|
+
# TODO no test for this function!
|
|
215
|
+
# Create a mask with the center of the subpatch masked
|
|
216
|
+
mask_placeholder = np.ones(subpatch.shape)
|
|
217
|
+
|
|
218
|
+
# reshape to move the struct axis to the first position
|
|
219
|
+
mask_reshaped = np.moveaxis(mask_placeholder, struct_params.axis, 0)
|
|
220
|
+
|
|
221
|
+
# create the mask index for the struct axis
|
|
222
|
+
mask_index = slice(
|
|
223
|
+
max(0, center_coords.take(struct_params.axis) - (struct_params.span - 1) // 2),
|
|
224
|
+
min(
|
|
225
|
+
1 + center_coords.take(struct_params.axis) + (struct_params.span - 1) // 2,
|
|
226
|
+
subpatch.shape[struct_params.axis],
|
|
227
|
+
),
|
|
228
|
+
)
|
|
229
|
+
mask_reshaped[struct_params.axis][mask_index] = 0
|
|
230
|
+
|
|
231
|
+
# reshape back to the original shape
|
|
232
|
+
mask = np.moveaxis(mask_reshaped, 0, struct_params.axis)
|
|
233
|
+
|
|
234
|
+
return np.ma.make_mask(mask) # type: ignore
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def uniform_manipulate(
|
|
238
|
+
patch: np.ndarray,
|
|
239
|
+
mask_pixel_percentage: float,
|
|
240
|
+
subpatch_size: int = 11,
|
|
241
|
+
remove_center: bool = True,
|
|
242
|
+
struct_params: StructMaskParameters | None = None,
|
|
243
|
+
rng: np.random.Generator | None = None,
|
|
244
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
245
|
+
"""
|
|
246
|
+
Manipulate pixels by replacing them with a neighbor values.
|
|
247
|
+
|
|
248
|
+
Manipulated pixels are selected unformly selected in a subpatch, away from a grid
|
|
249
|
+
with an approximate uniform probability to be selected across the whole patch.
|
|
250
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the
|
|
251
|
+
data, replacing the pixels in the mask with random values (excluding the pixel
|
|
252
|
+
already manipulated).
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
patch : np.ndarray
|
|
257
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
258
|
+
mask_pixel_percentage : float
|
|
259
|
+
Approximate percentage of pixels to be masked.
|
|
260
|
+
subpatch_size : int
|
|
261
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
262
|
+
remove_center : bool
|
|
263
|
+
Whether to remove the center pixel from the subpatch, by default False.
|
|
264
|
+
struct_params : StructMaskParameters or None
|
|
265
|
+
Parameters for the structN2V mask (axis and span).
|
|
266
|
+
rng : np.random.Generator or None
|
|
267
|
+
Random number generator.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
tuple[np.ndarray]
|
|
272
|
+
tuple containing the manipulated patch and the corresponding mask.
|
|
273
|
+
"""
|
|
274
|
+
if rng is None:
|
|
275
|
+
rng = np.random.default_rng()
|
|
276
|
+
|
|
277
|
+
# Get the coordinates of the pixels to be replaced
|
|
278
|
+
transformed_patch = patch.copy()
|
|
279
|
+
|
|
280
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
281
|
+
|
|
282
|
+
# Generate coordinate grid for subpatch
|
|
283
|
+
roi_span_full = np.arange(
|
|
284
|
+
-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)
|
|
285
|
+
).astype(np.int32)
|
|
286
|
+
|
|
287
|
+
# Remove the center pixel from the grid if needed
|
|
288
|
+
roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
|
|
289
|
+
|
|
290
|
+
# Randomly select coordinates from the grid
|
|
291
|
+
random_increment = rng.choice(roi_span, size=subpatch_centers.shape)
|
|
292
|
+
|
|
293
|
+
# Clip the coordinates to the patch size
|
|
294
|
+
replacement_coords = np.clip(
|
|
295
|
+
subpatch_centers + random_increment,
|
|
296
|
+
0,
|
|
297
|
+
[patch.shape[i] - 1 for i in range(len(patch.shape))],
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Get the replacement pixels from all subpatchs
|
|
301
|
+
replacement_pixels = patch[tuple(replacement_coords.T.tolist())]
|
|
302
|
+
|
|
303
|
+
# Replace the original pixels with the replacement pixels
|
|
304
|
+
transformed_patch[tuple(subpatch_centers.T.tolist())] = replacement_pixels
|
|
305
|
+
mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
|
|
306
|
+
|
|
307
|
+
if struct_params is not None:
|
|
308
|
+
transformed_patch = _apply_struct_mask(
|
|
309
|
+
transformed_patch, subpatch_centers, struct_params
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return (
|
|
313
|
+
transformed_patch,
|
|
314
|
+
mask,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def median_manipulate(
|
|
319
|
+
patch: np.ndarray,
|
|
320
|
+
mask_pixel_percentage: float,
|
|
321
|
+
subpatch_size: int = 11,
|
|
322
|
+
struct_params: StructMaskParameters | None = None,
|
|
323
|
+
rng: np.random.Generator | None = None,
|
|
324
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
325
|
+
"""
|
|
326
|
+
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
327
|
+
|
|
328
|
+
N2V2 version, manipulated pixels are selected randomly away from a grid with an
|
|
329
|
+
approximate uniform probability to be selected across the whole patch.
|
|
330
|
+
|
|
331
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the data,
|
|
332
|
+
replacing the pixels in the mask with random values (excluding the pixel already
|
|
333
|
+
manipulated).
|
|
334
|
+
|
|
335
|
+
Parameters
|
|
336
|
+
----------
|
|
337
|
+
patch : np.ndarray
|
|
338
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
339
|
+
mask_pixel_percentage : floar
|
|
340
|
+
Approximate percentage of pixels to be masked.
|
|
341
|
+
subpatch_size : int
|
|
342
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
343
|
+
struct_params : StructMaskParameters or None, optional
|
|
344
|
+
Parameters for the structN2V mask (axis and span).
|
|
345
|
+
rng : np.random.Generator or None, optional
|
|
346
|
+
Random number generato, by default None.
|
|
347
|
+
|
|
348
|
+
Returns
|
|
349
|
+
-------
|
|
350
|
+
tuple[np.ndarray]
|
|
351
|
+
tuple containing the manipulated patch, the original patch and the mask.
|
|
352
|
+
"""
|
|
353
|
+
if rng is None:
|
|
354
|
+
rng = np.random.default_rng()
|
|
355
|
+
|
|
356
|
+
transformed_patch = patch.copy()
|
|
357
|
+
|
|
358
|
+
# Get the coordinates of the pixels to be replaced
|
|
359
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
360
|
+
|
|
361
|
+
# Generate coordinate grid for subpatch
|
|
362
|
+
roi_span = np.array(
|
|
363
|
+
[-np.floor(subpatch_size / 2), np.ceil(subpatch_size / 2)]
|
|
364
|
+
).astype(np.int32)
|
|
365
|
+
|
|
366
|
+
subpatch_crops_span_full = subpatch_centers[np.newaxis, ...].T + roi_span
|
|
367
|
+
|
|
368
|
+
# Dimensions n dims, n centers, (min, max)
|
|
369
|
+
subpatch_crops_span_clipped = np.clip(
|
|
370
|
+
subpatch_crops_span_full,
|
|
371
|
+
a_min=np.zeros_like(patch.shape)[:, np.newaxis, np.newaxis],
|
|
372
|
+
a_max=np.array(patch.shape)[:, np.newaxis, np.newaxis],
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
for idx in range(subpatch_crops_span_clipped.shape[1]):
|
|
376
|
+
subpatch_coords = subpatch_crops_span_clipped[:, idx, ...]
|
|
377
|
+
idxs = [
|
|
378
|
+
slice(x[0], x[1]) if x[1] - x[0] > 0 else slice(0, 1)
|
|
379
|
+
for x in subpatch_coords
|
|
380
|
+
]
|
|
381
|
+
subpatch = patch[tuple(idxs)]
|
|
382
|
+
subpatch_center_adjusted = subpatch_centers[idx] - subpatch_coords[:, 0]
|
|
383
|
+
|
|
384
|
+
if struct_params is None:
|
|
385
|
+
subpatch_mask = _create_subpatch_center_mask(
|
|
386
|
+
subpatch, subpatch_center_adjusted
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
subpatch_mask = _create_subpatch_struct_mask(
|
|
390
|
+
subpatch, subpatch_center_adjusted, struct_params
|
|
391
|
+
)
|
|
392
|
+
transformed_patch[tuple(subpatch_centers[idx])] = np.median(
|
|
393
|
+
subpatch[subpatch_mask]
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
mask = np.where(transformed_patch != patch, 1, 0).astype(np.uint8)
|
|
397
|
+
|
|
398
|
+
if struct_params is not None:
|
|
399
|
+
transformed_patch = _apply_struct_mask(
|
|
400
|
+
transformed_patch, subpatch_centers, struct_params
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return (
|
|
404
|
+
transformed_patch,
|
|
405
|
+
mask,
|
|
406
|
+
)
|