ngio 0.4.0a2__py3-none-any.whl → 0.4.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ngio/__init__.py +1 -2
  2. ngio/common/__init__.py +2 -51
  3. ngio/common/_dimensions.py +223 -64
  4. ngio/common/_pyramid.py +42 -23
  5. ngio/common/_roi.py +94 -418
  6. ngio/common/_zoom.py +32 -7
  7. ngio/experimental/iterators/_abstract_iterator.py +2 -2
  8. ngio/experimental/iterators/_feature.py +10 -15
  9. ngio/experimental/iterators/_image_processing.py +18 -28
  10. ngio/experimental/iterators/_rois_utils.py +6 -6
  11. ngio/experimental/iterators/_segmentation.py +38 -54
  12. ngio/images/_abstract_image.py +136 -94
  13. ngio/images/_create.py +16 -0
  14. ngio/images/_create_synt_container.py +10 -0
  15. ngio/images/_image.py +33 -9
  16. ngio/images/_label.py +24 -3
  17. ngio/images/_masked_image.py +60 -81
  18. ngio/images/_ome_zarr_container.py +34 -1
  19. ngio/io_pipes/__init__.py +49 -0
  20. ngio/io_pipes/_io_pipes.py +286 -0
  21. ngio/io_pipes/_io_pipes_masked.py +481 -0
  22. ngio/io_pipes/_io_pipes_roi.py +143 -0
  23. ngio/io_pipes/_io_pipes_utils.py +299 -0
  24. ngio/io_pipes/_match_shape.py +376 -0
  25. ngio/io_pipes/_ops_axes.py +146 -0
  26. ngio/io_pipes/_ops_slices.py +218 -0
  27. ngio/io_pipes/_ops_transforms.py +104 -0
  28. ngio/io_pipes/_zoom_transform.py +175 -0
  29. ngio/ome_zarr_meta/__init__.py +6 -2
  30. ngio/ome_zarr_meta/ngio_specs/__init__.py +6 -4
  31. ngio/ome_zarr_meta/ngio_specs/_axes.py +182 -70
  32. ngio/ome_zarr_meta/ngio_specs/_dataset.py +47 -121
  33. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +30 -22
  34. ngio/ome_zarr_meta/ngio_specs/_pixel_size.py +17 -1
  35. ngio/ome_zarr_meta/v04/_v04_spec_utils.py +33 -30
  36. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  37. ngio/resources/__init__.py +1 -0
  38. ngio/resources/resource_model.py +1 -0
  39. ngio/tables/v1/_roi_table.py +11 -3
  40. ngio/{common/transforms → transforms}/__init__.py +1 -1
  41. ngio/transforms/_zoom.py +19 -0
  42. ngio/utils/_zarr_utils.py +5 -1
  43. {ngio-0.4.0a2.dist-info → ngio-0.4.0a4.dist-info}/METADATA +1 -1
  44. ngio-0.4.0a4.dist-info/RECORD +83 -0
  45. ngio/common/_array_io_pipes.py +0 -554
  46. ngio/common/_array_io_utils.py +0 -508
  47. ngio/common/transforms/_label.py +0 -12
  48. ngio/common/transforms/_zoom.py +0 -109
  49. ngio-0.4.0a2.dist-info/RECORD +0 -76
  50. {ngio-0.4.0a2.dist-info → ngio-0.4.0a4.dist-info}/WHEEL +0 -0
  51. {ngio-0.4.0a2.dist-info → ngio-0.4.0a4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,299 @@
1
+ from collections.abc import Mapping, Sequence
2
+ from typing import TypeAlias
3
+
4
+ from ngio.common._dimensions import Dimensions
5
+ from ngio.io_pipes._ops_slices import SlicingOps, SlicingType
6
+ from ngio.ome_zarr_meta.ngio_specs import Axis
7
+ from ngio.ome_zarr_meta.ngio_specs._axes import AxesOps
8
+ from ngio.utils import NgioValueError
9
+
10
+ SlicingInputType: TypeAlias = slice | Sequence[int] | int | None
11
+
12
+
13
+ def _try_to_slice(value: Sequence[int]) -> slice | tuple[int, ...]:
14
+ """Try to convert a list of integers into a slice if they are contiguous.
15
+
16
+ - If the input is empty, return an empty tuple.
17
+ - If the input is sorted, and contains contiguous integers,
18
+ return a slice from the minimum to the maximum integer.
19
+ - Otherwise, return the input as a tuple.
20
+
21
+ This is useful for optimizing array slicing operations
22
+ by allowing the use of slices when possible, which can be more efficient.
23
+ """
24
+ if not value:
25
+ raise NgioValueError("Ngio does not support empty sequences as slice input.")
26
+
27
+ if not all(isinstance(i, int) for i in value):
28
+ _value = []
29
+ for i in value:
30
+ try:
31
+ _value.append(int(i))
32
+ except Exception as e:
33
+ raise NgioValueError(
34
+ f"Invalid value {i} of type {type(i)} in sequence {value}"
35
+ ) from e
36
+ value = _value
37
+ # If the input is not sorted, return it as a tuple
38
+ max_input = max(value)
39
+ min_input = min(value)
40
+ assert min_input >= 0, "Input must contain non-negative integers"
41
+
42
+ if sorted(value) == list(range(min_input, max_input + 1)):
43
+ return slice(min_input, max_input + 1)
44
+
45
+ return tuple(value)
46
+
47
+
48
+ def _remove_channel_slicing(
49
+ slicing_dict: dict[str, SlicingInputType],
50
+ dimensions: Dimensions,
51
+ ) -> dict[str, SlicingInputType]:
52
+ """This utility function removes the channel selection from the slice kwargs.
53
+
54
+ if ignore_channel_selection is True, it will remove the channel selection
55
+ regardless of the dimensions. If the ignore_channel_selection is False
56
+ it will fail.
57
+ """
58
+ if dimensions.is_multi_channels:
59
+ return slicing_dict
60
+
61
+ if "c" in slicing_dict:
62
+ slicing_dict.pop("c", None)
63
+ return slicing_dict
64
+
65
+
66
+ def _check_slicing_virtual_axes(slice_: SlicingInputType) -> bool:
67
+ """Check if the slice_ is compatible with virtual axes.
68
+
69
+ Virtual axes are axes that are not present in the actual data,
70
+ such as time or channel axes in some datasets.
71
+ So the only valid slices for virtual axes are:
72
+ - None: means all data along the axis
73
+ - 0: means the first element along the axis
74
+ - slice([0, None], [1, None])
75
+ """
76
+ if slice_ is None or slice_ == 0:
77
+ return True
78
+ if isinstance(slice_, slice):
79
+ if slice_.start is None and slice_.stop is None:
80
+ return True
81
+ if slice_.start == 0 and slice_.stop is None:
82
+ return True
83
+ if slice_.start is None and slice_.stop == 0:
84
+ return True
85
+ if slice_.start == 0 and slice_.stop == 1:
86
+ return True
87
+ if isinstance(slice_, Sequence):
88
+ if len(slice_) == 1 and slice_[0] == 0:
89
+ return True
90
+ return False
91
+
92
+
93
+ def _clean_slicing_dict(
94
+ dimensions: Dimensions,
95
+ slicing_dict: Mapping[str, SlicingInputType],
96
+ remove_channel_selection: bool = False,
97
+ ) -> dict[str, SlicingInputType]:
98
+ """Clean the slicing dict.
99
+
100
+ This function will:
101
+ - Validate that the axes in the slicing_dict are present in the dimensions.
102
+ - Make sure that the slicing_dict uses the on-disk axis names.
103
+ - Check for duplicate axis names in the slicing_dict.
104
+ - Clean up channel selection if the dimensions
105
+ """
106
+ clean_slicing_dict: dict[str, SlicingInputType] = {}
107
+ for axis_name, slice_ in slicing_dict.items():
108
+ axis = dimensions.axes_handler.get_axis(axis_name)
109
+ if axis is None:
110
+ # Virtual axes should be allowed to be selected
111
+ # Common use case is still allowing channel_selection
112
+ # When the zarr has not channel axis.
113
+ if not _check_slicing_virtual_axes(slice_):
114
+ raise NgioValueError(
115
+ f"Invalid axis selection:{axis_name}={slice_}. "
116
+ f"Not found on the on-disk axes {dimensions.axes}."
117
+ )
118
+ # Virtual axes can be safely ignored
119
+ continue
120
+ if axis.name in clean_slicing_dict:
121
+ raise NgioValueError(
122
+ f"Duplicate axis {axis.name} in slice kwargs. "
123
+ "Please provide unique axis names."
124
+ )
125
+ clean_slicing_dict[axis.name] = slice_
126
+
127
+ if remove_channel_selection:
128
+ clean_slicing_dict = _remove_channel_slicing(
129
+ slicing_dict=clean_slicing_dict, dimensions=dimensions
130
+ )
131
+ return clean_slicing_dict
132
+
133
+
134
+ def _normalize_axes_order(
135
+ dimensions: Dimensions,
136
+ axes_order: Sequence[str],
137
+ ) -> list[str]:
138
+ """Convert axes order to the on-disk axes names.
139
+
140
+ In this way there is not unambiguity in the axes order.
141
+ """
142
+ new_axes_order = []
143
+ for axis_name in axes_order:
144
+ axis = dimensions.axes_handler.get_axis(axis_name)
145
+ if axis is None:
146
+ new_axes_order.append(axis_name)
147
+ else:
148
+ new_axes_order.append(axis.name)
149
+ return new_axes_order
150
+
151
+
152
+ def _normalize_slicing_tuple(
153
+ axis: Axis,
154
+ slicing_dict: dict[str, SlicingInputType],
155
+ no_axes_ops: bool,
156
+ axes_order: list[str],
157
+ ) -> tuple[SlicingType, str | None]:
158
+ """Normalize the slicing dict to tuple.
159
+
160
+ Since the slicing dict can contain different types of values
161
+ We need to normalize them to more predictable types.
162
+ The output types are:
163
+ - slice
164
+ - int
165
+ - tuple of int (for non-contiguous selection)
166
+ """
167
+ axis_name = axis.name
168
+ if axis_name not in slicing_dict:
169
+ # If no slice is provided for the axis, use a full slice
170
+ return slice(None), None
171
+
172
+ value = slicing_dict[axis_name]
173
+ if value is None:
174
+ return slice(None), None
175
+
176
+ if isinstance(value, slice):
177
+ return value, None
178
+ elif isinstance(value, int):
179
+ # If axes ops are requested, we need to preserve the dimension
180
+ # When we slice because the axes ops will be applied later
181
+ # If no axes ops are requested, we can safely keep the integer
182
+ # which will remove the dimension
183
+ if (not no_axes_ops) or (axis_name in axes_order):
184
+ # Axes ops require all dimensions to be preserved
185
+ value = slice(value, value + 1)
186
+ return value, None
187
+ return value, axis_name
188
+ elif isinstance(value, Sequence):
189
+ # If a contiguous sequence of integers is provided,
190
+ # convert it to a slice for efficiency
191
+ # Alternatively, it will be converted to a tuple of ints
192
+ return _try_to_slice(value), None
193
+
194
+ raise NgioValueError(
195
+ f"Invalid slice definition {value} of type {type(value)}. "
196
+ "Allowed types are: int, slice, sequence of int or None."
197
+ )
198
+
199
+
200
+ def _build_slicing_tuple(
201
+ *,
202
+ dimensions: Dimensions,
203
+ slicing_dict: dict[str, SlicingInputType],
204
+ axes_order: list[str] | None = None,
205
+ no_axes_ops: bool = False,
206
+ remove_channel_selection: bool = False,
207
+ ) -> tuple[tuple[SlicingType, ...] | None, list[str]]:
208
+ """Assemble slices to be used to query the array."""
209
+ if len(slicing_dict) == 0:
210
+ # Skip unnecessary computation if no slicing is requested
211
+ return None, []
212
+ _axes_order = (
213
+ _normalize_axes_order(dimensions=dimensions, axes_order=axes_order)
214
+ if axes_order is not None
215
+ else []
216
+ )
217
+ _slicing_dict = _clean_slicing_dict(
218
+ dimensions=dimensions,
219
+ slicing_dict=slicing_dict,
220
+ remove_channel_selection=remove_channel_selection,
221
+ )
222
+
223
+ slicing_tuple = []
224
+ axes_to_remove = []
225
+ for axis in dimensions.axes_handler.axes:
226
+ sl, ax_to_remove = _normalize_slicing_tuple(
227
+ axis=axis,
228
+ slicing_dict=_slicing_dict,
229
+ no_axes_ops=no_axes_ops,
230
+ axes_order=_axes_order,
231
+ )
232
+ slicing_tuple.append(sl)
233
+ if ax_to_remove is not None:
234
+ axes_to_remove.append(ax_to_remove)
235
+ slicing_tuple = tuple(slicing_tuple)
236
+ # Slicing tuple can have only one element of type tuple
237
+ # If multiple tuple are present it will lead to errors
238
+ # when querying the array
239
+ if sum(isinstance(s, tuple) for s in slicing_tuple) > 1:
240
+ raise NgioValueError(
241
+ f"Invalid slicing tuple {slicing_tuple}. Ngio does not support "
242
+ "multiple non-contiguous selections (tuples) in the slicing tuple. "
243
+ "Please use slices or single integer selections instead."
244
+ )
245
+ return slicing_tuple, axes_to_remove
246
+
247
+
248
+ def _build_axes_ops(
249
+ *,
250
+ axes_order: Sequence[str] | None,
251
+ dimensions: Dimensions,
252
+ ) -> tuple[list[str] | None, AxesOps]:
253
+ if axes_order is None:
254
+ return None, AxesOps(
255
+ on_disk_axes=dimensions.axes_handler.axes_names,
256
+ in_memory_axes=dimensions.axes_handler.axes_names,
257
+ )
258
+
259
+ axes_order = _normalize_axes_order(dimensions=dimensions, axes_order=axes_order)
260
+ axes_ops = dimensions.axes_handler.get_axes_ops(axes_order)
261
+ return axes_order, axes_ops
262
+
263
+
264
+ def setup_io_pipe(
265
+ *,
266
+ dimensions: Dimensions,
267
+ slicing_dict: dict[str, SlicingInputType] | None = None,
268
+ axes_order: Sequence[str] | None = None,
269
+ remove_channel_selection: bool = False,
270
+ ) -> tuple[SlicingOps, AxesOps]:
271
+ """Setup the slicing tuple and axes ops for an IO pipe."""
272
+ slicing_dict = slicing_dict or {}
273
+ axes_order, axes_ops = _build_axes_ops(
274
+ axes_order=axes_order,
275
+ dimensions=dimensions,
276
+ )
277
+
278
+ slicing_tuple, axes_to_remove = _build_slicing_tuple(
279
+ dimensions=dimensions,
280
+ slicing_dict=slicing_dict,
281
+ axes_order=axes_order,
282
+ no_axes_ops=axes_ops.is_no_op,
283
+ remove_channel_selection=remove_channel_selection,
284
+ )
285
+
286
+ if axes_to_remove:
287
+ in_memory_axes = tuple(
288
+ ax for ax in axes_ops.in_memory_axes if ax not in axes_to_remove
289
+ )
290
+ axes_ops = AxesOps(
291
+ on_disk_axes=axes_ops.on_disk_axes,
292
+ in_memory_axes=in_memory_axes,
293
+ )
294
+ slicing_ops = SlicingOps(
295
+ on_disk_axes=dimensions.axes_handler.axes_names,
296
+ slicing_tuple=slicing_tuple,
297
+ on_disk_shape=dimensions.shape,
298
+ )
299
+ return slicing_ops, axes_ops
@@ -0,0 +1,376 @@
1
+ from collections.abc import Sequence
2
+ from enum import Enum
3
+
4
+ import dask.array as da
5
+ import numpy as np
6
+
7
+ from ngio.utils import NgioValueError, ngio_logger
8
+
9
+
10
+ class Action(str, Enum):
11
+ NONE = "none"
12
+ PAD = "pad"
13
+ TRIM = "trim"
14
+ RESCALING = "rescaling"
15
+
16
+
17
+ def _compute_pad_widths(
18
+ array_shape: tuple[int, ...],
19
+ actions: list[Action],
20
+ target_shape: tuple[int, ...],
21
+ ) -> tuple[tuple[int, int], ...]:
22
+ pad_def = []
23
+ for act, s, ts in zip(actions, array_shape, target_shape, strict=True):
24
+ if act == Action.PAD:
25
+ total_pad = ts - s
26
+ before = total_pad // 2
27
+ after = total_pad - before
28
+ pad_def.append((before, after))
29
+ else:
30
+ pad_def.append((0, 0))
31
+ ngio_logger.warning(
32
+ f"Images have a different shape ({array_shape} vs {target_shape}). "
33
+ f"Resolving by padding: {pad_def}",
34
+ stacklevel=2,
35
+ )
36
+ return tuple(pad_def)
37
+
38
+
39
+ def _numpy_pad(
40
+ array: np.ndarray,
41
+ actions: list[Action],
42
+ target_shape: tuple[int, ...],
43
+ pad_mode: str = "constant",
44
+ constant_values: int | float = 0,
45
+ ) -> np.ndarray:
46
+ if all(act != Action.PAD for act in actions):
47
+ return array
48
+ pad_widths = _compute_pad_widths(array.shape, actions, target_shape)
49
+ return np.pad(array, pad_widths, mode=pad_mode, constant_values=constant_values) # type: ignore
50
+
51
+
52
+ def _dask_pad(
53
+ array: da.Array,
54
+ actions: list[Action],
55
+ target_shape: tuple[int, ...],
56
+ pad_mode: str = "constant",
57
+ constant_values: int | float = 0,
58
+ ) -> da.Array:
59
+ if all(act != Action.PAD for act in actions):
60
+ return array
61
+ shape = tuple(int(s) for s in array.shape)
62
+ pad_widths = _compute_pad_widths(shape, actions, target_shape)
63
+ return da.pad(array, pad_widths, mode=pad_mode, constant_values=constant_values)
64
+
65
+
66
+ def _compute_trim_slices(
67
+ array_shape: tuple[int, ...],
68
+ actions: list[Action],
69
+ target_shape: tuple[int, ...],
70
+ ) -> tuple[slice, ...]:
71
+ slices = []
72
+ for act, s, ts in zip(actions, array_shape, target_shape, strict=True):
73
+ if act == Action.TRIM:
74
+ slices.append(slice(0, ts))
75
+ else:
76
+ slices.append(slice(0, s))
77
+
78
+ ngio_logger.warning(
79
+ f"Images have a different shape ({array_shape} vs {target_shape}). "
80
+ f"Resolving by trimming: {slices}",
81
+ stacklevel=2,
82
+ )
83
+ return tuple(slices)
84
+
85
+
86
+ def _numpy_trim(
87
+ array: np.ndarray, actions: list[Action], target_shape: tuple[int, ...]
88
+ ) -> np.ndarray:
89
+ if all(act != Action.TRIM for act in actions):
90
+ return array
91
+ slices = _compute_trim_slices(array.shape, actions, target_shape)
92
+ return array[tuple(slices)]
93
+
94
+
95
+ def _dask_trim(
96
+ array: da.Array, actions: list[Action], target_shape: tuple[int, ...]
97
+ ) -> da.Array:
98
+ if all(act != Action.TRIM for act in actions):
99
+ return array
100
+ shape = tuple(int(s) for s in array.shape)
101
+ slices = _compute_trim_slices(shape, actions, target_shape)
102
+ return array[tuple(slices)]
103
+
104
+
105
+ def _compute_rescaling_shape(
106
+ array_shape: tuple[int, ...],
107
+ actions: list[Action],
108
+ target_shape: tuple[int, ...],
109
+ ) -> tuple[int, ...]:
110
+ rescaling_shape = []
111
+ factor = []
112
+ for act, s, ts in zip(actions, array_shape, target_shape, strict=True):
113
+ if act == Action.RESCALING:
114
+ rescaling_shape.append(ts)
115
+ factor.append(ts / s)
116
+ else:
117
+ rescaling_shape.append(s)
118
+ factor.append(1.0)
119
+
120
+ ngio_logger.warning(
121
+ f"Images have a different shape ({array_shape} vs {target_shape}). "
122
+ f"Resolving by scaling with factors {factor}.",
123
+ stacklevel=2,
124
+ )
125
+ return tuple(rescaling_shape)
126
+
127
+
128
+ def _numpy_rescaling(
129
+ array: np.ndarray, actions: list[Action], target_shape: tuple[int, ...]
130
+ ) -> np.ndarray:
131
+ if all(act != Action.RESCALING for act in actions):
132
+ return array
133
+ from ngio.common._zoom import numpy_zoom
134
+
135
+ rescaling_shape = _compute_rescaling_shape(array.shape, actions, target_shape)
136
+ return numpy_zoom(source_array=array, target_shape=rescaling_shape, order="nearest")
137
+
138
+
139
+ def _dask_rescaling(
140
+ array: da.Array, actions: list[Action], target_shape: tuple[int, ...]
141
+ ) -> da.Array:
142
+ if all(act != Action.RESCALING for act in actions):
143
+ return array
144
+ from ngio.common._zoom import dask_zoom
145
+
146
+ shape = tuple(int(s) for s in array.shape)
147
+ rescaling_shape = _compute_rescaling_shape(shape, actions, target_shape)
148
+ return dask_zoom(source_array=array, target_shape=rescaling_shape, order="nearest")
149
+
150
+
151
+ def _check_axes(array_shape, reference_shape, array_axes, reference_axes):
152
+ if len(array_shape) != len(array_axes):
153
+ raise NgioValueError(
154
+ f"Array shape {array_shape} and reference axes {array_axes} "
155
+ "must have the same number of dimensions."
156
+ )
157
+ if len(reference_shape) != len(reference_axes):
158
+ raise NgioValueError(
159
+ f"Reference shape {reference_shape} and reference axes {reference_axes} "
160
+ "must have the same number of dimensions."
161
+ )
162
+
163
+ # Check if the array axes are a subset of the target axes
164
+ diff = set(array_axes) - set(reference_axes)
165
+ if diff:
166
+ raise NgioValueError(
167
+ f"Array axes {array_axes} are not a subset "
168
+ f"of reference axes {reference_axes}"
169
+ )
170
+
171
+ # Array must be smaller or equal in number of dimensions
172
+ if len(array_axes) > len(reference_axes):
173
+ raise NgioValueError(
174
+ f"Array has more dimensions ({len(array_axes)}) "
175
+ f"than reference ({len(reference_axes)}). "
176
+ "Cannot match shapes if the array has more dimensions."
177
+ )
178
+
179
+
180
+ def _compute_reshape_and_actions(
181
+ array_shape: tuple[int, ...],
182
+ reference_shape: tuple[int, ...],
183
+ array_axes: list[str],
184
+ reference_axes: list[str],
185
+ tolerance: int = 1,
186
+ allow_rescaling: bool = True,
187
+ ) -> tuple[tuple[int, ...], list[Action]]:
188
+ # Reshape array to match reference shape
189
+ # And determine actions to be taken
190
+ # to match the shapes
191
+ reshape_tuple = []
192
+ actions = []
193
+ errors = []
194
+ left_pointer = 0
195
+ for ref_ax, ref_shape in zip(reference_axes, reference_shape, strict=True):
196
+ if ref_ax not in array_axes:
197
+ reshape_tuple.append(1)
198
+ actions.append(Action.NONE)
199
+ elif ref_ax == array_axes[left_pointer]:
200
+ s2 = array_shape[left_pointer]
201
+ reshape_tuple.append(s2)
202
+ left_pointer += 1
203
+
204
+ if s2 == ref_shape or s2 == 1:
205
+ actions.append(Action.NONE)
206
+ elif s2 < ref_shape:
207
+ if (ref_shape - s2) <= tolerance:
208
+ actions.append(Action.PAD)
209
+ elif allow_rescaling:
210
+ actions.append(Action.RESCALING)
211
+ else:
212
+ errors.append(
213
+ f"Cannot pad axis={ref_ax}:{s2}->{ref_shape} "
214
+ "because shape difference is outside tolerance "
215
+ f"{tolerance}."
216
+ )
217
+ elif s2 > ref_shape:
218
+ if (s2 - ref_shape) <= tolerance:
219
+ actions.append(Action.TRIM)
220
+ elif allow_rescaling:
221
+ actions.append(Action.RESCALING)
222
+ else:
223
+ errors.append(
224
+ f"Cannot trim axis={ref_ax}:{s2}->{ref_shape} "
225
+ "because shape difference is outside tolerance "
226
+ f"{tolerance}."
227
+ )
228
+ else:
229
+ raise RuntimeError("Unreachable code reached.")
230
+ else:
231
+ raise NgioValueError(
232
+ f"Axes order mismatch {array_axes} -> {reference_axes}. "
233
+ "Cannot match shapes if the order is different."
234
+ )
235
+ if errors:
236
+ raise NgioValueError(
237
+ "Array shape cannot be matched to reference shape:\n\n".join(errors)
238
+ )
239
+ return tuple(reshape_tuple), actions
240
+
241
+
242
+ def numpy_match_shape(
243
+ array: np.ndarray,
244
+ reference_shape: tuple[int, ...],
245
+ array_axes: Sequence[str],
246
+ reference_axes: Sequence[str],
247
+ tolerance: int = 1,
248
+ pad_mode: str = "constant",
249
+ pad_values: int | float = 0,
250
+ allow_rescaling: bool = True,
251
+ ):
252
+ """Match the shape of a numpy array to a reference shape.
253
+
254
+ This function will reshape, pad, trim and broadcast the input array
255
+ to match the reference shape. If the shapes cannot be matched within
256
+ the specified tolerance, an error is raised.
257
+
258
+ The reference axes must be a superset of the array axes, and the order
259
+ of the axes must be the same.
260
+
261
+ Args:
262
+ array (np.ndarray): The input array to be reshaped.
263
+ reference_shape (tuple[int, ...]): The target shape to match.
264
+ array_axes (Sequence[str]): The axes names of the input array.
265
+ reference_axes (Sequence[str]): The axes names of the reference shape.
266
+ tolerance (int): The maximum number of pixels by which dimensions
267
+ can differ when matching shapes.
268
+ allow_broadcast (bool): If True, allow broadcasting new dimensions to
269
+ match the reference shape. If False, single-dimension axes will
270
+ be left as is.
271
+ pad_mode (str): The mode to use for padding. See numpy.pad for options.
272
+ pad_values (int | float): The constant value to use for padding if
273
+ pad_mode is 'constant'.
274
+ allow_rescaling (bool): If True, when the array differs more than the
275
+ tolerance, it will be rescalingd to the reference shape. If False,
276
+ an error will be raised.
277
+ """
278
+ _check_axes(
279
+ array_shape=array.shape,
280
+ reference_shape=reference_shape,
281
+ array_axes=array_axes,
282
+ reference_axes=reference_axes,
283
+ )
284
+ if array.shape == reference_shape:
285
+ # Shapes already match
286
+ return array
287
+
288
+ array_axes = list(array_axes)
289
+ reference_axes = list(reference_axes)
290
+
291
+ reshape_tuple, actions = _compute_reshape_and_actions(
292
+ array_shape=array.shape,
293
+ reference_shape=reference_shape,
294
+ array_axes=array_axes,
295
+ reference_axes=reference_axes,
296
+ tolerance=tolerance,
297
+ allow_rescaling=allow_rescaling,
298
+ )
299
+ array = array.reshape(reshape_tuple)
300
+ array = _numpy_rescaling(array=array, actions=actions, target_shape=reference_shape)
301
+ array = _numpy_pad(
302
+ array=array,
303
+ actions=actions,
304
+ target_shape=reference_shape,
305
+ pad_mode=pad_mode,
306
+ constant_values=pad_values,
307
+ )
308
+ array = _numpy_trim(array=array, actions=actions, target_shape=reference_shape)
309
+ return array
310
+
311
+
312
+ def dask_match_shape(
313
+ array: da.Array,
314
+ reference_shape: tuple[int, ...],
315
+ array_axes: Sequence[str],
316
+ reference_axes: Sequence[str],
317
+ tolerance: int = 1,
318
+ pad_mode: str = "constant",
319
+ pad_values: int | float = 0,
320
+ allow_rescaling: bool = True,
321
+ ) -> da.Array:
322
+ """Match the shape of a dask array to a reference shape.
323
+
324
+ This function will reshape, pad, trim and broadcast the input array
325
+ to match the reference shape. If the shapes cannot be matched within
326
+ the specified tolerance, an error is raised.
327
+
328
+ The reference axes must be a superset of the array axes, and the order
329
+ of the axes must be the same.
330
+
331
+ Args:
332
+ array (da.Array): The input array to be reshaped.
333
+ reference_shape (tuple[int, ...]): The target shape to match.
334
+ array_axes (Sequence[str]): The axes names of the input array.
335
+ reference_axes (Sequence[str]): The axes names of the reference shape.
336
+ tolerance (int): The maximum number of pixels by which dimensions
337
+ can differ when matching shapes.
338
+ pad_mode (str): The mode to use for padding. See numpy.pad for options.
339
+ pad_values (int | float): The constant value to use for padding if
340
+ pad_mode is 'constant'.
341
+ allow_rescaling (bool): If True, when the array differs more than the
342
+ tolerance, it will be rescalingd to the reference shape. If False,
343
+ an error will be raised.
344
+ """
345
+ array_shape = tuple(int(s) for s in array.shape)
346
+ _check_axes(
347
+ array_shape=array_shape,
348
+ reference_shape=reference_shape,
349
+ array_axes=array_axes,
350
+ reference_axes=reference_axes,
351
+ )
352
+ if array_shape == reference_shape:
353
+ # Shapes already match
354
+ return array
355
+ array_axes = list(array_axes)
356
+ reference_axes = list(reference_axes)
357
+
358
+ reshape_tuple, actions = _compute_reshape_and_actions(
359
+ array_shape=tuple(int(s) for s in array.shape),
360
+ reference_shape=reference_shape,
361
+ array_axes=array_axes,
362
+ reference_axes=reference_axes,
363
+ tolerance=tolerance,
364
+ allow_rescaling=allow_rescaling,
365
+ )
366
+ array = da.reshape(array, reshape_tuple)
367
+ array = _dask_rescaling(array=array, actions=actions, target_shape=reference_shape)
368
+ array = _dask_pad(
369
+ array=array,
370
+ actions=actions,
371
+ target_shape=reference_shape,
372
+ pad_mode=pad_mode,
373
+ constant_values=pad_values,
374
+ )
375
+ array = _dask_trim(array=array, actions=actions, target_shape=reference_shape)
376
+ return array