careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""N2V manipulation functions for PyTorch."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _apply_struct_mask_torch(
|
|
9
|
+
patch: torch.Tensor,
|
|
10
|
+
coords: torch.Tensor,
|
|
11
|
+
struct_params: StructMaskParameters,
|
|
12
|
+
rng: torch.Generator | None = None,
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""Apply structN2V masks to patch.
|
|
15
|
+
|
|
16
|
+
Each point in `coords` corresponds to the center of a mask. Masks are parameterized
|
|
17
|
+
by `struct_params`, and pixels in the mask (with respect to `coords`) are replaced
|
|
18
|
+
by a random value.
|
|
19
|
+
|
|
20
|
+
Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
patch : torch.Tensor
|
|
25
|
+
Patch to be manipulated, (batch, y, x) or (batch, z, y, x).
|
|
26
|
+
coords : torch.Tensor
|
|
27
|
+
Coordinates of the ROI (subpatch) centers.
|
|
28
|
+
struct_params : StructMaskParameters
|
|
29
|
+
Parameters for the structN2V mask (axis and span).
|
|
30
|
+
rng : torch.Generator, optional
|
|
31
|
+
Random number generator.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
torch.Tensor
|
|
36
|
+
Patch with the structN2V mask applied.
|
|
37
|
+
"""
|
|
38
|
+
if rng is None:
|
|
39
|
+
rng = torch.Generator(device=patch.device)
|
|
40
|
+
|
|
41
|
+
# Relative axis
|
|
42
|
+
moving_axis = -1 - struct_params.axis
|
|
43
|
+
|
|
44
|
+
# Create a mask array
|
|
45
|
+
mask_shape = [1] * len(patch.shape)
|
|
46
|
+
mask_shape[moving_axis] = struct_params.span
|
|
47
|
+
mask = torch.ones(mask_shape, device=patch.device)
|
|
48
|
+
|
|
49
|
+
center = torch.tensor(mask.shape, device=patch.device) // 2
|
|
50
|
+
|
|
51
|
+
# Mark the center
|
|
52
|
+
mask[tuple(center)] = 0
|
|
53
|
+
|
|
54
|
+
# Displacements from center
|
|
55
|
+
displacements = torch.stack(torch.where(mask == 1)) - center.unsqueeze(1)
|
|
56
|
+
|
|
57
|
+
# Combine all coords (ndim, npts) with all displacements (ncoords, ndim)
|
|
58
|
+
mix = displacements.T.unsqueeze(-1) + coords.T.unsqueeze(0)
|
|
59
|
+
mix = mix.permute([1, 0, 2]).reshape([mask.ndim, -1]).T
|
|
60
|
+
|
|
61
|
+
# Filter out invalid indices
|
|
62
|
+
valid_indices = (mix[:, moving_axis] >= 0) & (
|
|
63
|
+
mix[:, moving_axis] < patch.shape[moving_axis]
|
|
64
|
+
)
|
|
65
|
+
mix = mix[valid_indices]
|
|
66
|
+
|
|
67
|
+
mins = patch.min(-1)[0].min(-1)[0]
|
|
68
|
+
maxs = patch.max(-1)[0].max(-1)[0]
|
|
69
|
+
for i in range(patch.shape[0]):
|
|
70
|
+
batch_coords = mix[mix[:, 0] == i]
|
|
71
|
+
min_ = mins[i].item()
|
|
72
|
+
max_ = maxs[i].item()
|
|
73
|
+
random_values = torch.empty(len(batch_coords), device=patch.device).uniform_(
|
|
74
|
+
min_, max_, generator=rng
|
|
75
|
+
)
|
|
76
|
+
patch[tuple(batch_coords[:, i] for i in range(patch.ndim))] = random_values
|
|
77
|
+
|
|
78
|
+
return patch
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _get_stratified_coords_torch(
|
|
82
|
+
mask_pixel_perc: float,
|
|
83
|
+
shape: tuple[int, ...],
|
|
84
|
+
rng: torch.Generator,
|
|
85
|
+
) -> torch.Tensor:
|
|
86
|
+
"""
|
|
87
|
+
Generate coordinates of the pixels to mask.
|
|
88
|
+
|
|
89
|
+
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
90
|
+
the distance between masked pixels is approximately the same. This is achieved by
|
|
91
|
+
defining a grid and sampling a pixel in each grid square. The grid is defined such
|
|
92
|
+
that the resulting density of masked pixels is the desired masked pixel percentage.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
mask_pixel_perc : float
|
|
97
|
+
Expected value for percentage of masked pixels across the whole image.
|
|
98
|
+
shape : tuple[int, ...]
|
|
99
|
+
Shape of the input patch.
|
|
100
|
+
rng : torch.Generator or None
|
|
101
|
+
Random number generator.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
np.ndarray
|
|
106
|
+
Array of coordinates of the masked pixels.
|
|
107
|
+
"""
|
|
108
|
+
# Implementation logic:
|
|
109
|
+
# find a box size s.t sampling 1 pixel within the box will result in the desired
|
|
110
|
+
# pixel percentage. Make a grid of these boxes that cover the patch (the area of
|
|
111
|
+
# the grid will be greater than or equal to the area of the patch) and sample 1
|
|
112
|
+
# pixel in each box. The density of masked pixels is an intensive property therefore
|
|
113
|
+
# any subset of this area will have the desired expected masked pixel percentage.
|
|
114
|
+
# We can get our desired patch with our desired expected masked pixel percentage by
|
|
115
|
+
# simply filtering out masked pixels that lie outside of our patch bounds.
|
|
116
|
+
|
|
117
|
+
batch_size = shape[0]
|
|
118
|
+
spatial_shape = shape[1:]
|
|
119
|
+
|
|
120
|
+
n_dims = len(spatial_shape)
|
|
121
|
+
expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
|
|
122
|
+
|
|
123
|
+
# keep the grid size in floats for a more accurate expected masked pixel percentage
|
|
124
|
+
grid_size = expected_area_per_pixel ** (1 / n_dims)
|
|
125
|
+
grid_dims = torch.ceil(torch.tensor(spatial_shape) / grid_size).int()
|
|
126
|
+
|
|
127
|
+
# coords on a fixed grid (top left corner)
|
|
128
|
+
coords = torch.stack(
|
|
129
|
+
torch.meshgrid(
|
|
130
|
+
torch.arange(batch_size, dtype=torch.float),
|
|
131
|
+
*[torch.arange(0, grid_dims[i].item()) * grid_size for i in range(n_dims)],
|
|
132
|
+
indexing="ij",
|
|
133
|
+
),
|
|
134
|
+
-1,
|
|
135
|
+
).reshape(-1, n_dims + 1)
|
|
136
|
+
|
|
137
|
+
# add random offset to get a random coord in each grid box
|
|
138
|
+
# also keep the offset in floats
|
|
139
|
+
offset = (
|
|
140
|
+
torch.rand((len(coords), n_dims), device=rng.device, generator=rng) * grid_size
|
|
141
|
+
)
|
|
142
|
+
coords = coords.to(rng.device)
|
|
143
|
+
coords[:, 1:] += offset
|
|
144
|
+
coords = torch.floor(coords).int()
|
|
145
|
+
|
|
146
|
+
# filter pixels out of bounds
|
|
147
|
+
out_of_bounds = (
|
|
148
|
+
coords[:, 1:]
|
|
149
|
+
>= torch.tensor(spatial_shape, device=rng.device).reshape(1, n_dims)
|
|
150
|
+
).any(1)
|
|
151
|
+
coords = coords[~out_of_bounds]
|
|
152
|
+
return coords
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def uniform_manipulate_torch(
|
|
156
|
+
patch: torch.Tensor,
|
|
157
|
+
mask_pixel_percentage: float,
|
|
158
|
+
subpatch_size: int = 11,
|
|
159
|
+
remove_center: bool = True,
|
|
160
|
+
struct_params: StructMaskParameters | None = None,
|
|
161
|
+
rng: torch.Generator | None = None,
|
|
162
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
163
|
+
"""
|
|
164
|
+
Manipulate pixels by replacing them with a neighbor values.
|
|
165
|
+
|
|
166
|
+
# TODO add more details, especially about batch
|
|
167
|
+
|
|
168
|
+
Manipulated pixels are selected uniformly selected in a subpatch, away from a grid
|
|
169
|
+
with an approximate uniform probability to be selected across the whole patch.
|
|
170
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the
|
|
171
|
+
data, replacing the pixels in the mask with random values (excluding the pixel
|
|
172
|
+
already manipulated).
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
patch : torch.Tensor
|
|
177
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x). # TODO batch and channel.
|
|
178
|
+
mask_pixel_percentage : float
|
|
179
|
+
Approximate percentage of pixels to be masked.
|
|
180
|
+
subpatch_size : int
|
|
181
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
182
|
+
remove_center : bool
|
|
183
|
+
Whether to remove the center pixel from the subpatch, by default False.
|
|
184
|
+
struct_params : StructMaskParameters or None
|
|
185
|
+
Parameters for the structN2V mask (axis and span).
|
|
186
|
+
rng : torch.default_generator or None
|
|
187
|
+
Random number generator.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
192
|
+
tuple containing the manipulated patch and the corresponding mask.
|
|
193
|
+
"""
|
|
194
|
+
if rng is None:
|
|
195
|
+
rng = torch.Generator(device=patch.device)
|
|
196
|
+
# TODO do we need seed ?
|
|
197
|
+
|
|
198
|
+
# create a copy of the patch
|
|
199
|
+
transformed_patch = patch.clone()
|
|
200
|
+
|
|
201
|
+
# get the coordinates of the pixels to be masked
|
|
202
|
+
subpatch_centers = _get_stratified_coords_torch(
|
|
203
|
+
mask_pixel_percentage, patch.shape, rng
|
|
204
|
+
)
|
|
205
|
+
subpatch_centers = subpatch_centers.to(device=patch.device)
|
|
206
|
+
|
|
207
|
+
# TODO refactor with non negative indices?
|
|
208
|
+
# arrange the list of indices to represent the ROI around the pixel to be masked
|
|
209
|
+
roi_span_full = torch.arange(
|
|
210
|
+
-(subpatch_size // 2),
|
|
211
|
+
subpatch_size // 2 + 1,
|
|
212
|
+
dtype=torch.int32,
|
|
213
|
+
device=patch.device,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# remove the center pixel from the ROI
|
|
217
|
+
roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
|
|
218
|
+
|
|
219
|
+
# create a random increment to select the replacement value
|
|
220
|
+
# this increment is added to the center coordinates
|
|
221
|
+
random_increment = roi_span[
|
|
222
|
+
torch.randint(
|
|
223
|
+
low=min(roi_span),
|
|
224
|
+
high=max(roi_span) + 1,
|
|
225
|
+
# one less coord dim: we shouldn't add a random increment to the batch coord
|
|
226
|
+
size=(subpatch_centers.shape[0], subpatch_centers.shape[1] - 1),
|
|
227
|
+
generator=rng,
|
|
228
|
+
device=patch.device,
|
|
229
|
+
)
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
# compute the replacement pixel coordinates
|
|
233
|
+
replacement_coords = subpatch_centers.clone()
|
|
234
|
+
# only add random increment to the spatial dimensions, not the batch dimension
|
|
235
|
+
replacement_coords[:, 1:] = torch.clamp(
|
|
236
|
+
replacement_coords[:, 1:] + random_increment,
|
|
237
|
+
torch.zeros_like(torch.tensor(patch.shape[1:])).to(device=patch.device),
|
|
238
|
+
torch.tensor([v - 1 for v in patch.shape[1:]]).to(device=patch.device),
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# replace the pixels in the patch
|
|
242
|
+
# tuples and transpose are needed for proper indexing
|
|
243
|
+
replacement_pixels = patch[tuple(replacement_coords.T)]
|
|
244
|
+
transformed_patch[tuple(subpatch_centers.T)] = replacement_pixels
|
|
245
|
+
|
|
246
|
+
# create a mask representing the masked pixels
|
|
247
|
+
mask = (transformed_patch != patch).to(dtype=torch.uint8)
|
|
248
|
+
|
|
249
|
+
# apply structN2V mask if needed
|
|
250
|
+
if struct_params is not None:
|
|
251
|
+
transformed_patch = _apply_struct_mask_torch(
|
|
252
|
+
transformed_patch, subpatch_centers, struct_params, rng
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return transformed_patch, mask
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def median_manipulate_torch(
|
|
259
|
+
batch: torch.Tensor,
|
|
260
|
+
mask_pixel_percentage: float,
|
|
261
|
+
subpatch_size: int = 11,
|
|
262
|
+
struct_params: StructMaskParameters | None = None,
|
|
263
|
+
rng: torch.Generator | None = None,
|
|
264
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
265
|
+
"""
|
|
266
|
+
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
267
|
+
|
|
268
|
+
N2V2 version, manipulated pixels are selected randomly away from a grid with an
|
|
269
|
+
approximate uniform probability to be selected across the whole patch.
|
|
270
|
+
|
|
271
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the data,
|
|
272
|
+
replacing the pixels in the mask with random values (excluding the pixel already
|
|
273
|
+
manipulated).
|
|
274
|
+
|
|
275
|
+
Parameters
|
|
276
|
+
----------
|
|
277
|
+
batch : torch.Tensor
|
|
278
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
279
|
+
mask_pixel_percentage : float
|
|
280
|
+
Approximate percentage of pixels to be masked.
|
|
281
|
+
subpatch_size : int
|
|
282
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
283
|
+
struct_params : StructMaskParameters or None, optional
|
|
284
|
+
Parameters for the structN2V mask (axis and span).
|
|
285
|
+
rng : torch.default_generator or None, optional
|
|
286
|
+
Random number generator, by default None.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
291
|
+
tuple containing the manipulated patch, the original patch and the mask.
|
|
292
|
+
"""
|
|
293
|
+
# get the coordinates of the future ROI centers
|
|
294
|
+
subpatch_center_coordinates = _get_stratified_coords_torch(
|
|
295
|
+
mask_pixel_percentage, batch.shape, rng
|
|
296
|
+
).to(
|
|
297
|
+
device=batch.device
|
|
298
|
+
) # (num_coordinates, batch + num_spatial_dims)
|
|
299
|
+
|
|
300
|
+
# Calculate the padding value for the input tensor
|
|
301
|
+
pad_value = subpatch_size // 2
|
|
302
|
+
|
|
303
|
+
# Generate all offsets for the ROIs. Iteration starting from 1 to skip the batch
|
|
304
|
+
offsets = torch.meshgrid(
|
|
305
|
+
[
|
|
306
|
+
torch.arange(-pad_value, pad_value + 1, device=batch.device)
|
|
307
|
+
for _ in range(1, subpatch_center_coordinates.shape[1])
|
|
308
|
+
],
|
|
309
|
+
indexing="ij",
|
|
310
|
+
)
|
|
311
|
+
offsets = torch.stack(
|
|
312
|
+
[axis_offset.flatten() for axis_offset in offsets], dim=1
|
|
313
|
+
) # (subpatch_size**2, num_spatial_dims)
|
|
314
|
+
|
|
315
|
+
# Create the list to assemble coordinates of the ROIs centers for each axis
|
|
316
|
+
coords_axes = []
|
|
317
|
+
# Create the list to assemble the span of coordinates defining the ROIs for each
|
|
318
|
+
# axis
|
|
319
|
+
coords_expands = []
|
|
320
|
+
for d in range(subpatch_center_coordinates.shape[1]):
|
|
321
|
+
coords_axes.append(subpatch_center_coordinates[:, d])
|
|
322
|
+
if d == 0:
|
|
323
|
+
# For batch dimension coordinates are not expanded (no offsets)
|
|
324
|
+
coords_expands.append(
|
|
325
|
+
subpatch_center_coordinates[:, d]
|
|
326
|
+
.unsqueeze(1)
|
|
327
|
+
.expand(-1, subpatch_size ** offsets.shape[1])
|
|
328
|
+
) # (num_coordinates, subpatch_size**num_spacial_dims)
|
|
329
|
+
else:
|
|
330
|
+
# For spatial dimensions, coordinates are expanded with offsets, creating
|
|
331
|
+
# spans
|
|
332
|
+
coords_expands.append(
|
|
333
|
+
(
|
|
334
|
+
subpatch_center_coordinates[:, d].unsqueeze(1) + offsets[:, d - 1]
|
|
335
|
+
).clamp(0, batch.shape[d] - 1)
|
|
336
|
+
) # (num_coordinates, subpatch_size**num_spacial_dims)
|
|
337
|
+
|
|
338
|
+
# create array of rois by indexing the batch with gathered coordinates
|
|
339
|
+
rois = batch[
|
|
340
|
+
tuple(coords_expands)
|
|
341
|
+
] # (num_coordinates, subpatch_size**num_spacial_dims)
|
|
342
|
+
|
|
343
|
+
if struct_params is not None:
|
|
344
|
+
# Create the structN2V mask
|
|
345
|
+
h, w = torch.meshgrid(
|
|
346
|
+
torch.arange(subpatch_size), torch.arange(subpatch_size), indexing="ij"
|
|
347
|
+
)
|
|
348
|
+
center_idx = subpatch_size // 2
|
|
349
|
+
halfspan = (struct_params.span - 1) // 2
|
|
350
|
+
|
|
351
|
+
# Determine the axis along which to apply the mask
|
|
352
|
+
if struct_params.axis == 0:
|
|
353
|
+
center_axis = h
|
|
354
|
+
span_axis = w
|
|
355
|
+
else:
|
|
356
|
+
center_axis = w
|
|
357
|
+
span_axis = h
|
|
358
|
+
|
|
359
|
+
# Create the mask
|
|
360
|
+
struct_mask = (
|
|
361
|
+
~(
|
|
362
|
+
(center_axis == center_idx)
|
|
363
|
+
& (span_axis >= center_idx - halfspan)
|
|
364
|
+
& (span_axis <= center_idx + halfspan)
|
|
365
|
+
)
|
|
366
|
+
).flatten()
|
|
367
|
+
rois_filtered = rois[:, struct_mask]
|
|
368
|
+
else:
|
|
369
|
+
# Remove the center pixel value from the rois
|
|
370
|
+
center_idx = (subpatch_size ** offsets.shape[1]) // 2
|
|
371
|
+
rois_filtered = torch.cat(
|
|
372
|
+
[rois[:, :center_idx], rois[:, center_idx + 1 :]], dim=1
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# compute the medians.
|
|
376
|
+
medians = rois_filtered.median(dim=1).values # (num_coordinates,)
|
|
377
|
+
|
|
378
|
+
# Update the output tensor with medians
|
|
379
|
+
output_batch = batch.clone()
|
|
380
|
+
output_batch[tuple(coords_axes)] = medians
|
|
381
|
+
mask = torch.where(output_batch != batch, 1, 0).to(torch.uint8)
|
|
382
|
+
|
|
383
|
+
if struct_params is not None:
|
|
384
|
+
output_batch = _apply_struct_mask_torch(
|
|
385
|
+
output_batch, subpatch_center_coordinates, struct_params
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
return output_batch, mask
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Class representing the parameters of structN2V masks."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class StructMaskParameters:
|
|
9
|
+
"""Parameters of structN2V masks.
|
|
10
|
+
|
|
11
|
+
Attributes
|
|
12
|
+
----------
|
|
13
|
+
axis : Literal[0, 1]
|
|
14
|
+
Axis along which to apply the mask, horizontal (0) or vertical (1).
|
|
15
|
+
span : int
|
|
16
|
+
Span of the mask.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
axis: Literal[0, 1]
|
|
20
|
+
span: int
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""A general parent class for transforms."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Transform:
|
|
7
|
+
"""A general parent class for transforms."""
|
|
8
|
+
|
|
9
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
10
|
+
"""Apply the transform.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
*args : Any
|
|
15
|
+
Arguments.
|
|
16
|
+
**kwargs : Any
|
|
17
|
+
Keyword arguments.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
Any
|
|
22
|
+
Transformed data.
|
|
23
|
+
"""
|
|
24
|
+
pass
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Test-time augmentations."""
|
|
2
|
+
|
|
3
|
+
from torch import Tensor, flip, mean, rot90, stack
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ImageRestorationTTA:
|
|
7
|
+
"""
|
|
8
|
+
Test-time augmentation for image restoration tasks.
|
|
9
|
+
|
|
10
|
+
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
11
|
+
as well as the original image flipped.
|
|
12
|
+
|
|
13
|
+
Tensors should be of shape SC(Z)YX.
|
|
14
|
+
|
|
15
|
+
This transformation is used in the LightningModule in order to perform test-time
|
|
16
|
+
augmentation.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def forward(self, input_tensor: Tensor) -> list[Tensor]:
|
|
20
|
+
"""
|
|
21
|
+
Apply test-time augmentation to the input tensor.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
input_tensor : Tensor
|
|
26
|
+
Input tensor, shape SC(Z)YX.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
list of torch.Tensor
|
|
31
|
+
List of augmented tensors.
|
|
32
|
+
"""
|
|
33
|
+
# axes: only applies to YX axes
|
|
34
|
+
axes = (-2, -1)
|
|
35
|
+
|
|
36
|
+
augmented = [
|
|
37
|
+
# original
|
|
38
|
+
input_tensor,
|
|
39
|
+
# rotations
|
|
40
|
+
rot90(input_tensor, 1, dims=axes),
|
|
41
|
+
rot90(input_tensor, 2, dims=axes),
|
|
42
|
+
rot90(input_tensor, 3, dims=axes),
|
|
43
|
+
# original flipped
|
|
44
|
+
flip(input_tensor, dims=(axes[0],)),
|
|
45
|
+
flip(input_tensor, dims=(axes[1],)),
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# rotated once, flipped
|
|
49
|
+
augmented.extend(
|
|
50
|
+
[
|
|
51
|
+
flip(augmented[1], dims=(axes[0],)),
|
|
52
|
+
flip(augmented[1], dims=(axes[1],)),
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return augmented
|
|
57
|
+
|
|
58
|
+
def backward(self, x: list[Tensor]) -> Tensor:
|
|
59
|
+
"""Undo the test-time augmentation.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
x : Any
|
|
64
|
+
List of augmented tensors of shape SC(Z)YX.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
Any
|
|
69
|
+
Original tensor.
|
|
70
|
+
"""
|
|
71
|
+
axes = (-2, -1)
|
|
72
|
+
|
|
73
|
+
reverse = [
|
|
74
|
+
# original
|
|
75
|
+
x[0],
|
|
76
|
+
# rotated
|
|
77
|
+
rot90(x[1], -1, dims=axes),
|
|
78
|
+
rot90(x[2], -2, dims=axes),
|
|
79
|
+
rot90(x[3], -3, dims=axes),
|
|
80
|
+
# original flipped
|
|
81
|
+
flip(x[4], dims=(axes[0],)),
|
|
82
|
+
flip(x[5], dims=(axes[1],)),
|
|
83
|
+
# rotated once, flipped
|
|
84
|
+
rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
|
|
85
|
+
rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
return mean(stack(reverse), dim=0)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""XY flip transform."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
from careamics.transforms.transform import Transform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class XYFlip(Transform):
|
|
10
|
+
"""Flip image along X and Y axis, one at a time.
|
|
11
|
+
|
|
12
|
+
This transform randomly flips one of the last two axes.
|
|
13
|
+
|
|
14
|
+
This transform expects C(Z)YX dimensions.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
axis_indices : List[int]
|
|
19
|
+
Indices of the axes that can be flipped.
|
|
20
|
+
rng : np.random.Generator
|
|
21
|
+
Random number generator.
|
|
22
|
+
p : float
|
|
23
|
+
Probability of applying the transform.
|
|
24
|
+
seed : Optional[int]
|
|
25
|
+
Random seed.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
flip_x : bool, optional
|
|
30
|
+
Whether to flip along the X axis, by default True.
|
|
31
|
+
flip_y : bool, optional
|
|
32
|
+
Whether to flip along the Y axis, by default True.
|
|
33
|
+
p : float, optional
|
|
34
|
+
Probability of applying the transform, by default 0.5.
|
|
35
|
+
seed : Optional[int], optional
|
|
36
|
+
Random seed, by default None.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
flip_x: bool = True,
|
|
42
|
+
flip_y: bool = True,
|
|
43
|
+
p: float = 0.5,
|
|
44
|
+
seed: int | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Constructor.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
flip_x : bool, optional
|
|
51
|
+
Whether to flip along the X axis, by default True.
|
|
52
|
+
flip_y : bool, optional
|
|
53
|
+
Whether to flip along the Y axis, by default True.
|
|
54
|
+
p : float
|
|
55
|
+
Probability of applying the transform, by default 0.5.
|
|
56
|
+
seed : Optional[int], optional
|
|
57
|
+
Random seed, by default None.
|
|
58
|
+
"""
|
|
59
|
+
if p < 0 or p > 1:
|
|
60
|
+
raise ValueError("Probability must be in [0, 1].")
|
|
61
|
+
|
|
62
|
+
if not flip_x and not flip_y:
|
|
63
|
+
raise ValueError("At least one axis must be flippable.")
|
|
64
|
+
|
|
65
|
+
# probability to apply the transform
|
|
66
|
+
self.p = p
|
|
67
|
+
|
|
68
|
+
# "flippable" axes
|
|
69
|
+
self.axis_indices = []
|
|
70
|
+
|
|
71
|
+
if flip_y:
|
|
72
|
+
self.axis_indices.append(-2)
|
|
73
|
+
if flip_x:
|
|
74
|
+
self.axis_indices.append(-1)
|
|
75
|
+
|
|
76
|
+
# numpy random generator
|
|
77
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
78
|
+
|
|
79
|
+
def __call__(
|
|
80
|
+
self,
|
|
81
|
+
patch: NDArray,
|
|
82
|
+
target: NDArray | None = None,
|
|
83
|
+
**additional_arrays: NDArray,
|
|
84
|
+
) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
|
|
85
|
+
"""Apply the transform to the source patch and the target (optional).
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
patch : np.ndarray
|
|
90
|
+
Patch, 2D or 3D, shape C(Z)YX.
|
|
91
|
+
target : Optional[np.ndarray], optional
|
|
92
|
+
Target for the patch, by default None.
|
|
93
|
+
**additional_arrays : NDArray
|
|
94
|
+
Additional arrays that will be transformed identically to `patch` and
|
|
95
|
+
`target`.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
100
|
+
Transformed patch and target.
|
|
101
|
+
"""
|
|
102
|
+
if self.rng.random() > self.p:
|
|
103
|
+
return patch, target, additional_arrays
|
|
104
|
+
|
|
105
|
+
# choose an axis to flip
|
|
106
|
+
axis = self.rng.choice(self.axis_indices)
|
|
107
|
+
|
|
108
|
+
patch_transformed = self._apply(patch, axis)
|
|
109
|
+
target_transformed = self._apply(target, axis) if target is not None else None
|
|
110
|
+
additional_transformed = {
|
|
111
|
+
key: self._apply(array, axis) for key, array in additional_arrays.items()
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
return patch_transformed, target_transformed, additional_transformed
|
|
115
|
+
|
|
116
|
+
def _apply(self, patch: NDArray, axis: int) -> NDArray:
|
|
117
|
+
"""Apply the transform to the image.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
patch : np.ndarray
|
|
122
|
+
Image patch, 2D or 3D, shape C(Z)YX.
|
|
123
|
+
axis : int
|
|
124
|
+
Axis to flip.
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
np.ndarray
|
|
129
|
+
Flipped image patch.
|
|
130
|
+
"""
|
|
131
|
+
return np.ascontiguousarray(np.flip(patch, axis=axis))
|