careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Self
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tifffile
|
|
7
|
+
from numpy.typing import DTypeLike, NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.dataset.dataset_utils import reshape_array
|
|
10
|
+
from careamics.file_io.read import ReadFunc, read_tiff
|
|
11
|
+
|
|
12
|
+
from .image_utils.image_stack_utils import channel_slice, pad_patch, reshape_array_shape
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FileImageStack:
|
|
16
|
+
"""
|
|
17
|
+
An ImageStack implementation for data that is coming from a file.
|
|
18
|
+
|
|
19
|
+
The data will not be loaded until the `load` method is called. The `close` method
|
|
20
|
+
can be used to remove the internal reference to the data.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
source: Path,
|
|
26
|
+
axes: str,
|
|
27
|
+
data_shape: tuple[int, ...],
|
|
28
|
+
data_dtype: DTypeLike,
|
|
29
|
+
read_func: ReadFunc,
|
|
30
|
+
read_kwargs: dict[str, Any] | Any = None,
|
|
31
|
+
):
|
|
32
|
+
self.source = source
|
|
33
|
+
self.axes = axes
|
|
34
|
+
self.data_shape = data_shape
|
|
35
|
+
self.data_dtype = data_dtype
|
|
36
|
+
self.read_func = read_func
|
|
37
|
+
self.read_kwargs = read_kwargs
|
|
38
|
+
self._data: NDArray | None = None
|
|
39
|
+
|
|
40
|
+
def extract_patch(
|
|
41
|
+
self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
|
|
42
|
+
) -> NDArray:
|
|
43
|
+
return self.extract_channel_patch(sample_idx, None, coords, patch_size)
|
|
44
|
+
|
|
45
|
+
def extract_channel_patch(
|
|
46
|
+
self,
|
|
47
|
+
sample_idx: int,
|
|
48
|
+
channels: Sequence[int] | None, # `channels = None` to select all channels
|
|
49
|
+
coords: Sequence[int],
|
|
50
|
+
patch_size: Sequence[int],
|
|
51
|
+
) -> NDArray:
|
|
52
|
+
if self._data is None:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"Cannot extract patch because data has not been loaded from "
|
|
55
|
+
f"'{self.source}', the `load` method must be called first."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if (coord_dims := len(coords)) != (patch_dims := len(patch_size)):
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Patch coordinates and patch size must have the same dimensions but "
|
|
61
|
+
f"found {coord_dims} and {patch_dims}."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# check that channels are within bounds
|
|
65
|
+
if channels is not None:
|
|
66
|
+
max_channel = self.data_shape[1] - 1 # channel is second dimension
|
|
67
|
+
for ch in channels:
|
|
68
|
+
if ch > max_channel:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Channel index {ch} is out of bounds for data with "
|
|
71
|
+
f"{self.data_shape[1]} channels. Check the provided `channels` "
|
|
72
|
+
f"parameter in the configuration for erroneous channel "
|
|
73
|
+
f"indices."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
patch_data = self._data[
|
|
77
|
+
(
|
|
78
|
+
sample_idx, # type: ignore
|
|
79
|
+
# use channel slice so that channel dimension is kept
|
|
80
|
+
channel_slice(channels), # type: ignore
|
|
81
|
+
*[
|
|
82
|
+
slice(
|
|
83
|
+
np.clip(c, 0, self.data_shape[2 + i]),
|
|
84
|
+
np.clip(c + ps, 0, self.data_shape[2 + i]),
|
|
85
|
+
)
|
|
86
|
+
for i, (c, ps) in enumerate(zip(coords, patch_size, strict=False))
|
|
87
|
+
], # type: ignore
|
|
88
|
+
) # type: ignore
|
|
89
|
+
]
|
|
90
|
+
patch = pad_patch(coords, patch_size, self.data_shape, patch_data)
|
|
91
|
+
|
|
92
|
+
return patch
|
|
93
|
+
|
|
94
|
+
def load(self):
|
|
95
|
+
"""Load the data stored in a file."""
|
|
96
|
+
data = self.read_func(self.source)
|
|
97
|
+
self._data = reshape_array(data, self.axes)
|
|
98
|
+
|
|
99
|
+
# TODO: maybe this should be called something else
|
|
100
|
+
def close(self):
|
|
101
|
+
"""Remove the internal reference to the data to clear up memory."""
|
|
102
|
+
# will get cleaned up by the garbage collector since there is no longer a ref
|
|
103
|
+
self._data = None
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def is_loaded(self):
|
|
107
|
+
return self._data is not None
|
|
108
|
+
|
|
109
|
+
@classmethod
|
|
110
|
+
def from_tiff(
|
|
111
|
+
cls,
|
|
112
|
+
path: Path,
|
|
113
|
+
axes: str,
|
|
114
|
+
) -> Self:
|
|
115
|
+
"""
|
|
116
|
+
Construct the `ImageStack` from a TIFF file.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
path : Path
|
|
121
|
+
Path to the TIFF file.
|
|
122
|
+
axes : str
|
|
123
|
+
The original axes of the data, must be a subset of STCZYX.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
Self
|
|
128
|
+
The `ImageStack` with the underlying data being from a TIFF file.
|
|
129
|
+
"""
|
|
130
|
+
# TODO: think this is correct but need more examples to test
|
|
131
|
+
file = tifffile.TiffFile(path)
|
|
132
|
+
data_shape = reshape_array_shape(axes, file.series[0].shape)
|
|
133
|
+
dtype = file.series[0].dtype
|
|
134
|
+
return cls(
|
|
135
|
+
source=path,
|
|
136
|
+
axes=axes,
|
|
137
|
+
data_shape=data_shape,
|
|
138
|
+
data_dtype=dtype,
|
|
139
|
+
read_func=read_tiff,
|
|
140
|
+
)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal, Protocol, TypeVar, Union
|
|
4
|
+
|
|
5
|
+
from numpy.typing import DTypeLike, NDArray
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ImageStack(Protocol):
|
|
9
|
+
"""
|
|
10
|
+
An interface for extracting patches from an image stack.
|
|
11
|
+
|
|
12
|
+
Attributes
|
|
13
|
+
----------
|
|
14
|
+
source: Path or "array"
|
|
15
|
+
Origin of the image data.
|
|
16
|
+
data_shape: Sequence[int]
|
|
17
|
+
The shape of the data, it is expected to be in the order (SC(Z)YX).
|
|
18
|
+
data_dtype: DTypeLike
|
|
19
|
+
The data type of the image data.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def source(self) -> Union[str, Path, Literal["array"]]: ...
|
|
24
|
+
|
|
25
|
+
"""Source of the image data."""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def data_shape(self) -> Sequence[int]: ...
|
|
29
|
+
|
|
30
|
+
"""Shape of the image data."""
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def data_dtype(self) -> DTypeLike: ...
|
|
34
|
+
|
|
35
|
+
"""Data type of the image data."""
|
|
36
|
+
|
|
37
|
+
def extract_patch(
|
|
38
|
+
self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
|
|
39
|
+
) -> NDArray:
|
|
40
|
+
"""
|
|
41
|
+
Extract a patch for a given sample within the image stack.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
sample_idx: int
|
|
46
|
+
Sample index. The first dimension of the image data will be indexed at this
|
|
47
|
+
value.
|
|
48
|
+
coords: Sequence of int
|
|
49
|
+
The coordinates that define the start of a patch.
|
|
50
|
+
patch_size: Sequence of int
|
|
51
|
+
The size of the patch in each spatial dimension.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
numpy.ndarray
|
|
56
|
+
A patch of the image data from a particlular sample. It will have the
|
|
57
|
+
dimensions C(Z)YX.
|
|
58
|
+
"""
|
|
59
|
+
...
|
|
60
|
+
|
|
61
|
+
def extract_channel_patch(
|
|
62
|
+
self,
|
|
63
|
+
sample_idx: int,
|
|
64
|
+
channels: Sequence[int] | None,
|
|
65
|
+
coords: Sequence[int],
|
|
66
|
+
patch_size: Sequence[int],
|
|
67
|
+
) -> NDArray:
|
|
68
|
+
"""
|
|
69
|
+
Extract a patch of a single channel for a given sample within the image stack.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
sample_idx: int
|
|
74
|
+
Sample index. The first dimension of the image data will be indexed at this
|
|
75
|
+
value.
|
|
76
|
+
channels: Sequence[int] | None
|
|
77
|
+
Channel indices to extract. If `None` is given all channels will be
|
|
78
|
+
extracted.
|
|
79
|
+
coords: Sequence of int
|
|
80
|
+
The coordinates that define the start of a patch.
|
|
81
|
+
patch_size: Sequence of int
|
|
82
|
+
The size of the patch in each spatial dimension.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
numpy.ndarray
|
|
87
|
+
A patch of the image data from a particlular sample. It will have the
|
|
88
|
+
dimensions C(Z)YX.
|
|
89
|
+
"""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
GenericImageStack = TypeVar("GenericImageStack", bound=ImageStack, covariant=True)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from types import EllipsisType
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T", bound=np.generic)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def channel_slice(
|
|
12
|
+
channels: Sequence[int] | None,
|
|
13
|
+
) -> EllipsisType | Sequence[int]:
|
|
14
|
+
"""Create a slice or sequence for indexing channels while preserving dimensions.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
channels : Sequence[int] | None
|
|
19
|
+
The channel indices to select, or None to select all channels.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
EllipsisType | Sequence[int]
|
|
24
|
+
An indexing object that can be used to index the channel dimension while
|
|
25
|
+
preserving it.
|
|
26
|
+
"""
|
|
27
|
+
if channels is None:
|
|
28
|
+
return ...
|
|
29
|
+
|
|
30
|
+
if len(channels) == 0:
|
|
31
|
+
raise ValueError("Channel index sequence cannot be empty.")
|
|
32
|
+
|
|
33
|
+
return channels
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# TODO: add tests
|
|
37
|
+
# TODO: move to dataset_utils, better name?
|
|
38
|
+
def reshape_array_shape(
|
|
39
|
+
original_axes: str, shape: Sequence[int], add_singleton: bool = True
|
|
40
|
+
) -> tuple[int, ...]:
|
|
41
|
+
"""Find resulting shape if reshaping array to SC(Z)YX.
|
|
42
|
+
|
|
43
|
+
If `T` is present in the original axes, its size is multiplied into `S`, as both
|
|
44
|
+
axes are multiplexed.
|
|
45
|
+
|
|
46
|
+
Setting `add_singleton` to `False` will only include axes that are present in
|
|
47
|
+
`original_axes` in the output shape.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
original_axes : str
|
|
52
|
+
The axes of the original array, e.g. "TCZYX", "SCYX", etc.
|
|
53
|
+
shape : Sequence[int]
|
|
54
|
+
The shape of the original array.
|
|
55
|
+
add_singleton : bool, default=True
|
|
56
|
+
Whether to add singleton dimensions for missing axes. When `False`, only axes
|
|
57
|
+
present in `original_axes` will be included in the output shape. When `True`,
|
|
58
|
+
missing mandatory axes (`S` and `C`) will be added as singleton dimensions.
|
|
59
|
+
"""
|
|
60
|
+
target_axes = "SCZYX"
|
|
61
|
+
target_shape = []
|
|
62
|
+
for d in target_axes:
|
|
63
|
+
if d in original_axes:
|
|
64
|
+
idx = original_axes.index(d)
|
|
65
|
+
target_shape.append(shape[idx])
|
|
66
|
+
elif d != "Z":
|
|
67
|
+
if add_singleton:
|
|
68
|
+
target_shape.append(1)
|
|
69
|
+
|
|
70
|
+
if "T" in original_axes:
|
|
71
|
+
idx = original_axes.index("T")
|
|
72
|
+
if "S" in original_axes or add_singleton:
|
|
73
|
+
target_shape[0] = target_shape[0] * shape[idx]
|
|
74
|
+
else:
|
|
75
|
+
target_shape.insert(0, shape[idx])
|
|
76
|
+
|
|
77
|
+
return tuple(target_shape)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def pad_patch(
|
|
81
|
+
coords: Sequence[int],
|
|
82
|
+
patch_size: Sequence[int],
|
|
83
|
+
data_shape: Sequence[int],
|
|
84
|
+
patch_data: NDArray[T],
|
|
85
|
+
) -> NDArray[T]:
|
|
86
|
+
"""
|
|
87
|
+
Pad patch data with zeros where it is outside the bounds of it's source image.
|
|
88
|
+
|
|
89
|
+
This ensures the patch data is contained in an array with the expected patch size.
|
|
90
|
+
|
|
91
|
+
If `coords` are negative, the start of the patch will be padded with zeros up until
|
|
92
|
+
where the start of the image would be, and this is where the patch data starts.
|
|
93
|
+
|
|
94
|
+
If the `coords + patch_size` are greater than the bounds of the image then the
|
|
95
|
+
end of the patch will be filled with zeros.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
coords : Sequence[int]
|
|
100
|
+
The coordinates that describe where the patch starts in the spatial dimension of
|
|
101
|
+
the image
|
|
102
|
+
patch_size : Sequence[int]
|
|
103
|
+
The size of the patch in the spatial dimensions.
|
|
104
|
+
data_shape : Sequence[int]
|
|
105
|
+
The shape of the image the patch originates from, must be in the format SC(Z)YX.
|
|
106
|
+
patch_data : NDArray[T]
|
|
107
|
+
The patch data to be padded.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
NDArray[T]
|
|
112
|
+
The resulting padded patch.
|
|
113
|
+
"""
|
|
114
|
+
coords_ = np.array(coords)
|
|
115
|
+
patch = np.zeros((patch_data.shape[0], *patch_size), dtype=patch_data.dtype)
|
|
116
|
+
# data start will be zero unless coords are negative
|
|
117
|
+
data_start = np.clip(coords_, 0, None) - coords_
|
|
118
|
+
data_end = data_start + np.array(patch_data.shape[1:])
|
|
119
|
+
patch[
|
|
120
|
+
(
|
|
121
|
+
slice(None, None, None), # channel slice
|
|
122
|
+
*tuple(slice(s, t) for s, t in zip(data_start, data_end, strict=False)),
|
|
123
|
+
)
|
|
124
|
+
] = patch_data
|
|
125
|
+
return patch
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Literal, Self, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import DTypeLike, NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.dataset.dataset_utils import reshape_array
|
|
9
|
+
from careamics.file_io.read import ReadFunc, read_tiff
|
|
10
|
+
|
|
11
|
+
from .image_utils.image_stack_utils import channel_slice, pad_patch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InMemoryImageStack:
|
|
15
|
+
"""
|
|
16
|
+
A class for extracting patches from an image stack that has been loaded into memory.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, source: Union[Path, Literal["array"]], data: NDArray):
|
|
20
|
+
self.source: Union[str, Path, Literal["array"]] = source
|
|
21
|
+
# data expected to be in SC(Z)YX shape, reason to use from_array constructor
|
|
22
|
+
self._data: NDArray = data
|
|
23
|
+
self.data_shape: Sequence[int] = self._data.shape
|
|
24
|
+
self.data_dtype: DTypeLike = self._data.dtype
|
|
25
|
+
|
|
26
|
+
def extract_patch(
|
|
27
|
+
self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
|
|
28
|
+
) -> NDArray:
|
|
29
|
+
return self.extract_channel_patch(sample_idx, None, coords, patch_size)
|
|
30
|
+
|
|
31
|
+
def extract_channel_patch(
|
|
32
|
+
self,
|
|
33
|
+
sample_idx: int,
|
|
34
|
+
channels: Sequence[int] | None, # `channels = None` to select all channels
|
|
35
|
+
coords: Sequence[int],
|
|
36
|
+
patch_size: Sequence[int],
|
|
37
|
+
) -> NDArray:
|
|
38
|
+
if (coord_dims := len(coords)) != (patch_dims := len(patch_size)):
|
|
39
|
+
raise ValueError(
|
|
40
|
+
"Patch coordinates and patch size must have the same dimensions but "
|
|
41
|
+
f"found {coord_dims} ({coords}) and {patch_dims} ({patch_size})."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# check that channels are within bounds
|
|
45
|
+
if channels is not None:
|
|
46
|
+
max_channel = self.data_shape[1] - 1 # channel is second dimension
|
|
47
|
+
for ch in channels:
|
|
48
|
+
if ch > max_channel:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"Channel index {ch} is out of bounds for data with "
|
|
51
|
+
f"{self.data_shape[1]} channels. Check the provided `channels` "
|
|
52
|
+
f"parameter in the configuration for erroneous channel "
|
|
53
|
+
f"indices."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# TODO: test for 2D or 3D?
|
|
57
|
+
|
|
58
|
+
patch_data = self._data[
|
|
59
|
+
(
|
|
60
|
+
sample_idx, # type: ignore
|
|
61
|
+
# use channel slice so that channel dimension is kept
|
|
62
|
+
channel_slice(channels), # type: ignore
|
|
63
|
+
*[
|
|
64
|
+
slice(
|
|
65
|
+
np.clip(c, 0, self.data_shape[2 + i]),
|
|
66
|
+
np.clip(c + ps, 0, self.data_shape[2 + i]),
|
|
67
|
+
)
|
|
68
|
+
for i, (c, ps) in enumerate(zip(coords, patch_size, strict=False))
|
|
69
|
+
], # type: ignore
|
|
70
|
+
) # type: ignore
|
|
71
|
+
]
|
|
72
|
+
patch = pad_patch(coords, patch_size, self.data_shape, patch_data)
|
|
73
|
+
|
|
74
|
+
return patch
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_array(cls, data: NDArray, axes: str) -> Self:
|
|
78
|
+
data = reshape_array(data, axes)
|
|
79
|
+
return cls(source="array", data=data)
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_tiff(cls, path: Path, axes: str) -> Self:
|
|
83
|
+
data = read_tiff(path)
|
|
84
|
+
data = reshape_array(data, axes)
|
|
85
|
+
return cls(source=path, data=data)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def from_custom_file_type(
|
|
89
|
+
cls, path: Path, axes: str, read_func: ReadFunc, **read_kwargs: Any
|
|
90
|
+
) -> Self:
|
|
91
|
+
data = read_func(path, **read_kwargs)
|
|
92
|
+
data = reshape_array(data, axes)
|
|
93
|
+
return cls(source=path, data=data)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import zarr
|
|
4
|
+
from numpy.typing import DTypeLike, NDArray
|
|
5
|
+
|
|
6
|
+
from careamics.dataset.dataset_utils import reshape_array
|
|
7
|
+
|
|
8
|
+
from .image_utils.image_stack_utils import channel_slice, pad_patch, reshape_array_shape
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ZarrImageStack:
|
|
12
|
+
"""
|
|
13
|
+
A class for extracting patches from an image stack that is stored as a zarr array.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, group: zarr.Group, data_path: str, axes: str):
|
|
17
|
+
if not isinstance(group, zarr.Group):
|
|
18
|
+
raise TypeError(f"group must be a zarr.Group instance, got {type(group)}.")
|
|
19
|
+
|
|
20
|
+
self._group = group
|
|
21
|
+
self._store = str(group.store_path)
|
|
22
|
+
try:
|
|
23
|
+
self._array = group[data_path]
|
|
24
|
+
except KeyError as e:
|
|
25
|
+
raise ValueError(
|
|
26
|
+
f"Did not find array at '{data_path}' in store '{self._store}'."
|
|
27
|
+
) from e
|
|
28
|
+
|
|
29
|
+
if not isinstance(self._array, zarr.Array):
|
|
30
|
+
raise TypeError(
|
|
31
|
+
f"data at path '{data_path}' must be a zarr.Array instance, "
|
|
32
|
+
f"got {type(self._array)}."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
self._source = self._array.store_path
|
|
36
|
+
|
|
37
|
+
# TODO: validate axes
|
|
38
|
+
# - must contain XY
|
|
39
|
+
# - must be subset of STCZYX
|
|
40
|
+
self._original_axes = axes
|
|
41
|
+
self._original_data_shape: tuple[int, ...] = self._array.shape
|
|
42
|
+
self.data_shape = reshape_array_shape(axes, self._original_data_shape)
|
|
43
|
+
self._data_dtype = self._array.dtype
|
|
44
|
+
self._chunk_size = reshape_array_shape(
|
|
45
|
+
axes, self._array.chunks, add_singleton=False
|
|
46
|
+
)
|
|
47
|
+
self._shard_size = (
|
|
48
|
+
reshape_array_shape(axes, self._array.shards, add_singleton=False)
|
|
49
|
+
if self._array.shards is not None
|
|
50
|
+
else None
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Used to identify the source of the data and write to similar path during pred
|
|
54
|
+
@property
|
|
55
|
+
def source(self) -> str:
|
|
56
|
+
# e.g. file://data/bsd68.zarr/train/
|
|
57
|
+
return str(self._source)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def chunks(self) -> Sequence[int]:
|
|
61
|
+
"""Chunks size in the order of data_shape (SC(Z)YX)."""
|
|
62
|
+
return self._chunk_size
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def shards(self) -> Sequence[int] | None:
|
|
66
|
+
"""Shard size in the order of data_shape (SC(Z)YX)."""
|
|
67
|
+
return self._shard_size
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def data_dtype(self) -> DTypeLike:
|
|
71
|
+
return self._data_dtype
|
|
72
|
+
|
|
73
|
+
def extract_patch(
|
|
74
|
+
self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
|
|
75
|
+
) -> NDArray:
|
|
76
|
+
return self.extract_channel_patch(sample_idx, None, coords, patch_size)
|
|
77
|
+
|
|
78
|
+
def extract_channel_patch(
|
|
79
|
+
self,
|
|
80
|
+
sample_idx: int,
|
|
81
|
+
channels: Sequence[int] | None, # `channels = None` to select all channels,
|
|
82
|
+
coords: Sequence[int],
|
|
83
|
+
patch_size: Sequence[int],
|
|
84
|
+
) -> NDArray:
|
|
85
|
+
# original axes assumed to be any subset of STCZYX (containing YX), in any order
|
|
86
|
+
# arguments must be transformed to index data in original axes order
|
|
87
|
+
# to do this: loop through original axes and append correct index/slice
|
|
88
|
+
# for each case: STCZYX
|
|
89
|
+
# Note: if any axis is not present in original_axes it is skipped.
|
|
90
|
+
|
|
91
|
+
# guard for no S and T in original axes
|
|
92
|
+
if ("S" not in self._original_axes) and ("T" not in self._original_axes):
|
|
93
|
+
if sample_idx not in [0, -1]:
|
|
94
|
+
raise IndexError(
|
|
95
|
+
f"Sample index {sample_idx} out of bounds for S axes with size "
|
|
96
|
+
f"{self.data_shape[0]}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# check that channels are within bounds
|
|
100
|
+
if channels is not None:
|
|
101
|
+
max_channel = self.data_shape[1] - 1 # channel is second dimension
|
|
102
|
+
for ch in channels:
|
|
103
|
+
if ch > max_channel:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Channel index {ch} is out of bounds for data with "
|
|
106
|
+
f"{self.data_shape[1]} channels. Check the provided `channels` "
|
|
107
|
+
f"parameter in the configuration for erroneous channel "
|
|
108
|
+
f"indices."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
patch_slice: list[int | slice] = []
|
|
112
|
+
for d in self._original_axes:
|
|
113
|
+
if d == "S":
|
|
114
|
+
patch_slice.append(self._get_S_index(sample_idx))
|
|
115
|
+
elif d == "T":
|
|
116
|
+
patch_slice.append(self._get_T_index(sample_idx))
|
|
117
|
+
elif d == "C":
|
|
118
|
+
patch_slice.append(channel_slice(channels)) # type: ignore
|
|
119
|
+
elif d == "Z":
|
|
120
|
+
patch_slice.append(slice(coords[0], coords[0] + patch_size[0]))
|
|
121
|
+
elif d == "Y":
|
|
122
|
+
y_idx = 0 if "Z" not in self._original_axes else 1
|
|
123
|
+
patch_slice.append(
|
|
124
|
+
slice(coords[y_idx], coords[y_idx] + patch_size[y_idx])
|
|
125
|
+
)
|
|
126
|
+
elif d == "X":
|
|
127
|
+
x_idx = 1 if "Z" not in self._original_axes else 2
|
|
128
|
+
patch_slice.append(
|
|
129
|
+
slice(coords[x_idx], coords[x_idx] + patch_size[x_idx])
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(f"Unrecognised axis '{d}', axes should be in STCZYX.")
|
|
133
|
+
|
|
134
|
+
patch_data: NDArray = self._array[tuple(patch_slice)] # type: ignore
|
|
135
|
+
patch_axes = self._original_axes.replace("S", "").replace("T", "")
|
|
136
|
+
patch_data = reshape_array(patch_data, patch_axes)[0] # remove first sample dim
|
|
137
|
+
patch = pad_patch(coords, patch_size, self.data_shape, patch_data)
|
|
138
|
+
|
|
139
|
+
return patch
|
|
140
|
+
|
|
141
|
+
def _get_T_index(self, sample_idx: int) -> int:
|
|
142
|
+
"""Get T index given `sample_idx`."""
|
|
143
|
+
if "T" not in self._original_axes:
|
|
144
|
+
raise ValueError("No 'T' axis specified in original data axes.")
|
|
145
|
+
axis_idx = self._original_axes.index("T")
|
|
146
|
+
dim = self._original_data_shape[axis_idx]
|
|
147
|
+
|
|
148
|
+
# new S' = S*T
|
|
149
|
+
# T_idx = S_idx' // T_size
|
|
150
|
+
# S_idx = S_idx' % T_size
|
|
151
|
+
# - floor divide finds the row
|
|
152
|
+
# - modulus finds how far along the row i.e. the column
|
|
153
|
+
return sample_idx % dim
|
|
154
|
+
|
|
155
|
+
def _get_S_index(self, sample_idx: int) -> int:
|
|
156
|
+
"""Get S index given `sample_idx`."""
|
|
157
|
+
if "S" not in self._original_axes:
|
|
158
|
+
raise ValueError("No 'S' axis specified in original data axes.")
|
|
159
|
+
if "T" in self._original_axes:
|
|
160
|
+
T_axis_idx = self._original_axes.index("T")
|
|
161
|
+
T_dim = self._original_data_shape[T_axis_idx]
|
|
162
|
+
|
|
163
|
+
# new S' = S*T
|
|
164
|
+
# T_idx = S_idx' // T_size
|
|
165
|
+
# S_idx = S_idx' % T_size
|
|
166
|
+
# - floor divide finds the row
|
|
167
|
+
# - modulus finds how far along the row i.e. the column
|
|
168
|
+
return sample_idx // T_dim
|
|
169
|
+
else:
|
|
170
|
+
return sample_idx
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
__all__ = [
|
|
2
|
+
"ImageStackLoader",
|
|
3
|
+
"load_arrays",
|
|
4
|
+
"load_custom_file",
|
|
5
|
+
"load_czis",
|
|
6
|
+
"load_iter_tiff",
|
|
7
|
+
"load_tiffs",
|
|
8
|
+
"load_zarrs",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from .image_stack_loader_protocol import ImageStackLoader
|
|
12
|
+
from .image_stack_loaders import (
|
|
13
|
+
load_arrays,
|
|
14
|
+
load_custom_file,
|
|
15
|
+
load_czis,
|
|
16
|
+
load_iter_tiff,
|
|
17
|
+
load_tiffs,
|
|
18
|
+
load_zarrs,
|
|
19
|
+
)
|