ngio 0.4.0a4__py3-none-any.whl → 0.4.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ngio/common/_dimensions.py +209 -189
- ngio/common/_roi.py +2 -2
- ngio/experimental/iterators/__init__.py +0 -2
- ngio/experimental/iterators/_abstract_iterator.py +246 -26
- ngio/experimental/iterators/_feature.py +84 -41
- ngio/experimental/iterators/_image_processing.py +7 -36
- ngio/experimental/iterators/_mappers.py +48 -0
- ngio/experimental/iterators/_segmentation.py +7 -38
- ngio/images/_abstract_image.py +60 -5
- ngio/images/_image.py +2 -0
- ngio/images/_label.py +2 -0
- ngio/images/_masked_image.py +22 -17
- ngio/io_pipes/__init__.py +29 -3
- ngio/io_pipes/_io_pipes.py +93 -18
- ngio/io_pipes/_io_pipes_masked.py +17 -10
- ngio/io_pipes/_io_pipes_roi.py +10 -1
- ngio/io_pipes/_io_pipes_types.py +56 -0
- ngio/io_pipes/_ops_axes.py +199 -1
- ngio/io_pipes/_ops_slices.py +255 -27
- ngio/io_pipes/_ops_slices_utils.py +196 -0
- ngio/io_pipes/_ops_transforms.py +1 -1
- ngio/io_pipes/_zoom_transform.py +5 -5
- ngio/ome_zarr_meta/__init__.py +0 -2
- ngio/ome_zarr_meta/ngio_specs/__init__.py +0 -2
- ngio/ome_zarr_meta/ngio_specs/_axes.py +7 -131
- ngio/utils/_datasets.py +5 -0
- {ngio-0.4.0a4.dist-info → ngio-0.4.0b1.dist-info}/METADATA +1 -1
- {ngio-0.4.0a4.dist-info → ngio-0.4.0b1.dist-info}/RECORD +30 -28
- ngio/io_pipes/_io_pipes_utils.py +0 -299
- {ngio-0.4.0a4.dist-info → ngio-0.4.0b1.dist-info}/WHEEL +0 -0
- {ngio-0.4.0a4.dist-info → ngio-0.4.0b1.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,8 +16,8 @@ from ngio.io_pipes._io_pipes import (
|
|
|
16
16
|
NumpySetter,
|
|
17
17
|
)
|
|
18
18
|
from ngio.io_pipes._io_pipes_roi import roi_to_slicing_dict
|
|
19
|
-
from ngio.io_pipes._io_pipes_utils import SlicingInputType
|
|
20
19
|
from ngio.io_pipes._match_shape import dask_match_shape, numpy_match_shape
|
|
20
|
+
from ngio.io_pipes._ops_slices import SlicingInputType
|
|
21
21
|
from ngio.io_pipes._ops_transforms import TransformProtocol
|
|
22
22
|
from ngio.io_pipes._zoom_transform import BaseZoomTransform
|
|
23
23
|
|
|
@@ -112,6 +112,7 @@ def _setup_numpy_getters(
|
|
|
112
112
|
class NumpyGetterMasked(DataGetter[np.ndarray]):
|
|
113
113
|
def __init__(
|
|
114
114
|
self,
|
|
115
|
+
*,
|
|
115
116
|
zarr_array: zarr.Array,
|
|
116
117
|
dimensions: Dimensions,
|
|
117
118
|
label_zarr_array: zarr.Array,
|
|
@@ -152,6 +153,7 @@ class NumpyGetterMasked(DataGetter[np.ndarray]):
|
|
|
152
153
|
slicing_ops=self._data_getter.slicing_ops,
|
|
153
154
|
axes_ops=self._data_getter.axes_ops,
|
|
154
155
|
transforms=self._data_getter.transforms,
|
|
156
|
+
roi=roi,
|
|
155
157
|
)
|
|
156
158
|
|
|
157
159
|
@property
|
|
@@ -167,8 +169,8 @@ class NumpyGetterMasked(DataGetter[np.ndarray]):
|
|
|
167
169
|
label_data=label_data,
|
|
168
170
|
label=self.label_id,
|
|
169
171
|
data_shape=data.shape,
|
|
170
|
-
label_axes=self._label_data_getter.axes_ops.
|
|
171
|
-
data_axes=self._data_getter.axes_ops.
|
|
172
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
173
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
172
174
|
allow_rescaling=self._allow_rescaling,
|
|
173
175
|
)
|
|
174
176
|
if bool_mask.shape != data.shape:
|
|
@@ -180,6 +182,7 @@ class NumpyGetterMasked(DataGetter[np.ndarray]):
|
|
|
180
182
|
class NumpySetterMasked(DataSetter[np.ndarray]):
|
|
181
183
|
def __init__(
|
|
182
184
|
self,
|
|
185
|
+
*,
|
|
183
186
|
zarr_array: zarr.Array,
|
|
184
187
|
dimensions: Dimensions,
|
|
185
188
|
label_zarr_array: zarr.Array,
|
|
@@ -226,6 +229,7 @@ class NumpySetterMasked(DataSetter[np.ndarray]):
|
|
|
226
229
|
slicing_ops=self._data_setter.slicing_ops,
|
|
227
230
|
axes_ops=self._data_setter.axes_ops,
|
|
228
231
|
transforms=self._data_setter.transforms,
|
|
232
|
+
roi=roi,
|
|
229
233
|
)
|
|
230
234
|
|
|
231
235
|
@property
|
|
@@ -235,14 +239,13 @@ class NumpySetterMasked(DataSetter[np.ndarray]):
|
|
|
235
239
|
def set(self, patch: np.ndarray) -> None:
|
|
236
240
|
data = self._data_getter()
|
|
237
241
|
label_data = self._label_data_getter()
|
|
238
|
-
print(data.shape, label_data.shape)
|
|
239
242
|
|
|
240
243
|
bool_mask = _numpy_label_to_bool_mask(
|
|
241
244
|
label_data=label_data,
|
|
242
245
|
label=self.label_id,
|
|
243
246
|
data_shape=data.shape,
|
|
244
|
-
label_axes=self._label_data_getter.axes_ops.
|
|
245
|
-
data_axes=self._data_getter.axes_ops.
|
|
247
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
248
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
246
249
|
allow_rescaling=self._allow_rescaling,
|
|
247
250
|
)
|
|
248
251
|
if bool_mask.shape != data.shape:
|
|
@@ -342,6 +345,7 @@ def _setup_dask_getters(
|
|
|
342
345
|
class DaskGetterMasked(DataGetter[DaskArray]):
|
|
343
346
|
def __init__(
|
|
344
347
|
self,
|
|
348
|
+
*,
|
|
345
349
|
zarr_array: zarr.Array,
|
|
346
350
|
dimensions: Dimensions,
|
|
347
351
|
label_zarr_array: zarr.Array,
|
|
@@ -381,6 +385,7 @@ class DaskGetterMasked(DataGetter[DaskArray]):
|
|
|
381
385
|
slicing_ops=self._data_getter.slicing_ops,
|
|
382
386
|
axes_ops=self._data_getter.axes_ops,
|
|
383
387
|
transforms=self._data_getter.transforms,
|
|
388
|
+
roi=roi,
|
|
384
389
|
)
|
|
385
390
|
|
|
386
391
|
@property
|
|
@@ -395,8 +400,8 @@ class DaskGetterMasked(DataGetter[DaskArray]):
|
|
|
395
400
|
label_data=label_data,
|
|
396
401
|
label=self.label_id,
|
|
397
402
|
data_shape=data_shape,
|
|
398
|
-
label_axes=self._label_data_getter.axes_ops.
|
|
399
|
-
data_axes=self._data_getter.axes_ops.
|
|
403
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
404
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
400
405
|
allow_rescaling=self._allow_rescaling,
|
|
401
406
|
)
|
|
402
407
|
if bool_mask.shape != data.shape:
|
|
@@ -408,6 +413,7 @@ class DaskGetterMasked(DataGetter[DaskArray]):
|
|
|
408
413
|
class DaskSetterMasked(DataSetter[DaskArray]):
|
|
409
414
|
def __init__(
|
|
410
415
|
self,
|
|
416
|
+
*,
|
|
411
417
|
zarr_array: zarr.Array,
|
|
412
418
|
dimensions: Dimensions,
|
|
413
419
|
label_zarr_array: zarr.Array,
|
|
@@ -456,6 +462,7 @@ class DaskSetterMasked(DataSetter[DaskArray]):
|
|
|
456
462
|
slicing_ops=self._data_setter.slicing_ops,
|
|
457
463
|
axes_ops=self._data_setter.axes_ops,
|
|
458
464
|
transforms=self._data_setter.transforms,
|
|
465
|
+
roi=roi,
|
|
459
466
|
)
|
|
460
467
|
|
|
461
468
|
@property
|
|
@@ -471,8 +478,8 @@ class DaskSetterMasked(DataSetter[DaskArray]):
|
|
|
471
478
|
label_data=label_data,
|
|
472
479
|
label=self.label_id,
|
|
473
480
|
data_shape=data_shape,
|
|
474
|
-
label_axes=self._label_data_getter.axes_ops.
|
|
475
|
-
data_axes=self._data_getter.axes_ops.
|
|
481
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
482
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
476
483
|
allow_rescaling=self._allow_rescaling,
|
|
477
484
|
)
|
|
478
485
|
if bool_mask.shape != data.shape:
|
ngio/io_pipes/_io_pipes_roi.py
CHANGED
|
@@ -10,13 +10,14 @@ from ngio.io_pipes._io_pipes import (
|
|
|
10
10
|
NumpyGetter,
|
|
11
11
|
NumpySetter,
|
|
12
12
|
)
|
|
13
|
-
from ngio.io_pipes.
|
|
13
|
+
from ngio.io_pipes._ops_slices import SlicingInputType
|
|
14
14
|
from ngio.io_pipes._ops_transforms import TransformProtocol
|
|
15
15
|
from ngio.ome_zarr_meta.ngio_specs._pixel_size import PixelSize
|
|
16
16
|
from ngio.utils import NgioValueError
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def roi_to_slicing_dict(
|
|
20
|
+
*,
|
|
20
21
|
roi: Roi | RoiPixels,
|
|
21
22
|
pixel_size: PixelSize | None = None,
|
|
22
23
|
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
@@ -42,6 +43,7 @@ def roi_to_slicing_dict(
|
|
|
42
43
|
class NumpyRoiGetter(NumpyGetter):
|
|
43
44
|
def __init__(
|
|
44
45
|
self,
|
|
46
|
+
*,
|
|
45
47
|
zarr_array: zarr.Array,
|
|
46
48
|
dimensions: Dimensions,
|
|
47
49
|
roi: Roi | RoiPixels,
|
|
@@ -62,12 +64,14 @@ class NumpyRoiGetter(NumpyGetter):
|
|
|
62
64
|
transforms=transforms,
|
|
63
65
|
slicing_dict=input_slice_kwargs,
|
|
64
66
|
remove_channel_selection=remove_channel_selection,
|
|
67
|
+
roi=roi,
|
|
65
68
|
)
|
|
66
69
|
|
|
67
70
|
|
|
68
71
|
class DaskRoiGetter(DaskGetter):
|
|
69
72
|
def __init__(
|
|
70
73
|
self,
|
|
74
|
+
*,
|
|
71
75
|
zarr_array: zarr.Array,
|
|
72
76
|
dimensions: Dimensions,
|
|
73
77
|
roi: Roi | RoiPixels,
|
|
@@ -88,12 +92,14 @@ class DaskRoiGetter(DaskGetter):
|
|
|
88
92
|
transforms=transforms,
|
|
89
93
|
slicing_dict=input_slice_kwargs,
|
|
90
94
|
remove_channel_selection=remove_channel_selection,
|
|
95
|
+
roi=roi,
|
|
91
96
|
)
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
class NumpyRoiSetter(NumpySetter):
|
|
95
100
|
def __init__(
|
|
96
101
|
self,
|
|
102
|
+
*,
|
|
97
103
|
zarr_array: zarr.Array,
|
|
98
104
|
dimensions: Dimensions,
|
|
99
105
|
roi: Roi | RoiPixels,
|
|
@@ -114,12 +120,14 @@ class NumpyRoiSetter(NumpySetter):
|
|
|
114
120
|
transforms=transforms,
|
|
115
121
|
slicing_dict=input_slice_kwargs,
|
|
116
122
|
remove_channel_selection=remove_channel_selection,
|
|
123
|
+
roi=roi,
|
|
117
124
|
)
|
|
118
125
|
|
|
119
126
|
|
|
120
127
|
class DaskRoiSetter(DaskSetter):
|
|
121
128
|
def __init__(
|
|
122
129
|
self,
|
|
130
|
+
*,
|
|
123
131
|
zarr_array: zarr.Array,
|
|
124
132
|
dimensions: Dimensions,
|
|
125
133
|
roi: Roi | RoiPixels,
|
|
@@ -140,4 +148,5 @@ class DaskRoiSetter(DaskSetter):
|
|
|
140
148
|
transforms=transforms,
|
|
141
149
|
slicing_dict=input_slice_kwargs,
|
|
142
150
|
remove_channel_selection=remove_channel_selection,
|
|
151
|
+
roi=roi,
|
|
143
152
|
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Protocol, TypeVar
|
|
3
|
+
|
|
4
|
+
import zarr
|
|
5
|
+
|
|
6
|
+
from ngio.common._roi import Roi, RoiPixels
|
|
7
|
+
from ngio.io_pipes._ops_axes import AxesOps
|
|
8
|
+
from ngio.io_pipes._ops_slices import SlicingOps
|
|
9
|
+
from ngio.io_pipes._ops_transforms import TransformProtocol
|
|
10
|
+
|
|
11
|
+
GetterDataType = TypeVar("GetterDataType", covariant=True)
|
|
12
|
+
SetterDataType = TypeVar("SetterDataType", contravariant=True)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DataGetterProtocol(Protocol[GetterDataType]):
|
|
16
|
+
@property
|
|
17
|
+
def zarr_array(self) -> zarr.Array: ...
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def slicing_ops(self) -> SlicingOps: ...
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def axes_ops(self) -> AxesOps: ...
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def transforms(self) -> Sequence[TransformProtocol] | None: ...
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def roi(self) -> Roi | RoiPixels: ...
|
|
30
|
+
|
|
31
|
+
def __call__(self) -> GetterDataType:
|
|
32
|
+
return self.get()
|
|
33
|
+
|
|
34
|
+
def get(self) -> GetterDataType: ...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class DataSetterProtocol(Protocol[SetterDataType]):
|
|
38
|
+
@property
|
|
39
|
+
def zarr_array(self) -> zarr.Array: ...
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def slicing_ops(self) -> SlicingOps: ...
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def axes_ops(self) -> AxesOps: ...
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def transforms(self) -> Sequence[TransformProtocol] | None: ...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def roi(self) -> Roi | RoiPixels: ...
|
|
52
|
+
|
|
53
|
+
def __call__(self, patch: SetterDataType) -> None:
|
|
54
|
+
return self.set(patch)
|
|
55
|
+
|
|
56
|
+
def set(self, patch: SetterDataType) -> None: ...
|
ngio/io_pipes/_ops_axes.py
CHANGED
|
@@ -3,8 +3,80 @@ from typing import TypeVar
|
|
|
3
3
|
|
|
4
4
|
import dask.array as da
|
|
5
5
|
import numpy as np
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
7
|
|
|
7
|
-
from ngio.
|
|
8
|
+
from ngio.common._dimensions import Dimensions
|
|
9
|
+
from ngio.utils import NgioValueError
|
|
10
|
+
|
|
11
|
+
##############################################################
|
|
12
|
+
#
|
|
13
|
+
# "AxesOps" Model
|
|
14
|
+
#
|
|
15
|
+
##############################################################
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AxesOps(BaseModel):
|
|
19
|
+
"""Model to represent axes operations.
|
|
20
|
+
|
|
21
|
+
This model will be used to transform objects from on disk axes to in memory axes.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
input_axes: tuple[str, ...]
|
|
25
|
+
output_axes: tuple[str, ...]
|
|
26
|
+
transpose_op: tuple[int, ...] | None = None
|
|
27
|
+
expand_op: tuple[int, ...] | None = None
|
|
28
|
+
squeeze_op: tuple[int, ...] | None = None
|
|
29
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def is_no_op(self) -> bool:
|
|
33
|
+
"""Check if all operations are no ops."""
|
|
34
|
+
if (
|
|
35
|
+
self.transpose_op is None
|
|
36
|
+
and self.expand_op is None
|
|
37
|
+
and self.squeeze_op is None
|
|
38
|
+
):
|
|
39
|
+
return True
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def get_transpose_op(self) -> tuple[int, ...] | None:
|
|
44
|
+
"""Get the transpose axes."""
|
|
45
|
+
return self.transpose_op
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def get_expand_op(self) -> tuple[int, ...] | None:
|
|
49
|
+
"""Get the expand axes."""
|
|
50
|
+
return self.expand_op
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def get_squeeze_op(self) -> tuple[int, ...] | None:
|
|
54
|
+
"""Get the squeeze axes."""
|
|
55
|
+
return self.squeeze_op
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def set_transpose_op(self) -> tuple[int, ...] | None:
|
|
59
|
+
"""Set the transpose axes."""
|
|
60
|
+
if self.transpose_op is None:
|
|
61
|
+
return None
|
|
62
|
+
return tuple(np.argsort(self.transpose_op))
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def set_expand_op(self) -> tuple[int, ...] | None:
|
|
66
|
+
"""Set the expand axes."""
|
|
67
|
+
return self.squeeze_op
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def set_squeeze_op(self) -> tuple[int, ...] | None:
|
|
71
|
+
"""Set the squeeze axes."""
|
|
72
|
+
return self.expand_op
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
##############################################################
|
|
76
|
+
#
|
|
77
|
+
# Axes Operations implementations
|
|
78
|
+
#
|
|
79
|
+
##############################################################
|
|
8
80
|
|
|
9
81
|
|
|
10
82
|
def _apply_numpy_axes_ops(
|
|
@@ -144,3 +216,129 @@ def set_as_sequence_axes_ops(
|
|
|
144
216
|
transpose_axes=axes_ops.set_transpose_op,
|
|
145
217
|
expand_axes=axes_ops.set_expand_op,
|
|
146
218
|
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
##############################################################
|
|
222
|
+
#
|
|
223
|
+
# Builder functions
|
|
224
|
+
#
|
|
225
|
+
##############################################################
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _check_output_axes(axes: Sequence[str]) -> None:
|
|
229
|
+
"""Check that the input axes are valid."""
|
|
230
|
+
unique_names = set(axes)
|
|
231
|
+
if len(unique_names) != len(axes):
|
|
232
|
+
raise NgioValueError(
|
|
233
|
+
"Duplicate axis names found. Please provide unique names for each axis."
|
|
234
|
+
)
|
|
235
|
+
for name in axes:
|
|
236
|
+
if not isinstance(name, str):
|
|
237
|
+
raise NgioValueError(
|
|
238
|
+
f"Invalid axis name '{name}'. Axis names must be strings."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _build_squeeze_tuple(
|
|
243
|
+
input_axes: tuple[str, ...], output_axes: tuple[str, ...]
|
|
244
|
+
) -> tuple[tuple[int, ...], tuple[str, ...]]:
|
|
245
|
+
"""Build a tuple of axes to squeeze."""
|
|
246
|
+
axes_to_squeeze = []
|
|
247
|
+
axes_after_squeeze = []
|
|
248
|
+
for i, ax in enumerate(input_axes):
|
|
249
|
+
if ax not in output_axes:
|
|
250
|
+
axes_to_squeeze.append(i)
|
|
251
|
+
else:
|
|
252
|
+
axes_after_squeeze.append(ax)
|
|
253
|
+
return tuple(axes_to_squeeze), tuple(axes_after_squeeze)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _build_transpose_tuple(
|
|
257
|
+
input_axes: tuple[str, ...], output_axes: tuple[str, ...]
|
|
258
|
+
) -> tuple[tuple[int, ...], tuple[str, ...]]:
|
|
259
|
+
"""Build a tuple of axes to transpose."""
|
|
260
|
+
transposition_order = []
|
|
261
|
+
axes_names_after_transpose = []
|
|
262
|
+
for ax in output_axes:
|
|
263
|
+
if ax in input_axes:
|
|
264
|
+
transposition_order.append(input_axes.index(ax))
|
|
265
|
+
axes_names_after_transpose.append(ax)
|
|
266
|
+
return tuple(transposition_order), tuple(axes_names_after_transpose)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _build_expand_tuple(
|
|
270
|
+
input_axes: tuple[str, ...], output_axes: tuple[str, ...]
|
|
271
|
+
) -> tuple[int, ...]:
|
|
272
|
+
"""Build a tuple of axes to expand."""
|
|
273
|
+
axes_to_expand = []
|
|
274
|
+
for i, ax in enumerate(output_axes):
|
|
275
|
+
if ax not in input_axes:
|
|
276
|
+
axes_to_expand.append(i)
|
|
277
|
+
return tuple(axes_to_expand)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _build_axes_ops(
|
|
281
|
+
input_axes: tuple[str, ...], output_axes: tuple[str, ...]
|
|
282
|
+
) -> AxesOps:
|
|
283
|
+
"""Change the order of the axes."""
|
|
284
|
+
# Validate the names
|
|
285
|
+
_check_output_axes(output_axes)
|
|
286
|
+
# Step 1: Check find squeeze axes
|
|
287
|
+
axes_to_squeeze, input_axes = _build_squeeze_tuple(input_axes, output_axes)
|
|
288
|
+
# Step 2: Find the transposition order
|
|
289
|
+
transposition_order, input_axes = _build_transpose_tuple(input_axes, output_axes)
|
|
290
|
+
# Step 3: Find axes to expand
|
|
291
|
+
axes_to_expand = _build_expand_tuple(input_axes, output_axes)
|
|
292
|
+
|
|
293
|
+
# If the operations are empty, make them None
|
|
294
|
+
if len(axes_to_squeeze) == 0:
|
|
295
|
+
axes_to_squeeze = None
|
|
296
|
+
|
|
297
|
+
if np.allclose(transposition_order, np.arange(len(transposition_order))):
|
|
298
|
+
# If the transposition order is the identity, we don't need to transpose
|
|
299
|
+
transposition_order = None
|
|
300
|
+
if len(axes_to_expand) == 0:
|
|
301
|
+
axes_to_expand = None
|
|
302
|
+
|
|
303
|
+
return AxesOps(
|
|
304
|
+
input_axes=input_axes,
|
|
305
|
+
output_axes=output_axes,
|
|
306
|
+
transpose_op=transposition_order,
|
|
307
|
+
expand_op=axes_to_expand,
|
|
308
|
+
squeeze_op=axes_to_squeeze,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _normalize_axes_order(
|
|
313
|
+
dimensions: Dimensions,
|
|
314
|
+
axes_order: Sequence[str],
|
|
315
|
+
) -> tuple[str, ...]:
|
|
316
|
+
"""Convert axes order to the on-disk axes names.
|
|
317
|
+
|
|
318
|
+
In this way there is not unambiguity in the axes order.
|
|
319
|
+
"""
|
|
320
|
+
new_axes_order = []
|
|
321
|
+
for axis_name in axes_order:
|
|
322
|
+
axis = dimensions.axes_handler.get_axis(axis_name)
|
|
323
|
+
if axis is None:
|
|
324
|
+
new_axes_order.append(axis_name)
|
|
325
|
+
else:
|
|
326
|
+
new_axes_order.append(axis.name)
|
|
327
|
+
return tuple(new_axes_order)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def build_axes_ops(
|
|
331
|
+
*,
|
|
332
|
+
dimensions: Dimensions,
|
|
333
|
+
input_axes: tuple[str, ...],
|
|
334
|
+
axes_order: Sequence[str] | None,
|
|
335
|
+
) -> AxesOps:
|
|
336
|
+
if axes_order is None:
|
|
337
|
+
return AxesOps(
|
|
338
|
+
input_axes=input_axes,
|
|
339
|
+
output_axes=input_axes,
|
|
340
|
+
)
|
|
341
|
+
output_axes = _normalize_axes_order(dimensions=dimensions, axes_order=axes_order)
|
|
342
|
+
|
|
343
|
+
axes_ops = _build_axes_ops(input_axes=input_axes, output_axes=output_axes)
|
|
344
|
+
return axes_ops
|