ngio 0.2.0a2__py3-none-any.whl → 0.5.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. ngio/__init__.py +40 -12
  2. ngio/common/__init__.py +16 -32
  3. ngio/common/_dimensions.py +270 -48
  4. ngio/common/_masking_roi.py +153 -0
  5. ngio/common/_pyramid.py +267 -73
  6. ngio/common/_roi.py +290 -66
  7. ngio/common/_synt_images_utils.py +101 -0
  8. ngio/common/_zoom.py +54 -22
  9. ngio/experimental/__init__.py +5 -0
  10. ngio/experimental/iterators/__init__.py +15 -0
  11. ngio/experimental/iterators/_abstract_iterator.py +390 -0
  12. ngio/experimental/iterators/_feature.py +189 -0
  13. ngio/experimental/iterators/_image_processing.py +130 -0
  14. ngio/experimental/iterators/_mappers.py +48 -0
  15. ngio/experimental/iterators/_rois_utils.py +126 -0
  16. ngio/experimental/iterators/_segmentation.py +235 -0
  17. ngio/hcs/__init__.py +17 -58
  18. ngio/hcs/_plate.py +1354 -0
  19. ngio/images/__init__.py +30 -9
  20. ngio/images/_abstract_image.py +968 -0
  21. ngio/images/_create_synt_container.py +132 -0
  22. ngio/images/_create_utils.py +423 -0
  23. ngio/images/_image.py +926 -0
  24. ngio/images/_label.py +417 -0
  25. ngio/images/_masked_image.py +531 -0
  26. ngio/images/_ome_zarr_container.py +1235 -0
  27. ngio/images/_table_ops.py +471 -0
  28. ngio/io_pipes/__init__.py +75 -0
  29. ngio/io_pipes/_io_pipes.py +361 -0
  30. ngio/io_pipes/_io_pipes_masked.py +488 -0
  31. ngio/io_pipes/_io_pipes_roi.py +146 -0
  32. ngio/io_pipes/_io_pipes_types.py +56 -0
  33. ngio/io_pipes/_match_shape.py +377 -0
  34. ngio/io_pipes/_ops_axes.py +344 -0
  35. ngio/io_pipes/_ops_slices.py +411 -0
  36. ngio/io_pipes/_ops_slices_utils.py +199 -0
  37. ngio/io_pipes/_ops_transforms.py +104 -0
  38. ngio/io_pipes/_zoom_transform.py +180 -0
  39. ngio/ome_zarr_meta/__init__.py +39 -15
  40. ngio/ome_zarr_meta/_meta_handlers.py +490 -96
  41. ngio/ome_zarr_meta/ngio_specs/__init__.py +24 -10
  42. ngio/ome_zarr_meta/ngio_specs/_axes.py +268 -234
  43. ngio/ome_zarr_meta/ngio_specs/_channels.py +125 -41
  44. ngio/ome_zarr_meta/ngio_specs/_dataset.py +42 -87
  45. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +536 -2
  46. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +202 -198
  47. ngio/ome_zarr_meta/ngio_specs/_pixel_size.py +72 -34
  48. ngio/ome_zarr_meta/v04/__init__.py +21 -5
  49. ngio/ome_zarr_meta/v04/_custom_models.py +18 -0
  50. ngio/ome_zarr_meta/v04/{_v04_spec_utils.py → _v04_spec.py} +151 -90
  51. ngio/ome_zarr_meta/v05/__init__.py +27 -0
  52. ngio/ome_zarr_meta/v05/_custom_models.py +18 -0
  53. ngio/ome_zarr_meta/v05/_v05_spec.py +511 -0
  54. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  55. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  56. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  57. ngio/resources/__init__.py +55 -0
  58. ngio/resources/resource_model.py +36 -0
  59. ngio/tables/__init__.py +20 -4
  60. ngio/tables/_abstract_table.py +270 -0
  61. ngio/tables/_tables_container.py +449 -0
  62. ngio/tables/backends/__init__.py +50 -1
  63. ngio/tables/backends/_abstract_backend.py +200 -31
  64. ngio/tables/backends/_anndata.py +139 -0
  65. ngio/tables/backends/_anndata_utils.py +10 -114
  66. ngio/tables/backends/_csv.py +19 -0
  67. ngio/tables/backends/_json.py +92 -0
  68. ngio/tables/backends/_parquet.py +19 -0
  69. ngio/tables/backends/_py_arrow_backends.py +222 -0
  70. ngio/tables/backends/_table_backends.py +162 -38
  71. ngio/tables/backends/_utils.py +608 -0
  72. ngio/tables/v1/__init__.py +19 -4
  73. ngio/tables/v1/_condition_table.py +71 -0
  74. ngio/tables/v1/_feature_table.py +79 -115
  75. ngio/tables/v1/_generic_table.py +21 -90
  76. ngio/tables/v1/_roi_table.py +486 -137
  77. ngio/transforms/__init__.py +5 -0
  78. ngio/transforms/_zoom.py +19 -0
  79. ngio/utils/__init__.py +16 -14
  80. ngio/utils/_cache.py +48 -0
  81. ngio/utils/_datasets.py +121 -13
  82. ngio/utils/_fractal_fsspec_store.py +42 -0
  83. ngio/utils/_zarr_utils.py +374 -218
  84. ngio-0.5.0b4.dist-info/METADATA +147 -0
  85. ngio-0.5.0b4.dist-info/RECORD +88 -0
  86. {ngio-0.2.0a2.dist-info → ngio-0.5.0b4.dist-info}/WHEEL +1 -1
  87. ngio/common/_array_pipe.py +0 -160
  88. ngio/common/_axes_transforms.py +0 -63
  89. ngio/common/_common_types.py +0 -5
  90. ngio/common/_slicer.py +0 -97
  91. ngio/images/abstract_image.py +0 -240
  92. ngio/images/create.py +0 -251
  93. ngio/images/image.py +0 -389
  94. ngio/images/label.py +0 -236
  95. ngio/images/omezarr_container.py +0 -535
  96. ngio/ome_zarr_meta/_generic_handlers.py +0 -320
  97. ngio/ome_zarr_meta/v04/_meta_handlers.py +0 -54
  98. ngio/tables/_validators.py +0 -192
  99. ngio/tables/backends/_anndata_v1.py +0 -75
  100. ngio/tables/backends/_json_v1.py +0 -56
  101. ngio/tables/tables_container.py +0 -300
  102. ngio/tables/v1/_masking_roi_table.py +0 -175
  103. ngio/utils/_logger.py +0 -29
  104. ngio-0.2.0a2.dist-info/METADATA +0 -95
  105. ngio-0.2.0a2.dist-info/RECORD +0 -53
  106. {ngio-0.2.0a2.dist-info → ngio-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,411 @@
1
+ import math
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import TypeAlias, assert_never
4
+ from warnings import warn
5
+
6
+ import dask.array as da
7
+ import numpy as np
8
+ import zarr
9
+ from pydantic import BaseModel, ConfigDict
10
+
11
+ from ngio.common._dimensions import Dimensions
12
+ from ngio.io_pipes._ops_slices_utils import compute_slice_chunks
13
+ from ngio.ome_zarr_meta.ngio_specs import Axis
14
+ from ngio.utils import NgioValueError
15
+
16
+ SlicingInputType: TypeAlias = slice | Sequence[int] | int | None
17
+ SlicingType: TypeAlias = slice | list[int] | int
18
+
19
+ ##############################################################
20
+ #
21
+ # "SlicingOps" model
22
+ #
23
+ ##############################################################
24
+
25
+
26
+ def _int_boundary_check(value: int, shape: int) -> int:
27
+ """Ensure that the integer value is within the boundaries of the array shape."""
28
+ if value < 0 or value >= shape:
29
+ raise NgioValueError(
30
+ f"Invalid index {value}. Index is out of bounds for axis with size {shape}."
31
+ )
32
+ return value
33
+
34
+
35
+ def _slicing_tuple_boundary_check(
36
+ slicing_tuple: tuple[SlicingType, ...],
37
+ array_shape: tuple[int, ...],
38
+ ) -> tuple[SlicingType, ...]:
39
+ """Ensure that the slicing tuple is within the boundaries of the array shape.
40
+
41
+ This function normalizes the slicing tuple to ensure that the selection
42
+ is within the boundaries of the array shape.
43
+ """
44
+ if len(slicing_tuple) != len(array_shape):
45
+ raise NgioValueError(
46
+ f"Invalid slicing tuple {slicing_tuple}. "
47
+ f"Length {len(slicing_tuple)} does not match array shape {array_shape}."
48
+ )
49
+ out_slicing_tuple = []
50
+ for sl, sh in zip(slicing_tuple, array_shape, strict=True):
51
+ if isinstance(sl, slice):
52
+ start, stop, step = sl.start, sl.stop, sl.step
53
+ if start is not None:
54
+ start = math.floor(start)
55
+ start = max(0, min(start, sh))
56
+ if stop is not None:
57
+ stop = math.ceil(stop)
58
+ stop = max(0, min(stop, sh))
59
+ out_slicing_tuple.append(slice(start, stop, step))
60
+ elif isinstance(sl, int):
61
+ _int_boundary_check(sl, shape=sh)
62
+ out_slicing_tuple.append(sl)
63
+ elif isinstance(sl, list):
64
+ [_int_boundary_check(i, shape=sh) for i in sl]
65
+ out_slicing_tuple.append(sl)
66
+ else:
67
+ assert_never(sl)
68
+
69
+ return tuple(out_slicing_tuple)
70
+
71
+
72
+ class SlicingOps(BaseModel):
73
+ """Class to hold slicing operations."""
74
+
75
+ on_disk_axes: tuple[str, ...]
76
+ on_disk_shape: tuple[int, ...]
77
+ on_disk_chunks: tuple[int, ...]
78
+ slicing_tuple: tuple[SlicingType, ...]
79
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
80
+
81
+ @property
82
+ def normalized_slicing_tuple(self) -> tuple[SlicingType, ...]:
83
+ """Normalize the slicing tuple to be within the array shape boundaries."""
84
+ return _slicing_tuple_boundary_check(
85
+ slicing_tuple=self.slicing_tuple,
86
+ array_shape=self.on_disk_shape,
87
+ )
88
+
89
+ @property
90
+ def slice_axes(self) -> tuple[str, ...]:
91
+ """The axes after slicing."""
92
+ in_memory_axes = []
93
+ for ax, sl in zip(self.on_disk_axes, self.slicing_tuple, strict=True):
94
+ if isinstance(sl, int):
95
+ continue
96
+ in_memory_axes.append(ax)
97
+ return tuple(in_memory_axes)
98
+
99
+ def slice_chunks(self) -> set[tuple[int, ...]]:
100
+ """The required to read or write the slice."""
101
+ return compute_slice_chunks(
102
+ shape=self.on_disk_shape,
103
+ chunks=self.on_disk_chunks,
104
+ slicing_tuple=self.normalized_slicing_tuple,
105
+ )
106
+
107
+ def get(self, ax_name: str, normalize: bool = False) -> SlicingType:
108
+ """Get the slicing tuple."""
109
+ slicing_tuple = (
110
+ self.slicing_tuple if not normalize else self.normalized_slicing_tuple
111
+ )
112
+ if ax_name not in self.on_disk_axes:
113
+ return slice(None)
114
+ ax_index = self.on_disk_axes.index(ax_name)
115
+ return slicing_tuple[ax_index]
116
+
117
+
118
+ def _check_list_in_slicing_tuple(
119
+ slicing_tuple: tuple[SlicingType, ...],
120
+ ) -> tuple[None, None] | tuple[int, list[int]]:
121
+ """Check if there are any lists in the slicing tuple.
122
+
123
+ Dask regions when setting data do not support non-contiguous
124
+ selections natively.
125
+ Ngio support a single list in the slicing tuple to allow non-contiguous
126
+ selection (main use case: selecting multiple channels).
127
+ """
128
+ # Find if the is any list in the slicing tuple
129
+ # If there is one we need to handle it differently
130
+ list_in_slice = [(i, s) for i, s in enumerate(slicing_tuple) if isinstance(s, list)]
131
+ if not list_in_slice:
132
+ # No list in the slicing tuple
133
+ return None, None
134
+
135
+ if len(list_in_slice) > 1:
136
+ raise NotImplementedError(
137
+ "Slicing with multiple non-contiguous tuples/lists "
138
+ "is not supported yet in Ngio. Use directly the "
139
+ "zarr.Array api to get the correct array slice."
140
+ )
141
+ # Complex case, we have exactly one tuple in the slicing tuple
142
+ ax, first_tuple = list_in_slice[0]
143
+ if len(first_tuple) > 100:
144
+ warn(
145
+ "Performance warning: "
146
+ "Non-contiguous slicing with a tuple/list with more than 100 elements is "
147
+ "not natively supported by zarr. This is implemented by Ngio by performing "
148
+ "multiple reads and stacking the result.",
149
+ stacklevel=2,
150
+ )
151
+ return ax, first_tuple
152
+
153
+
154
+ ##############################################################
155
+ #
156
+ # Slicing implementations
157
+ #
158
+ ##############################################################
159
+
160
+
161
+ def get_slice_as_numpy(zarr_array: zarr.Array, slicing_ops: SlicingOps) -> np.ndarray:
162
+ """Get a slice of a zarr array as a numpy array."""
163
+ slicing_tuple = slicing_ops.normalized_slicing_tuple
164
+ # Find if the is any tuple in the slicing tuple
165
+ # If there is one we need to handle it differently
166
+ return zarr_array[slicing_tuple]
167
+
168
+
169
+ def get_slice_as_dask(zarr_array: zarr.Array, slicing_ops: SlicingOps) -> da.Array:
170
+ """Get a slice of a zarr array as a dask array."""
171
+ da_array = da.from_zarr(zarr_array)
172
+ slicing_tuple = slicing_ops.normalized_slicing_tuple
173
+ return da_array[slicing_tuple]
174
+
175
+
176
+ def set_slice_as_numpy(
177
+ zarr_array: zarr.Array,
178
+ patch: np.ndarray,
179
+ slicing_ops: SlicingOps,
180
+ ) -> None:
181
+ slice_tuple = slicing_ops.normalized_slicing_tuple
182
+ zarr_array[slice_tuple] = patch
183
+
184
+
185
+ def handle_int_set_as_dask(
186
+ patch: da.Array,
187
+ slicing_tuple: tuple[SlicingType, ...],
188
+ ) -> tuple[da.Array, tuple[SlicingType, ...]]:
189
+ """Handle the case where the slicing tuple contains integers.
190
+
191
+ In this case we need to expand the patch array to match the slicing tuple.
192
+ """
193
+ new_slicing_tuple = list(slicing_tuple)
194
+ for i, sl in enumerate(slicing_tuple):
195
+ if isinstance(sl, int):
196
+ patch = da.expand_dims(patch, axis=i)
197
+ new_slicing_tuple[i] = slice(sl, sl + 1)
198
+ return patch, tuple(new_slicing_tuple)
199
+
200
+
201
+ def set_slice_as_dask(
202
+ zarr_array: zarr.Array, patch: da.Array, slicing_ops: SlicingOps
203
+ ) -> None:
204
+ slice_tuple = slicing_ops.normalized_slicing_tuple
205
+ ax, first_tuple = _check_list_in_slicing_tuple(slice_tuple)
206
+ patch, slice_tuple = handle_int_set_as_dask(patch, slice_tuple)
207
+ if ax is None:
208
+ # Base case, no tuple in the slicing tuple
209
+ # assert False
210
+ da.to_zarr(arr=patch, url=zarr_array, region=slice_tuple)
211
+ return
212
+
213
+ # Complex case, we have exactly one tuple in the slicing tuple
214
+ assert first_tuple is not None
215
+ for i, idx in enumerate(first_tuple):
216
+ _sub_slice = (*slice_tuple[:ax], slice(idx, idx + 1), *slice_tuple[ax + 1 :])
217
+ sub_patch = da.take(patch, indices=i, axis=ax)
218
+ sub_patch = da.expand_dims(sub_patch, axis=ax)
219
+ da.to_zarr(arr=sub_patch, url=zarr_array, region=_sub_slice)
220
+
221
+
222
+ ##############################################################
223
+ #
224
+ # Builder functions
225
+ #
226
+ ##############################################################
227
+
228
+
229
+ def _try_to_slice(value: Sequence[int]) -> slice | list[int]:
230
+ """Try to convert a list of integers into a slice if they are contiguous.
231
+
232
+ - If the input is empty, return an empty tuple.
233
+ - If the input is sorted, and contains contiguous integers,
234
+ return a slice from the minimum to the maximum integer.
235
+ - Otherwise, return the input as a list of integers.
236
+
237
+ This is useful for optimizing array slicing operations
238
+ by allowing the use of slices when possible, which can be more efficient.
239
+ """
240
+ if not value:
241
+ raise NgioValueError("Ngio does not support empty sequences as slice input.")
242
+
243
+ if not all(isinstance(i, int) for i in value):
244
+ _value = []
245
+ for i in value:
246
+ try:
247
+ _value.append(int(i))
248
+ except Exception as e:
249
+ raise NgioValueError(
250
+ f"Invalid value {i} of type {type(i)} in sequence {value}"
251
+ ) from e
252
+ value = _value
253
+ # If the input is not sorted, return it as a tuple
254
+ max_input = max(value)
255
+ min_input = min(value)
256
+ assert min_input >= 0, "Input must contain non-negative integers"
257
+
258
+ if sorted(value) == list(range(min_input, max_input + 1)):
259
+ return slice(min_input, max_input + 1)
260
+
261
+ return list(value)
262
+
263
+
264
+ def _remove_channel_slicing(
265
+ slicing_dict: dict[str, SlicingInputType],
266
+ dimensions: Dimensions,
267
+ ) -> dict[str, SlicingInputType]:
268
+ """This utility function removes the channel selection from the slice kwargs.
269
+
270
+ if ignore_channel_selection is True, it will remove the channel selection
271
+ regardless of the dimensions. If the ignore_channel_selection is False
272
+ it will fail.
273
+ """
274
+ if dimensions.is_multi_channels:
275
+ return slicing_dict
276
+
277
+ if "c" in slicing_dict:
278
+ slicing_dict.pop("c", None)
279
+ return slicing_dict
280
+
281
+
282
+ def _check_slicing_virtual_axes(slice_: SlicingInputType) -> bool:
283
+ """Check if the slice_ is compatible with virtual axes.
284
+
285
+ Virtual axes are axes that are not present in the actual data,
286
+ such as time or channel axes in some datasets.
287
+ So the only valid slices for virtual axes are:
288
+ - None: means all data along the axis
289
+ - 0: means the first element along the axis
290
+ - slice([0, None], [1, None])
291
+ """
292
+ if slice_ is None or slice_ == 0:
293
+ return True
294
+ if isinstance(slice_, slice):
295
+ if slice_.start is None and slice_.stop is None:
296
+ return True
297
+ if slice_.start == 0 and slice_.stop is None:
298
+ return True
299
+ if slice_.start is None and slice_.stop == 0:
300
+ return True
301
+ if slice_.start == 0 and slice_.stop == 1:
302
+ return True
303
+ if isinstance(slice_, Sequence):
304
+ if len(slice_) == 1 and slice_[0] == 0:
305
+ return True
306
+ return False
307
+
308
+
309
+ def _clean_slicing_dict(
310
+ dimensions: Dimensions,
311
+ slicing_dict: Mapping[str, SlicingInputType],
312
+ remove_channel_selection: bool = False,
313
+ ) -> dict[str, SlicingInputType]:
314
+ """Clean the slicing dict.
315
+
316
+ This function will:
317
+ - Validate that the axes in the slicing_dict are present in the dimensions.
318
+ - Make sure that the slicing_dict uses the on-disk axis names.
319
+ - Check for duplicate axis names in the slicing_dict.
320
+ - Clean up channel selection if the dimensions
321
+ """
322
+ clean_slicing_dict: dict[str, SlicingInputType] = {}
323
+ for axis_name, slice_ in slicing_dict.items():
324
+ axis = dimensions.axes_handler.get_axis(axis_name)
325
+ if axis is None:
326
+ # Virtual axes should be allowed to be selected
327
+ # Common use case is still allowing channel_selection
328
+ # When the zarr has not channel axis.
329
+ if not _check_slicing_virtual_axes(slice_):
330
+ raise NgioValueError(
331
+ f"Invalid axis selection:{axis_name}={slice_}. "
332
+ f"Not found on the on-disk axes {dimensions.axes}."
333
+ )
334
+ # Virtual axes can be safely ignored
335
+ continue
336
+ if axis.name in clean_slicing_dict:
337
+ raise NgioValueError(
338
+ f"Duplicate axis {axis.name} in slice kwargs. "
339
+ "Please provide unique axis names."
340
+ )
341
+ clean_slicing_dict[axis.name] = slice_
342
+
343
+ if remove_channel_selection:
344
+ clean_slicing_dict = _remove_channel_slicing(
345
+ slicing_dict=clean_slicing_dict, dimensions=dimensions
346
+ )
347
+ return clean_slicing_dict
348
+
349
+
350
+ def _normalize_slicing_tuple(
351
+ axis: Axis,
352
+ slicing_dict: dict[str, SlicingInputType],
353
+ ) -> SlicingType:
354
+ """Normalize the slicing dict to tuple.
355
+
356
+ Since the slicing dict can contain different types of values
357
+ We need to normalize them to more predictable types.
358
+ The output types are:
359
+ - slice
360
+ - int
361
+ - list of int (for non-contiguous selection)
362
+ """
363
+ axis_name = axis.name
364
+ if axis_name not in slicing_dict:
365
+ # If no slice is provided for the axis, use a full slice
366
+ return slice(None)
367
+
368
+ value = slicing_dict[axis_name]
369
+ if value is None:
370
+ return slice(None)
371
+ if isinstance(value, slice) or isinstance(value, int):
372
+ return value
373
+ elif isinstance(value, Sequence):
374
+ # If a contiguous sequence of integers is provided,
375
+ # convert it to a slice for simplicity.
376
+ # Alternatively, it will be converted to a list of ints
377
+ return _try_to_slice(value)
378
+
379
+ raise NgioValueError(
380
+ f"Invalid slice definition {value} of type {type(value)}. "
381
+ "Allowed types are: int, slice, sequence of int or None."
382
+ )
383
+
384
+
385
+ def build_slicing_ops(
386
+ *,
387
+ dimensions: Dimensions,
388
+ slicing_dict: dict[str, SlicingInputType] | None,
389
+ remove_channel_selection: bool = False,
390
+ ) -> SlicingOps:
391
+ """Assemble slices to be used to query the array."""
392
+ slicing_dict = slicing_dict or {}
393
+ _slicing_dict = _clean_slicing_dict(
394
+ dimensions=dimensions,
395
+ slicing_dict=slicing_dict,
396
+ remove_channel_selection=remove_channel_selection,
397
+ )
398
+
399
+ slicing_tuple = tuple(
400
+ _normalize_slicing_tuple(
401
+ axis=axis,
402
+ slicing_dict=_slicing_dict,
403
+ )
404
+ for axis in dimensions.axes_handler.axes
405
+ )
406
+ return SlicingOps(
407
+ on_disk_axes=dimensions.axes_handler.axes_names,
408
+ on_disk_shape=dimensions.shape,
409
+ on_disk_chunks=dimensions.chunks,
410
+ slicing_tuple=slicing_tuple,
411
+ )
@@ -0,0 +1,199 @@
1
+ import warnings
2
+ from collections.abc import Iterable, Iterator
3
+ from itertools import product
4
+ from typing import TypeAlias, TypeVar
5
+
6
+ from ngio.utils import NgioValueError
7
+
8
+ T = TypeVar("T")
9
+
10
+ ##############################################################
11
+ #
12
+ # Check slice overlaps
13
+ #
14
+ ##############################################################
15
+
16
+
17
+ def _pairs_stream(iterable: Iterable[T]) -> Iterator[tuple[T, T]]:
18
+ # Same as combinations but yields pairs as soon as they are generated
19
+ seen: list[T] = []
20
+ for a in iterable:
21
+ for b in seen:
22
+ yield b, a
23
+ seen.append(a)
24
+
25
+
26
+ SlicingType: TypeAlias = slice | list[int] | int
27
+
28
+
29
+ def check_elem_intersection(s1: SlicingType, s2: SlicingType) -> bool:
30
+ """Compare if two SlicingType elements intersect.
31
+
32
+ If they are a slice, check if they overlap.
33
+ If they are integers, check if they are equal.
34
+ If they are lists, check if they have any common elements.
35
+ """
36
+ if not isinstance(s1, type(s2)):
37
+ raise NgioValueError(
38
+ f"Slices must be of the same type. Got {type(s1)} and {type(s2)}"
39
+ )
40
+
41
+ if isinstance(s1, slice) and isinstance(s2, slice):
42
+ # Handle slice objects
43
+ start1, stop1, step1 = s1.start or 0, s1.stop or float("inf"), s1.step or 1
44
+ start2, stop2, step2 = s2.start or 0, s2.stop or float("inf"), s2.step or 1
45
+
46
+ if step1 is not None and step2 != 1:
47
+ raise NotImplementedError(
48
+ "Intersection for slices with step != 1 is not implemented"
49
+ )
50
+
51
+ if step2 is not None and step1 != 1:
52
+ raise NotImplementedError(
53
+ "Intersection for slices with step != 1 is not implemented"
54
+ )
55
+
56
+ return not (stop1 <= start2 or stop2 <= start1)
57
+ elif isinstance(s1, int) and isinstance(s2, int):
58
+ # Handle integer indices
59
+ return s1 == s2
60
+ elif isinstance(s1, list) and isinstance(s2, list):
61
+ if set(s1) & set(s2):
62
+ return True
63
+ return False
64
+ else:
65
+ raise TypeError("Unsupported slice type")
66
+
67
+
68
+ def check_slicing_tuple_intersection(
69
+ s1: tuple[SlicingType, ...], s2: tuple[SlicingType, ...]
70
+ ) -> bool:
71
+ """For a tuple of SlicingType, check if all elements intersect."""
72
+ if len(s1) != len(s2):
73
+ raise NgioValueError("Slices must have the same length")
74
+ return all(check_elem_intersection(a, b) for a, b in zip(s1, s2, strict=True))
75
+
76
+
77
+ def check_if_regions_overlap(slices: Iterable[tuple[SlicingType, ...]]) -> bool:
78
+ """Check for overlaps in a list of slicing tuples using brute-force method.
79
+
80
+ This is O(n^2) and not efficient for large lists.
81
+ Returns True if any overlaps are found.
82
+ """
83
+ for it, (si, sj) in enumerate(_pairs_stream(slices)):
84
+ overalap = check_slicing_tuple_intersection(si, sj)
85
+ if overalap:
86
+ return True
87
+
88
+ if it == 10_000:
89
+ warnings.warn(
90
+ "Performance Warning check_for_overlaps is O(n^2) and may be slow for "
91
+ "large numbers of regions.",
92
+ stacklevel=2,
93
+ )
94
+ return False
95
+
96
+
97
+ ##############################################################
98
+ #
99
+ # Check chunk overlaps
100
+ #
101
+ ##############################################################
102
+
103
+
104
+ def _normalize_slice(slc: slice, size: int) -> tuple[int, int]:
105
+ if slc.step not in (None, 1):
106
+ raise NgioValueError(f"Only step=1 slices supported, got step={slc.step}")
107
+ start = 0 if slc.start is None else slc.start
108
+ stop = size if slc.stop is None else slc.stop
109
+ if start < 0 or stop < 0:
110
+ raise NgioValueError("Negative slice bounds are not supported")
111
+ # clamp to [0, size]
112
+ start = min(start, size)
113
+ stop = min(stop, size)
114
+ if start > stop:
115
+ # empty selection
116
+ return (0, 0)
117
+ return start, stop
118
+
119
+
120
+ def _chunk_indices_for_axis(sel: SlicingType, size: int, csize: int) -> list[int]:
121
+ """From a selection for a single axis, return the list chunk indices touched."""
122
+ if isinstance(sel, slice):
123
+ start, stop = _normalize_slice(sel, size)
124
+ if start >= stop: # empty
125
+ return []
126
+ first = start // csize
127
+ last = (stop - 1) // csize
128
+ return list(range(first, last + 1))
129
+
130
+ if isinstance(sel, int):
131
+ if sel < 0 or sel >= size:
132
+ raise IndexError(f"index {sel} out of bounds for axis of size {size}")
133
+ return [sel // csize]
134
+
135
+ if isinstance(sel, list):
136
+ if not sel:
137
+ return []
138
+ chunks_hit = {}
139
+ for v in sel:
140
+ if not isinstance(v, int):
141
+ raise TypeError("Only integers allowed inside tuple selections")
142
+ if v < 0 or v >= size:
143
+ raise IndexError(f"index {v} out of bounds for axis of size {size}")
144
+ chunks_hit[v // csize] = None
145
+ return sorted(chunks_hit.keys())
146
+
147
+ raise TypeError(f"Unsupported index type: {type(sel)!r}")
148
+
149
+
150
+ def compute_slice_chunks(
151
+ shape: tuple[int, ...],
152
+ chunks: tuple[int, ...],
153
+ slicing_tuple: tuple[SlicingType, ...],
154
+ ) -> set[tuple[int, ...]]:
155
+ """Compute the set of chunk coordinates touched by `slicing_tuple`.
156
+
157
+ Args:
158
+ shape: overall array shape (s1, s2, ...)
159
+ chunks: chunk shape (c1, c2, ...)
160
+ slicing_tuple: tuple of slices, ints, or tuples of ints
161
+ """
162
+ if len(slicing_tuple) != len(shape):
163
+ raise NgioValueError(
164
+ f"key must have {len(shape)} items, got {len(slicing_tuple)}"
165
+ )
166
+
167
+ per_axis_chunks: list[list[int]] = [
168
+ _chunk_indices_for_axis(sel, size, csize)
169
+ for sel, size, csize in zip(slicing_tuple, shape, chunks, strict=True)
170
+ ]
171
+
172
+ # If any axis yields no chunks, the overall selection is empty.
173
+ if any(len(ax) == 0 for ax in per_axis_chunks):
174
+ return set()
175
+
176
+ return {tuple(idx) for idx in product(*per_axis_chunks)}
177
+
178
+
179
+ def check_if_chunks_overlap(
180
+ slices: Iterable[tuple[SlicingType, ...]],
181
+ shape: tuple[int, ...],
182
+ chunks: tuple[int, ...],
183
+ ) -> bool:
184
+ """Check for overlaps in a list of slicing tuples using brute-force method.
185
+
186
+ This is O(n^2) and not efficient for large lists.
187
+ Returns True if any overlaps are found.
188
+ """
189
+ slices_chunks = (compute_slice_chunks(shape, chunks, si) for si in slices)
190
+ for it, (si, sj) in enumerate(_pairs_stream(slices_chunks)):
191
+ if si & sj:
192
+ return True
193
+ if it == 10_000:
194
+ warnings.warn(
195
+ "Performance Warning check_for_chunks_overlaps is O(n^2) and may be "
196
+ "slow for large numbers of regions.",
197
+ stacklevel=2,
198
+ )
199
+ return False