careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class EmptyPatchFetcher:
|
|
6
|
+
"""
|
|
7
|
+
The idea is to fetch empty patches so that real content can be replaced with this.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, idx_manager, patch_size, data_frames, max_val_threshold=None):
|
|
11
|
+
self._frames = data_frames
|
|
12
|
+
self._idx_manager = idx_manager
|
|
13
|
+
self._max_val_threshold = max_val_threshold
|
|
14
|
+
self._idx_list = []
|
|
15
|
+
self._patch_size = patch_size
|
|
16
|
+
self._grid_size = 1
|
|
17
|
+
self.set_empty_idx()
|
|
18
|
+
|
|
19
|
+
print(f"[{self.__class__.__name__}] MaxVal:{self._max_val_threshold}")
|
|
20
|
+
|
|
21
|
+
def compute_max(self, window):
|
|
22
|
+
"""
|
|
23
|
+
Rolling compute.
|
|
24
|
+
"""
|
|
25
|
+
N, H, W = self._frames.shape
|
|
26
|
+
randnum = -954321
|
|
27
|
+
assert self._grid_size == 1
|
|
28
|
+
max_data = np.zeros((N, H - window, W - window)) * randnum
|
|
29
|
+
|
|
30
|
+
for h in tqdm(range(H - window)):
|
|
31
|
+
for w in range(W - window):
|
|
32
|
+
max_data[:, h, w] = self._frames[:, h : h + window, w : w + window].max(
|
|
33
|
+
axis=(1, 2)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
assert (max_data != 954321).any()
|
|
37
|
+
return max_data
|
|
38
|
+
|
|
39
|
+
def set_empty_idx(self):
|
|
40
|
+
max_data = self.compute_max(self._patch_size)
|
|
41
|
+
empty_loc = np.where(
|
|
42
|
+
np.logical_and(max_data >= 0, max_data < self._max_val_threshold)
|
|
43
|
+
)
|
|
44
|
+
# print(max_data.shape, len(empty_loc))
|
|
45
|
+
self._idx_list = []
|
|
46
|
+
for idx in range(len(empty_loc[0])):
|
|
47
|
+
n_idx = empty_loc[0][idx]
|
|
48
|
+
h_start = empty_loc[1][idx]
|
|
49
|
+
w_start = empty_loc[2][idx]
|
|
50
|
+
# print(n_idx,h_start,w_start)
|
|
51
|
+
# channel_idx = self._idx_manager.get_location_from_dataset_idx(0)[-1]
|
|
52
|
+
loc = (n_idx, h_start, w_start, 0)
|
|
53
|
+
idx = self._idx_manager.get_dataset_idx_from_location(loc)
|
|
54
|
+
t, h, w, _ = self._idx_manager.get_location_from_dataset_idx(idx)
|
|
55
|
+
assert h == h_start, f"{h} != {h_start}"
|
|
56
|
+
assert w == w_start, f"{w} != {w_start}"
|
|
57
|
+
assert t == n_idx, f"{t} != {n_idx}"
|
|
58
|
+
self._idx_list.append(idx)
|
|
59
|
+
|
|
60
|
+
self._idx_list = np.array(self._idx_list)
|
|
61
|
+
|
|
62
|
+
assert len(self._idx_list) > 0
|
|
63
|
+
|
|
64
|
+
def sample(self):
|
|
65
|
+
return (np.random.choice(self._idx_list), self._grid_size)
|
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from careamics.lvae_training.dataset.types import TilingMode
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class GridIndexManager:
|
|
10
|
+
data_shape: tuple
|
|
11
|
+
grid_shape: tuple
|
|
12
|
+
patch_shape: tuple
|
|
13
|
+
tiling_mode: TilingMode
|
|
14
|
+
|
|
15
|
+
# Patch is centered on index in the grid, grid size not used in training,
|
|
16
|
+
# used only during val / test, grid size controls the overlap of the patches
|
|
17
|
+
# in training you only get random patches every time
|
|
18
|
+
# For borders - just cropped the data, so it perfectly divisible
|
|
19
|
+
|
|
20
|
+
def __post_init__(self):
|
|
21
|
+
assert len(self.data_shape) == len(
|
|
22
|
+
self.grid_shape
|
|
23
|
+
), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
|
|
24
|
+
assert len(self.data_shape) == len(
|
|
25
|
+
self.patch_shape
|
|
26
|
+
), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
|
|
27
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
|
28
|
+
for dim, pad in enumerate(innerpad):
|
|
29
|
+
if pad < 0:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
|
|
32
|
+
)
|
|
33
|
+
if pad % 2 != 0:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def patch_offset(self):
|
|
39
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
40
|
+
|
|
41
|
+
def get_individual_dim_grid_count(self, dim: int):
|
|
42
|
+
"""
|
|
43
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
|
44
|
+
"""
|
|
45
|
+
assert dim < len(
|
|
46
|
+
self.data_shape
|
|
47
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
48
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
49
|
+
|
|
50
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
51
|
+
return self.data_shape[dim]
|
|
52
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
|
53
|
+
return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
|
54
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
55
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
56
|
+
return int(
|
|
57
|
+
np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
61
|
+
return int(
|
|
62
|
+
np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def total_grid_count(self):
|
|
66
|
+
"""
|
|
67
|
+
Returns the total number of grids in the dataset.
|
|
68
|
+
"""
|
|
69
|
+
return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
|
70
|
+
|
|
71
|
+
def grid_count(self, dim: int):
|
|
72
|
+
"""
|
|
73
|
+
Returns the total number of grids for one value in the specified dimension.
|
|
74
|
+
"""
|
|
75
|
+
assert dim < len(
|
|
76
|
+
self.data_shape
|
|
77
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
78
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
79
|
+
if dim == len(self.data_shape) - 1:
|
|
80
|
+
return 1
|
|
81
|
+
|
|
82
|
+
return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
|
|
83
|
+
|
|
84
|
+
def get_grid_index(self, dim: int, coordinate: int):
|
|
85
|
+
"""
|
|
86
|
+
Returns the index of the grid in the specified dimension.
|
|
87
|
+
"""
|
|
88
|
+
assert dim < len(
|
|
89
|
+
self.data_shape
|
|
90
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
91
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
92
|
+
assert (
|
|
93
|
+
coordinate < self.data_shape[dim]
|
|
94
|
+
), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
|
|
95
|
+
|
|
96
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
97
|
+
return coordinate
|
|
98
|
+
elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
|
|
99
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
|
100
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
|
101
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
102
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
103
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
104
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
105
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
106
|
+
if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
|
|
107
|
+
return self.get_individual_dim_grid_count(dim) - 1
|
|
108
|
+
else:
|
|
109
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
110
|
+
return max(
|
|
111
|
+
0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
116
|
+
|
|
117
|
+
def dataset_idx_from_grid_idx(self, grid_idx: tuple):
|
|
118
|
+
"""
|
|
119
|
+
Returns the index of the grid in the dataset.
|
|
120
|
+
"""
|
|
121
|
+
assert len(grid_idx) == len(
|
|
122
|
+
self.data_shape
|
|
123
|
+
), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
|
|
124
|
+
index = 0
|
|
125
|
+
for dim in range(len(grid_idx)):
|
|
126
|
+
index += grid_idx[dim] * self.grid_count(dim)
|
|
127
|
+
return index
|
|
128
|
+
|
|
129
|
+
def get_patch_location_from_dataset_idx(self, dataset_idx: int):
|
|
130
|
+
"""
|
|
131
|
+
Returns the patch location of the grid in the dataset.
|
|
132
|
+
"""
|
|
133
|
+
grid_location = self.get_location_from_dataset_idx(dataset_idx)
|
|
134
|
+
offset = self.patch_offset()
|
|
135
|
+
return tuple(np.array(grid_location) - np.array(offset))
|
|
136
|
+
|
|
137
|
+
def get_dataset_idx_from_grid_location(self, location: tuple):
|
|
138
|
+
assert len(location) == len(
|
|
139
|
+
self.data_shape
|
|
140
|
+
), f"Location {location} must have the same dimension as data shape {self.data_shape}"
|
|
141
|
+
grid_idx = [
|
|
142
|
+
self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
143
|
+
]
|
|
144
|
+
return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
|
145
|
+
|
|
146
|
+
def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
|
|
147
|
+
"""
|
|
148
|
+
Returns the grid-start coordinate of the grid in the specified dimension.
|
|
149
|
+
"""
|
|
150
|
+
assert dim < len(
|
|
151
|
+
self.data_shape
|
|
152
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
153
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
154
|
+
# assert dim_index < self.get_individual_dim_grid_count(
|
|
155
|
+
# dim
|
|
156
|
+
# ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
|
|
157
|
+
# TODO comented out this shit cuz I have no interest to dig why it's failing at this point !
|
|
158
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
159
|
+
return dim_index
|
|
160
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
|
161
|
+
return dim_index * self.grid_shape[dim]
|
|
162
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
|
163
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
164
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
|
165
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
166
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
167
|
+
if dim_index < self.get_individual_dim_grid_count(dim) - 1:
|
|
168
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
|
169
|
+
else:
|
|
170
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
|
171
|
+
return self.data_shape[dim] - self.grid_shape[dim] - excess_size
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
174
|
+
|
|
175
|
+
def get_location_from_dataset_idx(self, dataset_idx: int):
|
|
176
|
+
"""
|
|
177
|
+
Returns the start location of the grid in the dataset.
|
|
178
|
+
"""
|
|
179
|
+
grid_idx = []
|
|
180
|
+
for dim in range(len(self.data_shape)):
|
|
181
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
182
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
|
183
|
+
location = [
|
|
184
|
+
self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
185
|
+
for dim in range(len(self.data_shape))
|
|
186
|
+
]
|
|
187
|
+
return tuple(location)
|
|
188
|
+
|
|
189
|
+
def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
|
|
190
|
+
"""
|
|
191
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
|
192
|
+
"""
|
|
193
|
+
assert dim < len(
|
|
194
|
+
self.data_shape
|
|
195
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
196
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
197
|
+
|
|
198
|
+
if dim > 0:
|
|
199
|
+
dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
200
|
+
|
|
201
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
|
202
|
+
if only_end:
|
|
203
|
+
return dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
204
|
+
|
|
205
|
+
return (
|
|
206
|
+
dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
210
|
+
"""
|
|
211
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
212
|
+
"""
|
|
213
|
+
assert dim < len(
|
|
214
|
+
self.data_shape
|
|
215
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
216
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
217
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
|
218
|
+
if new_idx >= self.total_grid_count():
|
|
219
|
+
return None
|
|
220
|
+
return new_idx
|
|
221
|
+
|
|
222
|
+
def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
223
|
+
"""
|
|
224
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
225
|
+
"""
|
|
226
|
+
assert dim < len(
|
|
227
|
+
self.data_shape
|
|
228
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
229
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
230
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
|
231
|
+
if new_idx < 0:
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@dataclass
|
|
236
|
+
class GridIndexManagerRef:
|
|
237
|
+
data_shapes: tuple
|
|
238
|
+
grid_shape: tuple
|
|
239
|
+
patch_shape: tuple
|
|
240
|
+
tiling_mode: TilingMode
|
|
241
|
+
|
|
242
|
+
# This class is used to calculate and store information about patches, and calculate
|
|
243
|
+
# the total length of the dataset in patches.
|
|
244
|
+
# It introduces a concept of a grid, to which input images are split.
|
|
245
|
+
# The grid is defined by the grid_shape and patch_shape, with former controlling the
|
|
246
|
+
# overlap.
|
|
247
|
+
# In this reimplementation it can accept multiple channels with different lengths,
|
|
248
|
+
# and every image can have different shape.
|
|
249
|
+
|
|
250
|
+
def __post_init__(self):
|
|
251
|
+
if len(self.data_shapes) > 1:
|
|
252
|
+
assert {len(ds) for ds in self.data_shapes[0]}.pop() == {
|
|
253
|
+
len(ds) for ds in self.data_shapes[1]
|
|
254
|
+
}.pop(), "Data shape for all channels must be the same" # TODO better way to assert this
|
|
255
|
+
assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
|
|
256
|
+
self.grid_shape
|
|
257
|
+
), "Data shape and grid size must have the same dimension"
|
|
258
|
+
assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
|
|
259
|
+
self.patch_shape
|
|
260
|
+
), "Data shape and patch shape must have the same dimension"
|
|
261
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
|
262
|
+
for dim, pad in enumerate(innerpad):
|
|
263
|
+
if pad < 0:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
f"Patch shape must be greater than or equal to grid shape in dimension {dim}"
|
|
266
|
+
)
|
|
267
|
+
if pad % 2 != 0:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Patch shape must have even padding in dimension {dim}"
|
|
270
|
+
)
|
|
271
|
+
self.num_patches_per_channel = self.total_grid_count()[1]
|
|
272
|
+
|
|
273
|
+
def patch_offset(self):
|
|
274
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
275
|
+
|
|
276
|
+
def get_individual_dim_grid_count(self, shape: tuple, dim: int):
|
|
277
|
+
"""
|
|
278
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
|
279
|
+
"""
|
|
280
|
+
# assert that dim is less than the number of dimensions in data shape
|
|
281
|
+
|
|
282
|
+
# if dim > len()
|
|
283
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
284
|
+
return shape[dim]
|
|
285
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
|
286
|
+
return int(np.ceil(shape[dim] / self.grid_shape[dim]))
|
|
287
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
288
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
289
|
+
return int(np.ceil((shape[dim] - excess_size) / self.grid_shape[dim]))
|
|
290
|
+
# if dim_index < self.get_individual_dim_grid_count(dim) - 1:
|
|
291
|
+
# return dim_index * self.grid_shape[dim] + excess_size
|
|
292
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
|
293
|
+
# return self.data_shape[dim] - self.grid_shape[dim] - excess_size
|
|
294
|
+
else:
|
|
295
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
296
|
+
return int(np.floor((shape[dim] - excess_size) / self.grid_shape[dim]))
|
|
297
|
+
|
|
298
|
+
def total_grid_count(self):
|
|
299
|
+
"""Returns the total number of patches in the dataset."""
|
|
300
|
+
len_per_channel = []
|
|
301
|
+
num_patches_per_sample = []
|
|
302
|
+
for channel_data in self.data_shapes:
|
|
303
|
+
num_patches = []
|
|
304
|
+
for file_shape in channel_data:
|
|
305
|
+
num_patches.append(np.prod(self.grid_count_per_sample(file_shape)))
|
|
306
|
+
len_per_channel.append(np.sum(num_patches))
|
|
307
|
+
num_patches_per_sample.append(num_patches)
|
|
308
|
+
|
|
309
|
+
return len_per_channel, num_patches_per_sample
|
|
310
|
+
|
|
311
|
+
def grid_count_per_sample(self, shape: tuple):
|
|
312
|
+
"""Returns the total number of patches for one dimension."""
|
|
313
|
+
grid_count = []
|
|
314
|
+
for dim in range(len(shape)):
|
|
315
|
+
grid_count.append(self.get_individual_dim_grid_count(shape, dim))
|
|
316
|
+
return grid_count
|
|
317
|
+
|
|
318
|
+
def get_grid_index(self, shape, dim: int, coordinate: int):
|
|
319
|
+
"""Returns the index of the patch in the specified dimension."""
|
|
320
|
+
assert dim < len(
|
|
321
|
+
shape
|
|
322
|
+
), f"Dimension {dim} is out of bounds for data shape {shape}"
|
|
323
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
324
|
+
assert (
|
|
325
|
+
coordinate < shape[dim]
|
|
326
|
+
), f"Coordinate {coordinate} is out of bounds for data shape {shape}"
|
|
327
|
+
|
|
328
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
329
|
+
return coordinate
|
|
330
|
+
elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
|
|
331
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
|
332
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
|
333
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
334
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
335
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
336
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
337
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
338
|
+
if coordinate + self.grid_shape[dim] + excess_size == self.data_shapes[dim]:
|
|
339
|
+
return self.get_individual_dim_grid_count(shape, dim) - 1
|
|
340
|
+
else:
|
|
341
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
342
|
+
return max(
|
|
343
|
+
0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
else:
|
|
347
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
348
|
+
|
|
349
|
+
def patch_idx_from_grid_idx(self, shape: tuple, grid_idx: tuple):
|
|
350
|
+
"""Returns the index of the patch in the dataset."""
|
|
351
|
+
assert len(grid_idx) == len(
|
|
352
|
+
shape
|
|
353
|
+
), f"Dimension indices {grid_idx} must have the same dimension as data shape {shape}"
|
|
354
|
+
index = 0
|
|
355
|
+
for dim in range(len(grid_idx)):
|
|
356
|
+
index += grid_idx[dim] * self.grid_count(shape, dim)
|
|
357
|
+
return index
|
|
358
|
+
|
|
359
|
+
def get_patch_location_from_patch_idx(self, ch_idx: int, patch_idx: int):
|
|
360
|
+
"""Returns the patch location of the grid in the dataset."""
|
|
361
|
+
grid_location = self.get_location_from_patch_idx(ch_idx, patch_idx)
|
|
362
|
+
offset = self.patch_offset()
|
|
363
|
+
return tuple(np.array(grid_location) - np.concatenate((np.array((0,)), offset)))
|
|
364
|
+
|
|
365
|
+
def get_patch_idx_from_grid_location(self, shape, location: tuple):
|
|
366
|
+
assert len(location) == len(
|
|
367
|
+
shape
|
|
368
|
+
), f"Location {location} must have the same dimension as data shape {shape}"
|
|
369
|
+
grid_idx = [
|
|
370
|
+
self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
371
|
+
]
|
|
372
|
+
return self.patch_idx_from_grid_idx(tuple(grid_idx))
|
|
373
|
+
|
|
374
|
+
def get_gridstart_location_from_dim_index(
|
|
375
|
+
self, shape: tuple, dim_idx: int, dim: int
|
|
376
|
+
):
|
|
377
|
+
"""Returns the grid-start coordinate of the grid in the specified dimension.
|
|
378
|
+
|
|
379
|
+
dim_idx: int
|
|
380
|
+
Index of the dimension in the data shape.
|
|
381
|
+
dim: int
|
|
382
|
+
Value of the dimension in the grid (relative to num patches in dimension).
|
|
383
|
+
"""
|
|
384
|
+
if self.grid_shape[dim_idx] == 1 and self.patch_shape[dim_idx] == 1:
|
|
385
|
+
return dim_idx
|
|
386
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
|
387
|
+
excess_size = (self.patch_shape[dim_idx] - self.grid_shape[dim_idx]) // 2
|
|
388
|
+
if dim < self.get_individual_dim_grid_count(shape, dim_idx) - 1:
|
|
389
|
+
return dim * self.grid_shape[dim_idx] + excess_size
|
|
390
|
+
else:
|
|
391
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
|
392
|
+
return shape[dim_idx] - self.grid_shape[dim_idx] - excess_size
|
|
393
|
+
else:
|
|
394
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
|
395
|
+
|
|
396
|
+
def get_location_from_patch_idx(self, channel_idx: int, patch_idx: int):
|
|
397
|
+
"""
|
|
398
|
+
Returns the start location of the grid in the dataset. Per channel!.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
patch_idx : int
|
|
403
|
+
The index of the patch in a list of samples within a channel. Channels can
|
|
404
|
+
be different in length.
|
|
405
|
+
"""
|
|
406
|
+
# TODO assert patch_idx <= num of patches in the channel
|
|
407
|
+
# create cumulative sum of the grid counts for each channel
|
|
408
|
+
cumulative_indices = np.cumsum(self.total_grid_count()[1][channel_idx])
|
|
409
|
+
# find the channel index
|
|
410
|
+
sample_idx = np.searchsorted(cumulative_indices, patch_idx, side="right")
|
|
411
|
+
sample_shape = self.data_shapes[channel_idx][sample_idx]
|
|
412
|
+
# TODO duplicated runs, revisit
|
|
413
|
+
# ingoring the channel dimension because we index it explicitly
|
|
414
|
+
grid_count = self.grid_count_per_sample(sample_shape)[1:]
|
|
415
|
+
|
|
416
|
+
grid_idx = []
|
|
417
|
+
for i in range(len(grid_count) - 1, -1, -1):
|
|
418
|
+
stride = np.prod(grid_count[:i]) if i > 0 else 1
|
|
419
|
+
grid_idx.insert(0, patch_idx // stride)
|
|
420
|
+
patch_idx %= stride
|
|
421
|
+
# TODO check for 3D !
|
|
422
|
+
# adding channel index
|
|
423
|
+
grid_idx = [channel_idx] + grid_idx
|
|
424
|
+
location = [
|
|
425
|
+
sample_idx,
|
|
426
|
+
] + [
|
|
427
|
+
self.get_gridstart_location_from_dim_index(
|
|
428
|
+
shape=sample_shape, dim_idx=dim_idx, dim=grid_idx[dim_idx]
|
|
429
|
+
)
|
|
430
|
+
for dim_idx in range(len(grid_idx))
|
|
431
|
+
]
|
|
432
|
+
return tuple(location)
|
|
433
|
+
|
|
434
|
+
def get_location_from_patch_idx_o(self, dataset_idx: int):
|
|
435
|
+
"""
|
|
436
|
+
Returns the start location of the grid in the dataset.
|
|
437
|
+
"""
|
|
438
|
+
grid_idx = []
|
|
439
|
+
for dim in range(len(self.data_shape)):
|
|
440
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
441
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
|
442
|
+
location = [
|
|
443
|
+
self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
444
|
+
for dim in range(len(self.data_shape))
|
|
445
|
+
]
|
|
446
|
+
return tuple(location)
|
|
447
|
+
|
|
448
|
+
def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
|
|
449
|
+
"""
|
|
450
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
|
451
|
+
"""
|
|
452
|
+
assert dim < len(
|
|
453
|
+
self.data_shapes
|
|
454
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
|
|
455
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
456
|
+
|
|
457
|
+
if dim > 0:
|
|
458
|
+
dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
459
|
+
|
|
460
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
|
461
|
+
if only_end:
|
|
462
|
+
return dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
463
|
+
|
|
464
|
+
return (
|
|
465
|
+
dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
469
|
+
"""
|
|
470
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
471
|
+
"""
|
|
472
|
+
assert dim < len(
|
|
473
|
+
self.data_shapes
|
|
474
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
|
|
475
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
476
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
|
477
|
+
if new_idx >= self.total_grid_count():
|
|
478
|
+
return None
|
|
479
|
+
return new_idx
|
|
480
|
+
|
|
481
|
+
def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
482
|
+
"""
|
|
483
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
484
|
+
"""
|
|
485
|
+
assert dim < len(
|
|
486
|
+
self.data_shapes
|
|
487
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
|
|
488
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
489
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
|
490
|
+
if new_idx < 0:
|
|
491
|
+
return None
|