careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""A module for random patching strategies."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from .patching_strategy_protocol import PatchSpecs
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RandomPatchingStrategy:
|
|
11
|
+
"""
|
|
12
|
+
A patching strategy for sampling random patches, it implements the
|
|
13
|
+
`PatchingStrategy` `Protocol`.
|
|
14
|
+
|
|
15
|
+
The output of `get_patch_spec` will be random, i.e. if the same index is given
|
|
16
|
+
twice the two outputs can be different.
|
|
17
|
+
|
|
18
|
+
However the strategy still ensures that there will be a known number of patches for
|
|
19
|
+
each sample in each image stack. This is achieved through defining a set of bins
|
|
20
|
+
that map to each sample in each image stack. Whichever bin an `index` passed to
|
|
21
|
+
`get_patch_spec` falls into, determines the `"data_idx"` and `"sample_idx"` in
|
|
22
|
+
the returned `PatchSpecs`, but the `"coords"` will be random.
|
|
23
|
+
|
|
24
|
+
The number of patches in each sample is based on the number of patches that would
|
|
25
|
+
fit if they were sampled sequentially, non-overlapping, and covering the entire
|
|
26
|
+
array.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
data_shapes: Sequence[Sequence[int]],
|
|
32
|
+
patch_size: Sequence[int],
|
|
33
|
+
seed: int | None = None,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
A patching strategy for sampling random patches.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
data_shapes : sequence of (sequence of int)
|
|
41
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
42
|
+
axes SC(Z)YX.
|
|
43
|
+
patch_size : sequence of int
|
|
44
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
45
|
+
data respectively.
|
|
46
|
+
seed : int, optional
|
|
47
|
+
An optional seed to ensure the reproducibility of the random patches.
|
|
48
|
+
"""
|
|
49
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
50
|
+
self.patch_size = patch_size
|
|
51
|
+
self.data_shapes = data_shapes
|
|
52
|
+
|
|
53
|
+
# these bins will determine which image stack and sample a patch comes from
|
|
54
|
+
# the image_stack_cumulative_patches map a patch index to each image stack
|
|
55
|
+
# the sample_cumulative_patches map a patch index to each sample
|
|
56
|
+
# the image_stack_cumulative_samples map a sample index to each image stack
|
|
57
|
+
(
|
|
58
|
+
self.image_stack_cumulative_patches,
|
|
59
|
+
self.sample_cumulative_patches,
|
|
60
|
+
self.image_stack_cumulative_samples,
|
|
61
|
+
) = self._calc_bins(self.data_shapes, self.patch_size)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def n_patches(self) -> int:
|
|
65
|
+
"""
|
|
66
|
+
The number of patches that this patching strategy will return.
|
|
67
|
+
|
|
68
|
+
It also determines the maximum index that can be given to `get_patch_spec`.
|
|
69
|
+
"""
|
|
70
|
+
# last bin boundary will be total patches
|
|
71
|
+
return self.image_stack_cumulative_patches[-1]
|
|
72
|
+
|
|
73
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
74
|
+
"""Return the patch specs for a given index.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
index : int
|
|
79
|
+
A patch index.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
PatchSpecs
|
|
84
|
+
A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
85
|
+
"""
|
|
86
|
+
# TODO: break into smaller testable functions?
|
|
87
|
+
if index >= self.n_patches:
|
|
88
|
+
raise IndexError(
|
|
89
|
+
f"Index {index} out of bounds for RandomPatchingStrategy with number "
|
|
90
|
+
f"of patches {self.n_patches}"
|
|
91
|
+
)
|
|
92
|
+
# digitize returns the bin that `index` belongs to
|
|
93
|
+
data_index = np.digitize(index, bins=self.image_stack_cumulative_patches).item()
|
|
94
|
+
# maps to a particular sample within the whole series of image stacks
|
|
95
|
+
# (not just a single image stack)
|
|
96
|
+
total_samples_index = np.digitize(
|
|
97
|
+
index, bins=self.sample_cumulative_patches
|
|
98
|
+
).item()
|
|
99
|
+
|
|
100
|
+
data_shape = self.data_shapes[data_index]
|
|
101
|
+
spatial_shape = data_shape[2:]
|
|
102
|
+
|
|
103
|
+
# calculate sample index relative to image stack:
|
|
104
|
+
# subtract the total number of samples in the previous image stacks
|
|
105
|
+
if data_index == 0:
|
|
106
|
+
n_previous_samples = 0
|
|
107
|
+
else:
|
|
108
|
+
n_previous_samples = self.image_stack_cumulative_samples[data_index - 1]
|
|
109
|
+
sample_index = total_samples_index - n_previous_samples
|
|
110
|
+
coords = _generate_random_coords(spatial_shape, self.patch_size, self.rng)
|
|
111
|
+
return {
|
|
112
|
+
"data_idx": data_index,
|
|
113
|
+
"sample_idx": sample_index,
|
|
114
|
+
"coords": coords,
|
|
115
|
+
"patch_size": self.patch_size,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
# Note: this is used by the FileIterSampler
|
|
119
|
+
def get_patch_indices(self, data_idx: int) -> Sequence[int]:
|
|
120
|
+
"""
|
|
121
|
+
Get the patch indices will return patches for a specific `image_stack`.
|
|
122
|
+
|
|
123
|
+
The `image_stack` corresponds to the given `data_idx`.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
data_idx : int
|
|
128
|
+
An index that corresponds to a given `image_stack`.
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
sequence of int
|
|
133
|
+
A sequence of patch indices, that when used to index the `CAREamicsDataset
|
|
134
|
+
will return a patch that comes from the `image_stack` corresponding to the
|
|
135
|
+
given `data_idx`.
|
|
136
|
+
"""
|
|
137
|
+
# return all the values in the corresponding bin
|
|
138
|
+
if data_idx == 0:
|
|
139
|
+
start = 0
|
|
140
|
+
else:
|
|
141
|
+
start = self.image_stack_cumulative_patches[data_idx - 1]
|
|
142
|
+
|
|
143
|
+
return np.arange(start, self.image_stack_cumulative_patches[data_idx]).tolist()
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def _calc_bins(
|
|
147
|
+
data_shapes: Sequence[Sequence[int]], patch_size: Sequence[int]
|
|
148
|
+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
149
|
+
"""Calculate bins used to map an index to an image_stack and a sample.
|
|
150
|
+
|
|
151
|
+
The number of patches in each sample is based on the number of patches that
|
|
152
|
+
would fit if they were sampled sequentially.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
data_shapes : sequence of (sequence of int)
|
|
157
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
158
|
+
axes SC(Z)YX.
|
|
159
|
+
patch_size : sequence of int
|
|
160
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
161
|
+
data respectively.
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
image_stack_cumulative_patches: tuple of int
|
|
166
|
+
The bins that map a patch index to an image stack. E.g. if a patch index
|
|
167
|
+
falls below the first bin boundary it belongs to the first image stack, if
|
|
168
|
+
a patch index falls between the first bin boundary and the second bin
|
|
169
|
+
boundary it belongs to the second image stack, and so on.
|
|
170
|
+
sample_cumulative_patches: tuple of int
|
|
171
|
+
The bins that map a patch index to a sample. E.g. if a patch index
|
|
172
|
+
falls below the first bin boundary it belongs to the first sample, if
|
|
173
|
+
a patch index falls between the first bin boundary and the second bin
|
|
174
|
+
boundary it belongs to the second sample, and so on.
|
|
175
|
+
image_stack_cumulative_samples: tuple of int
|
|
176
|
+
The bins that map a sample index to an image stack. E.g. if a sample index
|
|
177
|
+
falls below the first bin boundary it belongs to the first image stack, if
|
|
178
|
+
a patch index falls between the first bin boundary and the second bin
|
|
179
|
+
boundary it belongs to the second image stack, and so on.
|
|
180
|
+
"""
|
|
181
|
+
patches_per_image_stack: list[int] = []
|
|
182
|
+
patches_per_sample: list[int] = []
|
|
183
|
+
samples_per_image_stack: list[int] = []
|
|
184
|
+
for data_shape in data_shapes:
|
|
185
|
+
spatial_shape = data_shape[2:]
|
|
186
|
+
n_single_sample_patches = _calc_n_patches(spatial_shape, patch_size)
|
|
187
|
+
# multiply by number of samples in image_stack
|
|
188
|
+
patches_per_image_stack.append(n_single_sample_patches * data_shape[0])
|
|
189
|
+
# list of length `sample` filled with `n_single_sample_patches`
|
|
190
|
+
patches_per_sample.extend([n_single_sample_patches] * data_shape[0])
|
|
191
|
+
# number of samples in each image stack
|
|
192
|
+
samples_per_image_stack.append(data_shape[0])
|
|
193
|
+
|
|
194
|
+
# cumulative sum creates the bins
|
|
195
|
+
image_stack_cumulative_patches = np.cumsum(patches_per_image_stack)
|
|
196
|
+
sample_cumulative_patches = np.cumsum(patches_per_sample)
|
|
197
|
+
image_stack_cumulative_samples = np.cumsum(samples_per_image_stack)
|
|
198
|
+
return (
|
|
199
|
+
tuple(image_stack_cumulative_patches),
|
|
200
|
+
tuple(sample_cumulative_patches),
|
|
201
|
+
tuple(image_stack_cumulative_samples),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class FixedRandomPatchingStrategy:
|
|
206
|
+
"""
|
|
207
|
+
A patching strategy for sampling random patches it implements the `PatchingStrategy`
|
|
208
|
+
`Protocol`.
|
|
209
|
+
|
|
210
|
+
The output of `get_patch_spec` will be deterministic, i.e. if the same index is
|
|
211
|
+
given twice the two outputs will be the same.
|
|
212
|
+
|
|
213
|
+
The number of patches in each sample is based on the number of patches that would
|
|
214
|
+
fit if they were sampled sequentially, non-overlapping, and covering the entire
|
|
215
|
+
array.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
data_shapes: Sequence[Sequence[int]],
|
|
221
|
+
patch_size: Sequence[int],
|
|
222
|
+
seed: int | None = None,
|
|
223
|
+
):
|
|
224
|
+
"""A patching strategy for sampling random patches.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
data_shapes : sequence of (sequence of int)
|
|
229
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
230
|
+
axes SC(Z)YX.
|
|
231
|
+
patch_size : sequence of int
|
|
232
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
233
|
+
data respectively.
|
|
234
|
+
seed : int, optional
|
|
235
|
+
An optional seed to ensure the reproducibility of the random patches.
|
|
236
|
+
"""
|
|
237
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
238
|
+
self.patch_size = patch_size
|
|
239
|
+
self.data_shapes = data_shapes
|
|
240
|
+
|
|
241
|
+
# simply generate all the patches at initialisation, so they will be fixed
|
|
242
|
+
self.fixed_patch_specs: list[PatchSpecs] = []
|
|
243
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
244
|
+
spatial_shape = data_shape[2:]
|
|
245
|
+
n_patches = _calc_n_patches(spatial_shape, self.patch_size)
|
|
246
|
+
for sample_idx in range(data_shape[0]):
|
|
247
|
+
for _ in range(n_patches):
|
|
248
|
+
random_coords = _generate_random_coords(
|
|
249
|
+
spatial_shape, self.patch_size, self.rng
|
|
250
|
+
)
|
|
251
|
+
patch_specs: PatchSpecs = {
|
|
252
|
+
"data_idx": data_idx,
|
|
253
|
+
"sample_idx": sample_idx,
|
|
254
|
+
"coords": random_coords,
|
|
255
|
+
"patch_size": self.patch_size,
|
|
256
|
+
}
|
|
257
|
+
self.fixed_patch_specs.append(patch_specs)
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def n_patches(self):
|
|
261
|
+
"""
|
|
262
|
+
The number of patches that this patching strategy will return.
|
|
263
|
+
|
|
264
|
+
It also determines the maximum index that can be given to `get_patch_spec`.
|
|
265
|
+
"""
|
|
266
|
+
return len(self.fixed_patch_specs)
|
|
267
|
+
|
|
268
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
269
|
+
"""Return the patch specs for a given index.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
index : int
|
|
274
|
+
A patch index.
|
|
275
|
+
|
|
276
|
+
Returns
|
|
277
|
+
-------
|
|
278
|
+
PatchSpecs
|
|
279
|
+
A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
280
|
+
"""
|
|
281
|
+
if index >= self.n_patches:
|
|
282
|
+
raise IndexError(
|
|
283
|
+
f"Index {index} out of bounds for FixedRandomPatchingStrategy with "
|
|
284
|
+
f"number of patches, {self.n_patches}"
|
|
285
|
+
)
|
|
286
|
+
# simply index the pre-generated patches to get the correct patch
|
|
287
|
+
return self.fixed_patch_specs[index]
|
|
288
|
+
|
|
289
|
+
# Note: this is used by the FileIterSampler
|
|
290
|
+
def get_patch_indices(self, data_idx: int) -> Sequence[int]:
|
|
291
|
+
"""
|
|
292
|
+
Get the patch indices will return patches for a specific `image_stack`.
|
|
293
|
+
|
|
294
|
+
The `image_stack` corresponds to the given `data_idx`.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
data_idx : int
|
|
299
|
+
An index that corresponds to a given `image_stack`.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
sequence of int
|
|
304
|
+
A sequence of patch indices, that when used to index the `CAREamicsDataset
|
|
305
|
+
will return a patch that comes from the `image_stack` corresponding to the
|
|
306
|
+
given `data_idx`.
|
|
307
|
+
"""
|
|
308
|
+
return [
|
|
309
|
+
i
|
|
310
|
+
for i, patch_spec in enumerate(self.fixed_patch_specs)
|
|
311
|
+
if patch_spec["data_idx"] == data_idx
|
|
312
|
+
]
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _generate_random_coords(
|
|
316
|
+
spatial_shape: Sequence[int], patch_size: Sequence[int], rng: np.random.Generator
|
|
317
|
+
) -> tuple[int, ...]:
|
|
318
|
+
"""Generate random patch coordinates for a given `spatial_shape` and `patch_size`.
|
|
319
|
+
|
|
320
|
+
The coords are the top-left (and first z-slice for 3D data) of a patch. The
|
|
321
|
+
sequence will have length 2 or 3, for 2D and 3D data respectively.
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
spatial_shape : sequence of int
|
|
326
|
+
The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
|
|
327
|
+
data respectively.
|
|
328
|
+
patch_size : sequence of int
|
|
329
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
330
|
+
data respectively.
|
|
331
|
+
rng : numpy.random.Generator
|
|
332
|
+
A numpy generator to ensure the reproducibility of the random patches.
|
|
333
|
+
|
|
334
|
+
Returns
|
|
335
|
+
-------
|
|
336
|
+
coords: tuple of int
|
|
337
|
+
The top-left (and first z-slice for 3D data) coords of a patch. The tuple will
|
|
338
|
+
have length 2 or 3, for 2D and 3D data respectively.
|
|
339
|
+
|
|
340
|
+
Raises
|
|
341
|
+
------
|
|
342
|
+
ValueError
|
|
343
|
+
Raises if the number of spatial dimensions do not match the number of patch
|
|
344
|
+
dimensions.
|
|
345
|
+
"""
|
|
346
|
+
if len(patch_size) != len(spatial_shape):
|
|
347
|
+
raise ValueError(
|
|
348
|
+
f"Number of patch dimension {len(patch_size)}, do not match the number of "
|
|
349
|
+
f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
|
|
350
|
+
f"and `spatial_shape={spatial_shape}`."
|
|
351
|
+
)
|
|
352
|
+
return tuple(
|
|
353
|
+
rng.integers(
|
|
354
|
+
np.zeros(len(patch_size), dtype=int),
|
|
355
|
+
np.clip(np.array(spatial_shape) - np.array(patch_size), 0, None),
|
|
356
|
+
endpoint=True,
|
|
357
|
+
dtype=int,
|
|
358
|
+
).tolist()
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) -> int:
|
|
363
|
+
"""
|
|
364
|
+
Calculates the number of patches for a given `spatial_shape` and `patch_size`.
|
|
365
|
+
|
|
366
|
+
This is based on the number of patches that would fit if they were sampled
|
|
367
|
+
sequentially.
|
|
368
|
+
|
|
369
|
+
Parameters
|
|
370
|
+
----------
|
|
371
|
+
spatial_shape : sequence of int
|
|
372
|
+
The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
|
|
373
|
+
data respectively.
|
|
374
|
+
patch_size : sequence of int
|
|
375
|
+
The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
|
|
376
|
+
data respectively.
|
|
377
|
+
|
|
378
|
+
Returns
|
|
379
|
+
-------
|
|
380
|
+
int
|
|
381
|
+
The number of patches.
|
|
382
|
+
"""
|
|
383
|
+
if len(patch_size) != len(spatial_shape):
|
|
384
|
+
raise ValueError(
|
|
385
|
+
f"Number of patch dimension {len(patch_size)}, do not match the number of "
|
|
386
|
+
f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
|
|
387
|
+
f"and `spatial_shape={spatial_shape}`."
|
|
388
|
+
)
|
|
389
|
+
patches_per_dim = [
|
|
390
|
+
np.ceil(s / p) for s, p in zip(spatial_shape, patch_size, strict=False)
|
|
391
|
+
]
|
|
392
|
+
total_patches = int(np.prod(patches_per_dim))
|
|
393
|
+
return total_patches
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing_extensions import ParamSpec
|
|
6
|
+
|
|
7
|
+
from .patching_strategy_protocol import PatchSpecs
|
|
8
|
+
|
|
9
|
+
P = ParamSpec("P")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO: this is an unfinished prototype based on current tiling implementation
|
|
13
|
+
# not guaranteed to work!
|
|
14
|
+
class SequentialPatchingStrategy:
|
|
15
|
+
# TODO: docs
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
data_shapes: Sequence[Sequence[int]],
|
|
19
|
+
patch_size: Sequence[int],
|
|
20
|
+
overlaps: Sequence[int] | None = None,
|
|
21
|
+
):
|
|
22
|
+
self.data_shapes = data_shapes
|
|
23
|
+
self.patch_size = patch_size
|
|
24
|
+
if overlaps is None:
|
|
25
|
+
overlaps = [0] * len(patch_size)
|
|
26
|
+
self.overlaps = np.asarray(overlaps)
|
|
27
|
+
|
|
28
|
+
self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def n_patches(self) -> int:
|
|
32
|
+
return len(self.patch_specs)
|
|
33
|
+
|
|
34
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
35
|
+
return self.patch_specs[index]
|
|
36
|
+
|
|
37
|
+
# Note: this is used by the FileIterSampler
|
|
38
|
+
def get_patch_indices(self, data_idx: int) -> Sequence[int]:
|
|
39
|
+
"""
|
|
40
|
+
Get the patch indices will return patches for a specific `image_stack`.
|
|
41
|
+
|
|
42
|
+
The `image_stack` corresponds to the given `data_idx`.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
data_idx : int
|
|
47
|
+
An index that corresponds to a given `image_stack`.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
sequence of int
|
|
52
|
+
A sequence of patch indices, that when used to index the `CAREamicsDataset
|
|
53
|
+
will return a patch that comes from the `image_stack` corresponding to the
|
|
54
|
+
given `data_idx`.
|
|
55
|
+
"""
|
|
56
|
+
return [
|
|
57
|
+
i
|
|
58
|
+
for i, patch_spec in enumerate(self.patch_specs)
|
|
59
|
+
if patch_spec["data_idx"] == data_idx
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
def _compute_coords_1d(
|
|
63
|
+
self, patch_size: int, spatial_shape: int, overlap: int
|
|
64
|
+
) -> list[tuple[int, int]]:
|
|
65
|
+
step = patch_size - overlap
|
|
66
|
+
crop_coords = []
|
|
67
|
+
|
|
68
|
+
current_pos = 0
|
|
69
|
+
while current_pos <= spatial_shape - patch_size:
|
|
70
|
+
crop_coords.append((current_pos, current_pos + patch_size))
|
|
71
|
+
current_pos += step
|
|
72
|
+
|
|
73
|
+
if crop_coords[-1][1] < spatial_shape:
|
|
74
|
+
crop_coords.append((spatial_shape - patch_size, spatial_shape))
|
|
75
|
+
|
|
76
|
+
return crop_coords
|
|
77
|
+
|
|
78
|
+
def _initialize_patch_specs(self) -> list[PatchSpecs]:
|
|
79
|
+
patch_specs: list[PatchSpecs] = []
|
|
80
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
81
|
+
|
|
82
|
+
data_spatial_shape = data_shape[-len(self.patch_size) :]
|
|
83
|
+
coords_list = [
|
|
84
|
+
self._compute_coords_1d(
|
|
85
|
+
self.patch_size[i], data_spatial_shape[i], self.overlaps[i]
|
|
86
|
+
)
|
|
87
|
+
for i in range(len(self.patch_size))
|
|
88
|
+
]
|
|
89
|
+
for sample_idx in range(data_shape[0]):
|
|
90
|
+
for crop_coord in itertools.product(*coords_list):
|
|
91
|
+
patch_specs.append(
|
|
92
|
+
PatchSpecs(
|
|
93
|
+
data_idx=data_idx,
|
|
94
|
+
sample_idx=sample_idx,
|
|
95
|
+
coords=tuple(coord[0] for coord in crop_coord),
|
|
96
|
+
patch_size=self.patch_size,
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
return patch_specs
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Module for the `TilingStrategy` class."""
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from math import prod
|
|
6
|
+
|
|
7
|
+
from .patching_strategy_protocol import TileSpecs
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TilingStrategy:
|
|
11
|
+
"""
|
|
12
|
+
The tiling strategy should be used for prediction. The `get_patch_specs`
|
|
13
|
+
method returns `TileSpec` dictionaries that contains information on how to
|
|
14
|
+
stitch the tiles back together to create the full image.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
data_shapes: Sequence[Sequence[int]],
|
|
20
|
+
patch_size: Sequence[int],
|
|
21
|
+
overlaps: Sequence[int],
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
The tiling strategy should be used for prediction. The `get_patch_specs`
|
|
25
|
+
method returns `TileSpec` dictionaries that contains information on how to
|
|
26
|
+
stitch the tiles back together to create the full image.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
data_shapes : sequence of (sequence of int)
|
|
31
|
+
The shapes of the underlying data. Each element is the dimension of the
|
|
32
|
+
axes SC(Z)YX.
|
|
33
|
+
patch_size : sequence of int
|
|
34
|
+
The size of the tile. The sequence will have length 2 or 3, for 2D and 3D
|
|
35
|
+
data respectively.
|
|
36
|
+
overlaps : sequence of int
|
|
37
|
+
How much a tile will overlap with adjacent tiles in each spatial dimension.
|
|
38
|
+
"""
|
|
39
|
+
self.data_shapes = data_shapes
|
|
40
|
+
self.patch_size = patch_size
|
|
41
|
+
self.overlaps = overlaps
|
|
42
|
+
# patch_size and overlap should have same length validated in pydantic configs
|
|
43
|
+
self.tile_specs: list[TileSpecs] = self._generate_specs()
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def n_patches(self) -> int:
|
|
47
|
+
"""
|
|
48
|
+
The number of patches that this patching strategy will return.
|
|
49
|
+
|
|
50
|
+
It also determines the maximum index that can be given to `get_patch_spec`.
|
|
51
|
+
"""
|
|
52
|
+
return len(self.tile_specs)
|
|
53
|
+
|
|
54
|
+
def get_patch_spec(self, index: int) -> TileSpecs:
|
|
55
|
+
"""Return the tile specs for a given index.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
index : int
|
|
60
|
+
A patch index.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
TileSpecs
|
|
65
|
+
A dictionary that specifies a single patch in a series of `ImageStacks`.
|
|
66
|
+
"""
|
|
67
|
+
return self.tile_specs[index]
|
|
68
|
+
|
|
69
|
+
# Note: this is used by the FileIterSampler
|
|
70
|
+
def get_patch_indices(self, data_idx: int) -> Sequence[int]:
|
|
71
|
+
"""
|
|
72
|
+
Get the patch indices will return patches for a specific `image_stack`.
|
|
73
|
+
|
|
74
|
+
The `image_stack` corresponds to the given `data_idx`.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
data_idx : int
|
|
79
|
+
An index that corresponds to a given `image_stack`.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
sequence of int
|
|
84
|
+
A sequence of patch indices, that when used to index the `CAREamicsDataset
|
|
85
|
+
will return a patch that comes from the `image_stack` corresponding to the
|
|
86
|
+
given `data_idx`.
|
|
87
|
+
"""
|
|
88
|
+
return [
|
|
89
|
+
i
|
|
90
|
+
for i, patch_spec in enumerate(self.tile_specs)
|
|
91
|
+
if patch_spec["data_idx"] == data_idx
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
def _generate_specs(self) -> list[TileSpecs]:
|
|
95
|
+
tile_specs: list[TileSpecs] = []
|
|
96
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
97
|
+
spatial_shape = data_shape[2:]
|
|
98
|
+
|
|
99
|
+
# spec info for each axis
|
|
100
|
+
axis_specs: list[tuple[list[int], list[int], list[int], list[int]]] = [
|
|
101
|
+
self._compute_1d_coords(
|
|
102
|
+
axis_size, self.patch_size[axis_idx], self.overlaps[axis_idx]
|
|
103
|
+
)
|
|
104
|
+
for axis_idx, axis_size in enumerate(spatial_shape)
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# combine by using zip
|
|
108
|
+
all_coords, all_stitch_coords, all_crop_coords, all_crop_size = zip(
|
|
109
|
+
*axis_specs, strict=False
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# number of tiles for this data_idx
|
|
113
|
+
n_tiles = prod(len(dim) for dim in all_coords) * data_shape[0]
|
|
114
|
+
|
|
115
|
+
# patches will be the same for each sample in a stack
|
|
116
|
+
for sample_idx in range(data_shape[0]):
|
|
117
|
+
# iterate through all combinations using itertools.product
|
|
118
|
+
for coords, stitch_coords, crop_coords, crop_size in zip(
|
|
119
|
+
itertools.product(*all_coords),
|
|
120
|
+
itertools.product(*all_stitch_coords),
|
|
121
|
+
itertools.product(*all_crop_coords),
|
|
122
|
+
itertools.product(*all_crop_size),
|
|
123
|
+
strict=False,
|
|
124
|
+
):
|
|
125
|
+
tile_specs.append(
|
|
126
|
+
{
|
|
127
|
+
# PatchSpecs
|
|
128
|
+
"data_idx": data_idx,
|
|
129
|
+
"sample_idx": sample_idx,
|
|
130
|
+
"coords": coords,
|
|
131
|
+
"patch_size": self.patch_size,
|
|
132
|
+
# TileSpecs additional fields
|
|
133
|
+
"crop_coords": crop_coords,
|
|
134
|
+
"crop_size": crop_size,
|
|
135
|
+
"stitch_coords": stitch_coords,
|
|
136
|
+
"total_tiles": n_tiles,
|
|
137
|
+
}
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return tile_specs
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def _compute_1d_coords(
|
|
144
|
+
axis_size: int, patch_size: int, overlap: int
|
|
145
|
+
) -> tuple[list[int], list[int], list[int], list[int]]:
|
|
146
|
+
"""
|
|
147
|
+
Computes the TileSpec information for a single axis.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
axis_size : int
|
|
152
|
+
The size of the axis.
|
|
153
|
+
patch_size : int
|
|
154
|
+
The tile size.
|
|
155
|
+
overlap : int
|
|
156
|
+
The tile overlap.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
coords: list of int
|
|
161
|
+
The top-left (and first z-slice for 3D data) of a tile, in coords relative
|
|
162
|
+
to the image.
|
|
163
|
+
stitch_coords: list of int
|
|
164
|
+
Where the tile will be stitched back into an image, taking into account
|
|
165
|
+
that the tile will be cropped, in coords relative to the image.
|
|
166
|
+
crop_coords: list of int
|
|
167
|
+
The top-left side of where the tile will be cropped, in coordinates relative
|
|
168
|
+
to the tile.
|
|
169
|
+
crop_size: list of int
|
|
170
|
+
The size of the cropped tile.
|
|
171
|
+
"""
|
|
172
|
+
coords: list[int] = []
|
|
173
|
+
stitch_coords: list[int] = []
|
|
174
|
+
crop_coords: list[int] = []
|
|
175
|
+
crop_size: list[int] = []
|
|
176
|
+
|
|
177
|
+
step = patch_size - overlap
|
|
178
|
+
for i in range(0, max(1, axis_size - overlap), step):
|
|
179
|
+
if i == 0:
|
|
180
|
+
coords.append(i)
|
|
181
|
+
crop_coords.append(0)
|
|
182
|
+
stitch_coords.append(0)
|
|
183
|
+
if axis_size <= patch_size:
|
|
184
|
+
crop_size.append(axis_size)
|
|
185
|
+
else:
|
|
186
|
+
crop_size.append(patch_size - overlap // 2)
|
|
187
|
+
elif (0 < i) and (i + patch_size < axis_size):
|
|
188
|
+
coords.append(i)
|
|
189
|
+
crop_coords.append(overlap // 2)
|
|
190
|
+
stitch_coords.append(coords[-1] + crop_coords[-1])
|
|
191
|
+
crop_size.append(patch_size - overlap)
|
|
192
|
+
else:
|
|
193
|
+
previous_crop_size = crop_size[-1] if crop_size else 1
|
|
194
|
+
previous_stitch_coord = stitch_coords[-1] if stitch_coords else 0
|
|
195
|
+
previous_tile_end = previous_stitch_coord + previous_crop_size
|
|
196
|
+
|
|
197
|
+
coords.append(max(0, axis_size - patch_size))
|
|
198
|
+
stitch_coords.append(previous_tile_end)
|
|
199
|
+
crop_coords.append(stitch_coords[-1] - coords[-1])
|
|
200
|
+
crop_size.append(axis_size - stitch_coords[-1])
|
|
201
|
+
|
|
202
|
+
return (
|
|
203
|
+
coords,
|
|
204
|
+
stitch_coords,
|
|
205
|
+
crop_coords,
|
|
206
|
+
crop_size,
|
|
207
|
+
)
|