careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
"""MicroSplit patch synthesis."""
|
|
2
|
+
|
|
3
|
+
# --- PROOF OF PRINCIPLE ---
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from collections.abc import Callable, Sequence
|
|
7
|
+
from typing import Any, Literal, NamedTuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from numpy.typing import NDArray
|
|
11
|
+
|
|
12
|
+
from .dataset import ImageRegionData
|
|
13
|
+
from .image_stack import ImageStack
|
|
14
|
+
from .patch_extractor import PatchExtractor
|
|
15
|
+
from .patch_filter import PatchFilterProtocol
|
|
16
|
+
from .patching_strategies import PatchingStrategy, PatchSpecs
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# TODO: better name
|
|
20
|
+
# mirrors format of ImageRegionData
|
|
21
|
+
class UncorrelatedRegionData(NamedTuple):
|
|
22
|
+
data: NDArray
|
|
23
|
+
source: Sequence[str | Literal["array"]]
|
|
24
|
+
data_shape: Sequence[Sequence[int]]
|
|
25
|
+
dtype: Sequence[str] # dtype should be str for collate
|
|
26
|
+
axes: Sequence[str]
|
|
27
|
+
region_spec: Sequence[PatchSpecs]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# --- for finding empty / signal channel patches in loop
|
|
31
|
+
def is_empty(filter: PatchFilterProtocol) -> Callable[[NDArray[Any]], bool]:
|
|
32
|
+
def is_empty_check(patch: NDArray[Any]) -> bool:
|
|
33
|
+
return filter.filter_out(patch)
|
|
34
|
+
|
|
35
|
+
return is_empty_check
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def is_not_empty(filter: PatchFilterProtocol) -> Callable[[NDArray[Any]], bool]:
|
|
39
|
+
def is_not_empty_check(patch: NDArray[Any]) -> bool:
|
|
40
|
+
return not filter.filter_out(patch)
|
|
41
|
+
|
|
42
|
+
return is_not_empty_check
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ---
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def create_default_input_target(
|
|
49
|
+
idx: int,
|
|
50
|
+
patch_extractor: PatchExtractor[ImageStack],
|
|
51
|
+
patching_strategy: PatchingStrategy,
|
|
52
|
+
alphas: list[float],
|
|
53
|
+
axes: str, # annoyingly have to supply this to image region
|
|
54
|
+
) -> tuple[ImageRegionData, ImageRegionData]:
|
|
55
|
+
"""
|
|
56
|
+
Create a default MicroSplit patch with synthetically summed input.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
idx: int
|
|
61
|
+
The dataset index.
|
|
62
|
+
patch_extractor: PatchExtractor
|
|
63
|
+
Used to extract patches from the data.
|
|
64
|
+
patching_strategy: PatchingStrategy
|
|
65
|
+
Patch locations will be sampled using the patching strategy.
|
|
66
|
+
alphas: list[float]
|
|
67
|
+
Weights for each channel for creating the synthetic input with summation.
|
|
68
|
+
axes: str
|
|
69
|
+
The axes of the data. This is only used to populate metadata.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
input_region: ImageRegionData
|
|
74
|
+
The input patch and its metadata, the data has the dimension L(Z)YX.
|
|
75
|
+
target_region: ImageRegionData
|
|
76
|
+
The target patch and its metadata, the data has the dimensions C(Z)YX.
|
|
77
|
+
"""
|
|
78
|
+
patch_spec = patching_strategy.get_patch_spec(idx)
|
|
79
|
+
patches = extract_microsplit_patch(patch_extractor, patch_spec)
|
|
80
|
+
|
|
81
|
+
ndims = len(patches.shape) - 1
|
|
82
|
+
alpha_broadcast = np.array(alphas)[:, *(np.newaxis for _ in range(ndims))]
|
|
83
|
+
# weight channels by alphas then sum on the channel axis
|
|
84
|
+
# input dims will be L(Z)YX
|
|
85
|
+
input_patch = (alpha_broadcast * patches).sum(axis=0)
|
|
86
|
+
target_patch = patches[:, 0, ...] # first L patch
|
|
87
|
+
|
|
88
|
+
data_idx = patch_spec["data_idx"]
|
|
89
|
+
input_region = ImageRegionData(
|
|
90
|
+
input_patch,
|
|
91
|
+
source=str(patch_extractor.image_stacks[data_idx].source),
|
|
92
|
+
data_shape=patch_extractor.image_stacks[data_idx].data_shape,
|
|
93
|
+
dtype=str(patch_extractor.image_stacks[data_idx].data_dtype),
|
|
94
|
+
axes=axes,
|
|
95
|
+
region_spec=patch_spec,
|
|
96
|
+
additional_metadata={},
|
|
97
|
+
)
|
|
98
|
+
target_region = ImageRegionData(
|
|
99
|
+
target_patch,
|
|
100
|
+
source=str(patch_extractor.image_stacks[data_idx].source),
|
|
101
|
+
data_shape=patch_extractor.image_stacks[data_idx].data_shape,
|
|
102
|
+
dtype=str(patch_extractor.image_stacks[data_idx].data_dtype),
|
|
103
|
+
axes=axes,
|
|
104
|
+
region_spec=patch_spec,
|
|
105
|
+
additional_metadata={},
|
|
106
|
+
)
|
|
107
|
+
return input_region, target_region
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def create_uncorrelated_input_target(
|
|
111
|
+
patches: NDArray[Any],
|
|
112
|
+
patch_specs: list[PatchSpecs],
|
|
113
|
+
alphas: list[float],
|
|
114
|
+
patch_extractor: PatchExtractor[ImageStack], # for metadata
|
|
115
|
+
axes: str, # mirroring imageregion
|
|
116
|
+
) -> tuple[UncorrelatedRegionData, UncorrelatedRegionData]:
|
|
117
|
+
"""
|
|
118
|
+
Create MicroSplit target and synthetically summed input with metadata.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
patches: NDArray
|
|
123
|
+
Patches with dimensions LC(Z)YX, where L contains the lateral context at
|
|
124
|
+
multiple scales.
|
|
125
|
+
patch_specs: list[PatchSpecs]
|
|
126
|
+
The patch specs for each channel.
|
|
127
|
+
alphas: list[float]
|
|
128
|
+
Weights for each channel for creating the synthetic input with summation.
|
|
129
|
+
patch_extractor: PatchExtractor
|
|
130
|
+
The patch extractor the patches were extracted from. Used for additional
|
|
131
|
+
metadata.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
input_region: UncorrelatedRegionData
|
|
136
|
+
The input patch and its metadata, the data has the dimension L(Z)YX.
|
|
137
|
+
target_region: UncorrelatedRegionData
|
|
138
|
+
The target patch and its metadata, the data has the dimensions C(Z)YX.
|
|
139
|
+
"""
|
|
140
|
+
ndims = len(patches.shape) - 1
|
|
141
|
+
alpha_broadcast = np.array(alphas)[:, *(np.newaxis for _ in range(ndims))]
|
|
142
|
+
# weight channels by alphas then sum on the channel axis
|
|
143
|
+
# input dims will be L(Z)YX
|
|
144
|
+
input_patch = (alpha_broadcast * patches).sum(axis=0)
|
|
145
|
+
target_patch = patches[:, 0, ...] # first L patch
|
|
146
|
+
|
|
147
|
+
input_stacks = [
|
|
148
|
+
patch_extractor.image_stacks[patch_spec["data_idx"]]
|
|
149
|
+
for patch_spec in patch_specs
|
|
150
|
+
]
|
|
151
|
+
source = [str(stack.source) for stack in input_stacks]
|
|
152
|
+
data_shape = [stack.data_shape for stack in input_stacks]
|
|
153
|
+
dtype = [str(stack.data_dtype) for stack in input_stacks]
|
|
154
|
+
|
|
155
|
+
input_region = UncorrelatedRegionData(
|
|
156
|
+
data=input_patch,
|
|
157
|
+
source=source,
|
|
158
|
+
data_shape=data_shape,
|
|
159
|
+
dtype=dtype,
|
|
160
|
+
region_spec=patch_specs,
|
|
161
|
+
axes=axes,
|
|
162
|
+
)
|
|
163
|
+
target_region = UncorrelatedRegionData(
|
|
164
|
+
data=target_patch,
|
|
165
|
+
source=source,
|
|
166
|
+
data_shape=data_shape,
|
|
167
|
+
dtype=dtype,
|
|
168
|
+
region_spec=patch_specs,
|
|
169
|
+
axes=axes,
|
|
170
|
+
)
|
|
171
|
+
return input_region, target_region
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_random_channel_patches(
|
|
175
|
+
idx: int, # TODO: is this needed it makes it work the same as original dataset
|
|
176
|
+
patch_extractor: PatchExtractor[ImageStack],
|
|
177
|
+
patching_strategy: PatchingStrategy,
|
|
178
|
+
rng: np.random.Generator | None,
|
|
179
|
+
) -> tuple[NDArray[Any], list[PatchSpecs]]:
|
|
180
|
+
"""
|
|
181
|
+
Select patches form random patch locations for each channel.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
idx: int
|
|
186
|
+
The dataset index.
|
|
187
|
+
patch_extractor: PatchExtractor
|
|
188
|
+
Used to extract patches from the data.
|
|
189
|
+
patching_strategy: PatchingStrategy
|
|
190
|
+
Patch locations will be sampled using the patching strategy.
|
|
191
|
+
rng: numpy.random.Generator | None
|
|
192
|
+
Useful for seeding the process. If `None` the default random number generator
|
|
193
|
+
will be used.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
NDArray[Any]
|
|
198
|
+
The resulting patches with dimensions LC(Z)YX, where L contains the lateral
|
|
199
|
+
context at multiple scales.
|
|
200
|
+
list[PatchSpecs]
|
|
201
|
+
A list of patch specification, one for each channel.
|
|
202
|
+
"""
|
|
203
|
+
if rng is None:
|
|
204
|
+
rng = np.random.default_rng()
|
|
205
|
+
|
|
206
|
+
n_channels = patch_extractor.n_channels
|
|
207
|
+
|
|
208
|
+
# in the original dataset, new random indices are chosen for each channel
|
|
209
|
+
# the other channels can come from anywhere in the entire dataset
|
|
210
|
+
indices = (idx, *rng.integers(patching_strategy.n_patches, size=(n_channels - 1)))
|
|
211
|
+
|
|
212
|
+
# get n different patch specs for n different channels
|
|
213
|
+
patch_specs = [patching_strategy.get_patch_spec(i) for i in indices]
|
|
214
|
+
patches = extract_microsplit_patch(patch_extractor, patch_specs)
|
|
215
|
+
|
|
216
|
+
return patches, patch_specs
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# TODO: better name
|
|
220
|
+
def get_empty_channel_patches(
|
|
221
|
+
idx: int,
|
|
222
|
+
patch_extractor: PatchExtractor,
|
|
223
|
+
patching_strategy: PatchingStrategy,
|
|
224
|
+
signal_channels: dict[int, PatchFilterProtocol],
|
|
225
|
+
empty_channels: dict[int, PatchFilterProtocol],
|
|
226
|
+
patience: int,
|
|
227
|
+
rng: np.random.Generator | None,
|
|
228
|
+
) -> tuple[NDArray[Any], list[PatchSpecs]]:
|
|
229
|
+
"""
|
|
230
|
+
Select patches, specifying which channels should have signal and which should not.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
idx: int
|
|
235
|
+
The dataset index.
|
|
236
|
+
patch_extractor: PatchExtractor
|
|
237
|
+
Used to extract patches from the data.
|
|
238
|
+
patching_strategy: PatchingStrategy
|
|
239
|
+
Patch locations will be sampled using the patching strategy.
|
|
240
|
+
signal_channels: dict[int, PatchFilterProtocol]
|
|
241
|
+
A dictionary to specify the channels that should have signal and how they should
|
|
242
|
+
be filtered. The keys are the channel index and the values are the patch filters
|
|
243
|
+
used to determine if the channel patch is empty or not.
|
|
244
|
+
empty_channels: dict[int, PatchFilterProtocol]
|
|
245
|
+
A dictionary to specify the channels that should not have signal. Similar to
|
|
246
|
+
the `signal_channels`.
|
|
247
|
+
patience: int
|
|
248
|
+
New patches are selected at random until a patch with signal or without is
|
|
249
|
+
found, the `patience` determines how many times to look before giving up.
|
|
250
|
+
rng: numpy.random.Generator | None
|
|
251
|
+
Useful for seeding the process. If `None` the default random number generator
|
|
252
|
+
will be used.
|
|
253
|
+
|
|
254
|
+
Returns
|
|
255
|
+
-------
|
|
256
|
+
NDArray[Any]
|
|
257
|
+
The resulting patches with dimensions LC(Z)YX, where L contains the lateral
|
|
258
|
+
context at multiple scales.
|
|
259
|
+
list[PatchSpecs]
|
|
260
|
+
A list of patch specification, one for each channel.
|
|
261
|
+
"""
|
|
262
|
+
if rng is None:
|
|
263
|
+
rng = np.random.default_rng()
|
|
264
|
+
|
|
265
|
+
# if a channel is not selected to be empty or filled it will from idx
|
|
266
|
+
filled = set(signal_channels.keys())
|
|
267
|
+
empty = set(empty_channels.keys())
|
|
268
|
+
if len(intersect := filled.intersection(empty)) != 0:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"Channels cannot be selected as both empty and filled, the following "
|
|
271
|
+
f"channels were selected as both {intersect}."
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
n_channels = patch_extractor.n_channels
|
|
275
|
+
|
|
276
|
+
# start with random initial patches
|
|
277
|
+
patches, patch_specs = get_random_channel_patches(
|
|
278
|
+
idx, patch_extractor, patching_strategy, rng
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# for each channel sample patches until they are empty or not empty
|
|
282
|
+
for c in range(n_channels):
|
|
283
|
+
|
|
284
|
+
# criterion for the while loop
|
|
285
|
+
criterion: Callable[[NDArray[Any]], bool]
|
|
286
|
+
filter_: PatchFilterProtocol
|
|
287
|
+
if c in empty_channels:
|
|
288
|
+
filter_ = empty_channels[c]
|
|
289
|
+
criterion = is_not_empty(filter_)
|
|
290
|
+
elif c in signal_channels:
|
|
291
|
+
filter_ = signal_channels[c]
|
|
292
|
+
criterion = is_empty(filter_)
|
|
293
|
+
else:
|
|
294
|
+
break
|
|
295
|
+
|
|
296
|
+
patch = patches[c]
|
|
297
|
+
patch_spec = patch_specs[c]
|
|
298
|
+
patience_ = patience
|
|
299
|
+
# only check if primary input is empty
|
|
300
|
+
while criterion(patch[0]) and patience_ > 0:
|
|
301
|
+
# sample random indices from anywhere in the dataset
|
|
302
|
+
new_idx = rng.integers(patching_strategy.n_patches)
|
|
303
|
+
patch_spec = patching_strategy.get_patch_spec(new_idx.item())
|
|
304
|
+
patch = patch_extractor.extract_channel_patch(
|
|
305
|
+
data_idx=patch_spec["data_idx"],
|
|
306
|
+
sample_idx=patch_spec["sample_idx"],
|
|
307
|
+
channels=[c],
|
|
308
|
+
coords=patch_spec["coords"],
|
|
309
|
+
patch_size=patch_spec["patch_size"],
|
|
310
|
+
)[0]
|
|
311
|
+
# ^ removing channel dim
|
|
312
|
+
patience_ -= 1
|
|
313
|
+
if patience <= 0:
|
|
314
|
+
# TODO: log properly
|
|
315
|
+
print(f"Out of patience finding patch for channel {c}")
|
|
316
|
+
|
|
317
|
+
patches[c] = patch
|
|
318
|
+
patch_specs[c] = patch_spec
|
|
319
|
+
|
|
320
|
+
return patches, patch_specs
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def extract_microsplit_patch(
|
|
324
|
+
patch_extractor: PatchExtractor[ImageStack],
|
|
325
|
+
patch_specs: PatchSpecs | list[PatchSpecs],
|
|
326
|
+
) -> NDArray[Any]:
|
|
327
|
+
"""
|
|
328
|
+
Extract a MicroSplit patch with the dimensions LC(Z)YX.
|
|
329
|
+
|
|
330
|
+
This patch can be used to synthesis an input patch by summing the C dimension, and
|
|
331
|
+
it can be used to create a target patch by selecting the primary input from the
|
|
332
|
+
L dimension, where L is to store lateral context patches.
|
|
333
|
+
|
|
334
|
+
Parameters
|
|
335
|
+
----------
|
|
336
|
+
patch_extractor: PatchExtractor
|
|
337
|
+
Used to extract patches from the data.
|
|
338
|
+
patch_specs: PatchSpec | list[PatchSpecs]
|
|
339
|
+
A patch specification or a list of patch specifications — one for each channel.
|
|
340
|
+
Different patch specs can be used or each channel to create uncorrelated channel
|
|
341
|
+
patches.
|
|
342
|
+
|
|
343
|
+
Returns
|
|
344
|
+
-------
|
|
345
|
+
NDArray[Any]
|
|
346
|
+
The resulting patches with dimensions LC(Z)YX, where L contains the lateral
|
|
347
|
+
context at multiple scales.
|
|
348
|
+
"""
|
|
349
|
+
if isinstance(patch_specs, list):
|
|
350
|
+
patches = np.concat(
|
|
351
|
+
[
|
|
352
|
+
patch_extractor.extract_channel_patch(
|
|
353
|
+
data_idx=patch_spec["data_idx"],
|
|
354
|
+
sample_idx=patch_spec["sample_idx"],
|
|
355
|
+
channels=[c],
|
|
356
|
+
coords=patch_spec["coords"],
|
|
357
|
+
patch_size=patch_spec["patch_size"],
|
|
358
|
+
)
|
|
359
|
+
for c, patch_spec in enumerate(patch_specs)
|
|
360
|
+
],
|
|
361
|
+
axis=0,
|
|
362
|
+
)
|
|
363
|
+
else:
|
|
364
|
+
patches = patch_extractor.extract_patch(
|
|
365
|
+
data_idx=patch_specs["data_idx"],
|
|
366
|
+
sample_idx=patch_specs["sample_idx"],
|
|
367
|
+
coords=patch_specs["coords"],
|
|
368
|
+
patch_size=patch_specs["patch_size"],
|
|
369
|
+
)
|
|
370
|
+
# Add L dimension if not present
|
|
371
|
+
n_spatial_dims = patch_extractor.n_spatial_dims
|
|
372
|
+
lateral_context_present = len(patches.shape) - n_spatial_dims == 2
|
|
373
|
+
if not lateral_context_present:
|
|
374
|
+
# insert a L dim
|
|
375
|
+
patches = patches[:, np.newaxis]
|
|
376
|
+
|
|
377
|
+
return patches
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from numpy.typing import NDArray
|
|
4
|
+
|
|
5
|
+
from ..image_stack import FileImageStack
|
|
6
|
+
from .patch_construction import PatchConstructor, default_patch_constr
|
|
7
|
+
from .patch_extractor import PatchExtractor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LimitFilesPatchExtractor(PatchExtractor[FileImageStack]):
|
|
11
|
+
"""
|
|
12
|
+
A patch extractor that limits the number of files that have their data loaded.
|
|
13
|
+
|
|
14
|
+
This is useful for when not all of the data will fit into memory.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
image_stacks: Sequence[FileImageStack],
|
|
20
|
+
patch_constructor: PatchConstructor = default_patch_constr,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
image_stacks: Sequence of `FileImageStack`
|
|
26
|
+
"""
|
|
27
|
+
super().__init__(image_stacks, patch_constructor)
|
|
28
|
+
self.loaded_stacks: list[int] = []
|
|
29
|
+
|
|
30
|
+
def extract_channel_patch(
|
|
31
|
+
self,
|
|
32
|
+
data_idx: int,
|
|
33
|
+
sample_idx: int,
|
|
34
|
+
channels: Sequence[int] | None,
|
|
35
|
+
coords: Sequence[int],
|
|
36
|
+
patch_size: Sequence[int],
|
|
37
|
+
) -> NDArray:
|
|
38
|
+
if data_idx not in self.loaded_stacks:
|
|
39
|
+
# TODO: make maximum images loaded configurable?
|
|
40
|
+
if len(self.loaded_stacks) >= 1:
|
|
41
|
+
# get the idx that was added longest ago
|
|
42
|
+
idx_to_close = self.loaded_stacks.pop(0)
|
|
43
|
+
self.image_stacks[idx_to_close].close()
|
|
44
|
+
|
|
45
|
+
self.image_stacks[data_idx].load()
|
|
46
|
+
self.loaded_stacks.append(data_idx)
|
|
47
|
+
|
|
48
|
+
return super().extract_channel_patch(
|
|
49
|
+
data_idx, sample_idx, channels, coords, patch_size
|
|
50
|
+
)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Any, Literal, Protocol
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from skimage.transform import resize
|
|
7
|
+
|
|
8
|
+
from ..image_stack import ImageStack
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PatchConstructor(Protocol):
|
|
12
|
+
"""
|
|
13
|
+
A callable that modifies how patches are constructed in the PatchExtractor.
|
|
14
|
+
|
|
15
|
+
This protocol defines the signature of a callable that is passed as an argument to
|
|
16
|
+
the `PatchExtractor`. It can be used to modify how patches are constructed, for
|
|
17
|
+
example creating patches with multiple lateral context levels for MicroSplit.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __call__(
|
|
21
|
+
self,
|
|
22
|
+
image_stack: ImageStack,
|
|
23
|
+
sample_idx: int,
|
|
24
|
+
channels: Sequence[int] | None, # `channels = None` to select all channels
|
|
25
|
+
coords: Sequence[int],
|
|
26
|
+
patch_size: Sequence[int],
|
|
27
|
+
) -> NDArray[Any]:
|
|
28
|
+
"""
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
image_stack: ImageStack
|
|
32
|
+
The image stack to construct a patch from.
|
|
33
|
+
sample_idx: int
|
|
34
|
+
Sample index. The first dimension of the image data will be indexed at this
|
|
35
|
+
value.
|
|
36
|
+
coords: Sequence of int
|
|
37
|
+
The coordinates that define the start of a patch.
|
|
38
|
+
patch_size: Sequence of int
|
|
39
|
+
The size of the patch in each spatial dimension.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
numpy.ndarray
|
|
44
|
+
The patch.
|
|
45
|
+
"""
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def default_patch_constr(
|
|
50
|
+
image_stack: ImageStack,
|
|
51
|
+
sample_idx: int,
|
|
52
|
+
channels: Sequence[int] | None, # `channels = None` to select all channels
|
|
53
|
+
coords: Sequence[int],
|
|
54
|
+
patch_size: Sequence[int],
|
|
55
|
+
) -> NDArray[Any]:
|
|
56
|
+
return image_stack.extract_channel_patch(
|
|
57
|
+
sample_idx=sample_idx,
|
|
58
|
+
channels=channels,
|
|
59
|
+
coords=coords,
|
|
60
|
+
patch_size=patch_size,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# closure to create constructor funcs with particular multiscale_count and padding mode
|
|
65
|
+
def lateral_context_patch_constr(
|
|
66
|
+
# TODO: will we stick with this as the parameter name
|
|
67
|
+
multiscale_count: int,
|
|
68
|
+
# TODO: add other modes?
|
|
69
|
+
padding_mode: Literal["reflect", "wrap"],
|
|
70
|
+
) -> PatchConstructor:
|
|
71
|
+
"""
|
|
72
|
+
Create a lateral context `PatchConstructor` for MicroSplit.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
multiscale_count : int
|
|
77
|
+
The number of multiscale inputs that will be created including the original
|
|
78
|
+
image size.
|
|
79
|
+
padding_mode : {"reflect", "wrap"}
|
|
80
|
+
How lateral context inputs will be padded at the edge of the image. See
|
|
81
|
+
[`numpy.pad`](https://numpy.org/devdocs/reference/generated/numpy.pad.html) for
|
|
82
|
+
more information.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
PatchConstructor
|
|
87
|
+
The patch constructor function. It will return patches with the dimensions
|
|
88
|
+
(C, L, (Z), Y, X) where L will be equal to `multiscale_count`, C is the number
|
|
89
|
+
of channels in the image, and (Z), Y, X are the patch size.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def constructor_func(
|
|
93
|
+
image_stack: ImageStack,
|
|
94
|
+
sample_idx: int,
|
|
95
|
+
channels: Sequence[int] | None, # `channels = None` to select all channels
|
|
96
|
+
coords: Sequence[int],
|
|
97
|
+
patch_size: Sequence[int],
|
|
98
|
+
) -> NDArray[Any]:
|
|
99
|
+
if channels is not None and len(channels) > 1:
|
|
100
|
+
raise NotImplementedError(
|
|
101
|
+
"Selecting multiple channels is currently not implemented for lateral "
|
|
102
|
+
"context patches. Select a single channel or pass `channels=None` to "
|
|
103
|
+
"select all channels."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
shape = image_stack.data_shape
|
|
107
|
+
spatial_shape = shape[2:]
|
|
108
|
+
n_channels = shape[1] if channels is None else 1
|
|
109
|
+
|
|
110
|
+
# There will now be an additional lc dimension,
|
|
111
|
+
# this has to be handled correctly by the dataset
|
|
112
|
+
# TODO: maybe we want to limit this constructor to only images with 1 channel
|
|
113
|
+
# then we can put LCs in the channel dimension
|
|
114
|
+
# but not sure if this artificially limits potential use-cases
|
|
115
|
+
patch = np.zeros((n_channels, multiscale_count, *patch_size))
|
|
116
|
+
for scale in range(multiscale_count):
|
|
117
|
+
lc_patch_size = np.array(patch_size) * (2**scale)
|
|
118
|
+
lc_start = np.array(coords) + np.array(patch_size) // 2 - lc_patch_size // 2
|
|
119
|
+
lc_end = lc_start + np.array(lc_patch_size)
|
|
120
|
+
|
|
121
|
+
start_clipped = np.clip(
|
|
122
|
+
lc_start, np.zeros_like(spatial_shape), np.array(spatial_shape)
|
|
123
|
+
)
|
|
124
|
+
end_clipped = np.clip(
|
|
125
|
+
lc_end, np.zeros_like(spatial_shape), np.array(spatial_shape)
|
|
126
|
+
)
|
|
127
|
+
size_clipped = end_clipped - start_clipped
|
|
128
|
+
|
|
129
|
+
lc_patch = image_stack.extract_channel_patch(
|
|
130
|
+
sample_idx, channels, start_clipped, size_clipped
|
|
131
|
+
)
|
|
132
|
+
pad_before = start_clipped - lc_start
|
|
133
|
+
pad_after = lc_end - end_clipped
|
|
134
|
+
pad_width = np.concat(
|
|
135
|
+
[
|
|
136
|
+
# zeros to not pad the channel axis
|
|
137
|
+
np.zeros((1, 2), dtype=int),
|
|
138
|
+
np.stack([pad_before, pad_after], axis=-1),
|
|
139
|
+
]
|
|
140
|
+
)
|
|
141
|
+
lc_patch = np.pad(
|
|
142
|
+
lc_patch,
|
|
143
|
+
pad_width,
|
|
144
|
+
mode=padding_mode,
|
|
145
|
+
)
|
|
146
|
+
# TODO: test different downscaling? skimage suggests downscale_local_mean
|
|
147
|
+
lc_patch = resize(lc_patch, (n_channels, *patch_size))
|
|
148
|
+
patch[:, scale, ...] = lc_patch
|
|
149
|
+
return patch
|
|
150
|
+
|
|
151
|
+
return constructor_func
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Generic
|
|
3
|
+
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
from ..image_stack import GenericImageStack
|
|
7
|
+
from .patch_construction import PatchConstructor, default_patch_constr
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PatchExtractor(Generic[GenericImageStack]):
|
|
11
|
+
"""
|
|
12
|
+
A class for extracting patches from multiple image stacks.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
image_stacks: Sequence[GenericImageStack],
|
|
18
|
+
patch_constructor: PatchConstructor = default_patch_constr,
|
|
19
|
+
):
|
|
20
|
+
self.patch_constructor = patch_constructor
|
|
21
|
+
self.image_stacks: list[GenericImageStack] = list(image_stacks)
|
|
22
|
+
|
|
23
|
+
# check all image stacks have the same number of dimensions
|
|
24
|
+
# check all image stacks have the same number of channels
|
|
25
|
+
self.n_spatial_dims = len(self.image_stacks[0].data_shape) - 2 # SC(Z)YX
|
|
26
|
+
self.n_channels = self.image_stacks[0].data_shape[1]
|
|
27
|
+
for i, image_stack in enumerate(image_stacks):
|
|
28
|
+
if (ndims := len(image_stack.data_shape) - 2) != self.n_spatial_dims:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"All `ImageStack` objects in a `PatchExtractor` must have the same "
|
|
31
|
+
"number of spatial dimensions. The first image stack is "
|
|
32
|
+
f"{self.n_spatial_dims}D but found a {ndims}D image stack at index "
|
|
33
|
+
f"{i}."
|
|
34
|
+
)
|
|
35
|
+
if (n_channels := image_stack.data_shape[1]) != self.n_channels:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"All `ImageStack` objects in a `PatchExtractor` must have the same "
|
|
38
|
+
f"number of channels. The first image stack has {self.n_channels} "
|
|
39
|
+
f"but found an image stack with {n_channels} at index {i}."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def extract_patch(
|
|
43
|
+
self,
|
|
44
|
+
data_idx: int,
|
|
45
|
+
sample_idx: int,
|
|
46
|
+
coords: Sequence[int],
|
|
47
|
+
patch_size: Sequence[int],
|
|
48
|
+
) -> NDArray:
|
|
49
|
+
"""Extract a patch from the specified image stack across all channels.
|
|
50
|
+
|
|
51
|
+
Eqauivalent to calling `extract_channel_patch` with `channels=None`.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
data_idx : int
|
|
56
|
+
Index of the image stack to extract the patch from.
|
|
57
|
+
sample_idx : int
|
|
58
|
+
Sample index. The first dimension of the image data will be indexed at this
|
|
59
|
+
value.
|
|
60
|
+
coords : Sequence of int
|
|
61
|
+
The coordinates that define the start of a patch.
|
|
62
|
+
patch_size : Sequence of int
|
|
63
|
+
The size of the patch in each spatial dimension.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
numpy.ndarray
|
|
68
|
+
The extracted patch.
|
|
69
|
+
"""
|
|
70
|
+
return self.extract_channel_patch(
|
|
71
|
+
data_idx=data_idx,
|
|
72
|
+
sample_idx=sample_idx,
|
|
73
|
+
channels=None,
|
|
74
|
+
coords=coords,
|
|
75
|
+
patch_size=patch_size,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def extract_channel_patch(
|
|
79
|
+
self,
|
|
80
|
+
data_idx: int,
|
|
81
|
+
sample_idx: int,
|
|
82
|
+
channels: Sequence[int] | None,
|
|
83
|
+
coords: Sequence[int],
|
|
84
|
+
patch_size: Sequence[int],
|
|
85
|
+
) -> NDArray:
|
|
86
|
+
"""Extract a patch from the specified image stack.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
data_idx : int
|
|
91
|
+
Index of the image stack to extract the patch from.
|
|
92
|
+
sample_idx : int
|
|
93
|
+
Sample index. The first dimension of the image data will be indexed at this
|
|
94
|
+
value.
|
|
95
|
+
channels : Sequence of int | None
|
|
96
|
+
Channels to extract. If `None`, all channels are extracted.
|
|
97
|
+
coords : Sequence of int
|
|
98
|
+
The coordinates that define the start of a patch.
|
|
99
|
+
patch_size : Sequence of int
|
|
100
|
+
The size of the patch in each spatial dimension.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
numpy.ndarray
|
|
105
|
+
The extracted patch.
|
|
106
|
+
"""
|
|
107
|
+
return self.patch_constructor(
|
|
108
|
+
self.image_stacks[data_idx],
|
|
109
|
+
sample_idx=sample_idx,
|
|
110
|
+
channels=channels,
|
|
111
|
+
coords=coords,
|
|
112
|
+
patch_size=patch_size,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def shapes(self) -> list[Sequence[int]]:
|
|
117
|
+
return [stack.data_shape for stack in self.image_stacks]
|