ngio 0.4.0a4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,8 +16,8 @@ from ngio.io_pipes._io_pipes import (
16
16
  NumpySetter,
17
17
  )
18
18
  from ngio.io_pipes._io_pipes_roi import roi_to_slicing_dict
19
- from ngio.io_pipes._io_pipes_utils import SlicingInputType
20
19
  from ngio.io_pipes._match_shape import dask_match_shape, numpy_match_shape
20
+ from ngio.io_pipes._ops_slices import SlicingInputType
21
21
  from ngio.io_pipes._ops_transforms import TransformProtocol
22
22
  from ngio.io_pipes._zoom_transform import BaseZoomTransform
23
23
 
@@ -112,6 +112,7 @@ def _setup_numpy_getters(
112
112
  class NumpyGetterMasked(DataGetter[np.ndarray]):
113
113
  def __init__(
114
114
  self,
115
+ *,
115
116
  zarr_array: zarr.Array,
116
117
  dimensions: Dimensions,
117
118
  label_zarr_array: zarr.Array,
@@ -152,6 +153,7 @@ class NumpyGetterMasked(DataGetter[np.ndarray]):
152
153
  slicing_ops=self._data_getter.slicing_ops,
153
154
  axes_ops=self._data_getter.axes_ops,
154
155
  transforms=self._data_getter.transforms,
156
+ roi=roi,
155
157
  )
156
158
 
157
159
  @property
@@ -167,8 +169,8 @@ class NumpyGetterMasked(DataGetter[np.ndarray]):
167
169
  label_data=label_data,
168
170
  label=self.label_id,
169
171
  data_shape=data.shape,
170
- label_axes=self._label_data_getter.axes_ops.in_memory_axes,
171
- data_axes=self._data_getter.axes_ops.in_memory_axes,
172
+ label_axes=self._label_data_getter.axes_ops.output_axes,
173
+ data_axes=self._data_getter.axes_ops.output_axes,
172
174
  allow_rescaling=self._allow_rescaling,
173
175
  )
174
176
  if bool_mask.shape != data.shape:
@@ -180,6 +182,7 @@ class NumpyGetterMasked(DataGetter[np.ndarray]):
180
182
  class NumpySetterMasked(DataSetter[np.ndarray]):
181
183
  def __init__(
182
184
  self,
185
+ *,
183
186
  zarr_array: zarr.Array,
184
187
  dimensions: Dimensions,
185
188
  label_zarr_array: zarr.Array,
@@ -226,6 +229,7 @@ class NumpySetterMasked(DataSetter[np.ndarray]):
226
229
  slicing_ops=self._data_setter.slicing_ops,
227
230
  axes_ops=self._data_setter.axes_ops,
228
231
  transforms=self._data_setter.transforms,
232
+ roi=roi,
229
233
  )
230
234
 
231
235
  @property
@@ -235,14 +239,13 @@ class NumpySetterMasked(DataSetter[np.ndarray]):
235
239
  def set(self, patch: np.ndarray) -> None:
236
240
  data = self._data_getter()
237
241
  label_data = self._label_data_getter()
238
- print(data.shape, label_data.shape)
239
242
 
240
243
  bool_mask = _numpy_label_to_bool_mask(
241
244
  label_data=label_data,
242
245
  label=self.label_id,
243
246
  data_shape=data.shape,
244
- label_axes=self._label_data_getter.axes_ops.in_memory_axes,
245
- data_axes=self._data_getter.axes_ops.in_memory_axes,
247
+ label_axes=self._label_data_getter.axes_ops.output_axes,
248
+ data_axes=self._data_getter.axes_ops.output_axes,
246
249
  allow_rescaling=self._allow_rescaling,
247
250
  )
248
251
  if bool_mask.shape != data.shape:
@@ -342,6 +345,7 @@ def _setup_dask_getters(
342
345
  class DaskGetterMasked(DataGetter[DaskArray]):
343
346
  def __init__(
344
347
  self,
348
+ *,
345
349
  zarr_array: zarr.Array,
346
350
  dimensions: Dimensions,
347
351
  label_zarr_array: zarr.Array,
@@ -381,6 +385,7 @@ class DaskGetterMasked(DataGetter[DaskArray]):
381
385
  slicing_ops=self._data_getter.slicing_ops,
382
386
  axes_ops=self._data_getter.axes_ops,
383
387
  transforms=self._data_getter.transforms,
388
+ roi=roi,
384
389
  )
385
390
 
386
391
  @property
@@ -395,8 +400,8 @@ class DaskGetterMasked(DataGetter[DaskArray]):
395
400
  label_data=label_data,
396
401
  label=self.label_id,
397
402
  data_shape=data_shape,
398
- label_axes=self._label_data_getter.axes_ops.in_memory_axes,
399
- data_axes=self._data_getter.axes_ops.in_memory_axes,
403
+ label_axes=self._label_data_getter.axes_ops.output_axes,
404
+ data_axes=self._data_getter.axes_ops.output_axes,
400
405
  allow_rescaling=self._allow_rescaling,
401
406
  )
402
407
  if bool_mask.shape != data.shape:
@@ -408,6 +413,7 @@ class DaskGetterMasked(DataGetter[DaskArray]):
408
413
  class DaskSetterMasked(DataSetter[DaskArray]):
409
414
  def __init__(
410
415
  self,
416
+ *,
411
417
  zarr_array: zarr.Array,
412
418
  dimensions: Dimensions,
413
419
  label_zarr_array: zarr.Array,
@@ -456,6 +462,7 @@ class DaskSetterMasked(DataSetter[DaskArray]):
456
462
  slicing_ops=self._data_setter.slicing_ops,
457
463
  axes_ops=self._data_setter.axes_ops,
458
464
  transforms=self._data_setter.transforms,
465
+ roi=roi,
459
466
  )
460
467
 
461
468
  @property
@@ -471,8 +478,8 @@ class DaskSetterMasked(DataSetter[DaskArray]):
471
478
  label_data=label_data,
472
479
  label=self.label_id,
473
480
  data_shape=data_shape,
474
- label_axes=self._label_data_getter.axes_ops.in_memory_axes,
475
- data_axes=self._data_getter.axes_ops.in_memory_axes,
481
+ label_axes=self._label_data_getter.axes_ops.output_axes,
482
+ data_axes=self._data_getter.axes_ops.output_axes,
476
483
  allow_rescaling=self._allow_rescaling,
477
484
  )
478
485
  if bool_mask.shape != data.shape:
@@ -10,13 +10,14 @@ from ngio.io_pipes._io_pipes import (
10
10
  NumpyGetter,
11
11
  NumpySetter,
12
12
  )
13
- from ngio.io_pipes._io_pipes_utils import SlicingInputType
13
+ from ngio.io_pipes._ops_slices import SlicingInputType
14
14
  from ngio.io_pipes._ops_transforms import TransformProtocol
15
15
  from ngio.ome_zarr_meta.ngio_specs._pixel_size import PixelSize
16
16
  from ngio.utils import NgioValueError
17
17
 
18
18
 
19
19
  def roi_to_slicing_dict(
20
+ *,
20
21
  roi: Roi | RoiPixels,
21
22
  pixel_size: PixelSize | None = None,
22
23
  slicing_dict: dict[str, SlicingInputType] | None = None,
@@ -42,6 +43,7 @@ def roi_to_slicing_dict(
42
43
  class NumpyRoiGetter(NumpyGetter):
43
44
  def __init__(
44
45
  self,
46
+ *,
45
47
  zarr_array: zarr.Array,
46
48
  dimensions: Dimensions,
47
49
  roi: Roi | RoiPixels,
@@ -62,12 +64,14 @@ class NumpyRoiGetter(NumpyGetter):
62
64
  transforms=transforms,
63
65
  slicing_dict=input_slice_kwargs,
64
66
  remove_channel_selection=remove_channel_selection,
67
+ roi=roi,
65
68
  )
66
69
 
67
70
 
68
71
  class DaskRoiGetter(DaskGetter):
69
72
  def __init__(
70
73
  self,
74
+ *,
71
75
  zarr_array: zarr.Array,
72
76
  dimensions: Dimensions,
73
77
  roi: Roi | RoiPixels,
@@ -88,12 +92,14 @@ class DaskRoiGetter(DaskGetter):
88
92
  transforms=transforms,
89
93
  slicing_dict=input_slice_kwargs,
90
94
  remove_channel_selection=remove_channel_selection,
95
+ roi=roi,
91
96
  )
92
97
 
93
98
 
94
99
  class NumpyRoiSetter(NumpySetter):
95
100
  def __init__(
96
101
  self,
102
+ *,
97
103
  zarr_array: zarr.Array,
98
104
  dimensions: Dimensions,
99
105
  roi: Roi | RoiPixels,
@@ -114,12 +120,14 @@ class NumpyRoiSetter(NumpySetter):
114
120
  transforms=transforms,
115
121
  slicing_dict=input_slice_kwargs,
116
122
  remove_channel_selection=remove_channel_selection,
123
+ roi=roi,
117
124
  )
118
125
 
119
126
 
120
127
  class DaskRoiSetter(DaskSetter):
121
128
  def __init__(
122
129
  self,
130
+ *,
123
131
  zarr_array: zarr.Array,
124
132
  dimensions: Dimensions,
125
133
  roi: Roi | RoiPixels,
@@ -140,4 +148,5 @@ class DaskRoiSetter(DaskSetter):
140
148
  transforms=transforms,
141
149
  slicing_dict=input_slice_kwargs,
142
150
  remove_channel_selection=remove_channel_selection,
151
+ roi=roi,
143
152
  )
@@ -0,0 +1,56 @@
1
+ from collections.abc import Sequence
2
+ from typing import Protocol, TypeVar
3
+
4
+ import zarr
5
+
6
+ from ngio.common._roi import Roi, RoiPixels
7
+ from ngio.io_pipes._ops_axes import AxesOps
8
+ from ngio.io_pipes._ops_slices import SlicingOps
9
+ from ngio.io_pipes._ops_transforms import TransformProtocol
10
+
11
+ GetterDataType = TypeVar("GetterDataType", covariant=True)
12
+ SetterDataType = TypeVar("SetterDataType", contravariant=True)
13
+
14
+
15
+ class DataGetterProtocol(Protocol[GetterDataType]):
16
+ @property
17
+ def zarr_array(self) -> zarr.Array: ...
18
+
19
+ @property
20
+ def slicing_ops(self) -> SlicingOps: ...
21
+
22
+ @property
23
+ def axes_ops(self) -> AxesOps: ...
24
+
25
+ @property
26
+ def transforms(self) -> Sequence[TransformProtocol] | None: ...
27
+
28
+ @property
29
+ def roi(self) -> Roi | RoiPixels: ...
30
+
31
+ def __call__(self) -> GetterDataType:
32
+ return self.get()
33
+
34
+ def get(self) -> GetterDataType: ...
35
+
36
+
37
+ class DataSetterProtocol(Protocol[SetterDataType]):
38
+ @property
39
+ def zarr_array(self) -> zarr.Array: ...
40
+
41
+ @property
42
+ def slicing_ops(self) -> SlicingOps: ...
43
+
44
+ @property
45
+ def axes_ops(self) -> AxesOps: ...
46
+
47
+ @property
48
+ def transforms(self) -> Sequence[TransformProtocol] | None: ...
49
+
50
+ @property
51
+ def roi(self) -> Roi | RoiPixels: ...
52
+
53
+ def __call__(self, patch: SetterDataType) -> None:
54
+ return self.set(patch)
55
+
56
+ def set(self, patch: SetterDataType) -> None: ...
@@ -3,8 +3,80 @@ from typing import TypeVar
3
3
 
4
4
  import dask.array as da
5
5
  import numpy as np
6
+ from pydantic import BaseModel, ConfigDict
6
7
 
7
- from ngio.ome_zarr_meta.ngio_specs._axes import AxesOps
8
+ from ngio.common._dimensions import Dimensions
9
+ from ngio.utils import NgioValueError
10
+
11
+ ##############################################################
12
+ #
13
+ # "AxesOps" Model
14
+ #
15
+ ##############################################################
16
+
17
+
18
+ class AxesOps(BaseModel):
19
+ """Model to represent axes operations.
20
+
21
+ This model will be used to transform objects from on disk axes to in memory axes.
22
+ """
23
+
24
+ input_axes: tuple[str, ...]
25
+ output_axes: tuple[str, ...]
26
+ transpose_op: tuple[int, ...] | None = None
27
+ expand_op: tuple[int, ...] | None = None
28
+ squeeze_op: tuple[int, ...] | None = None
29
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
30
+
31
+ @property
32
+ def is_no_op(self) -> bool:
33
+ """Check if all operations are no ops."""
34
+ if (
35
+ self.transpose_op is None
36
+ and self.expand_op is None
37
+ and self.squeeze_op is None
38
+ ):
39
+ return True
40
+ return False
41
+
42
+ @property
43
+ def get_transpose_op(self) -> tuple[int, ...] | None:
44
+ """Get the transpose axes."""
45
+ return self.transpose_op
46
+
47
+ @property
48
+ def get_expand_op(self) -> tuple[int, ...] | None:
49
+ """Get the expand axes."""
50
+ return self.expand_op
51
+
52
+ @property
53
+ def get_squeeze_op(self) -> tuple[int, ...] | None:
54
+ """Get the squeeze axes."""
55
+ return self.squeeze_op
56
+
57
+ @property
58
+ def set_transpose_op(self) -> tuple[int, ...] | None:
59
+ """Set the transpose axes."""
60
+ if self.transpose_op is None:
61
+ return None
62
+ return tuple(np.argsort(self.transpose_op))
63
+
64
+ @property
65
+ def set_expand_op(self) -> tuple[int, ...] | None:
66
+ """Set the expand axes."""
67
+ return self.squeeze_op
68
+
69
+ @property
70
+ def set_squeeze_op(self) -> tuple[int, ...] | None:
71
+ """Set the squeeze axes."""
72
+ return self.expand_op
73
+
74
+
75
+ ##############################################################
76
+ #
77
+ # Axes Operations implementations
78
+ #
79
+ ##############################################################
8
80
 
9
81
 
10
82
  def _apply_numpy_axes_ops(
@@ -144,3 +216,129 @@ def set_as_sequence_axes_ops(
144
216
  transpose_axes=axes_ops.set_transpose_op,
145
217
  expand_axes=axes_ops.set_expand_op,
146
218
  )
219
+
220
+
221
+ ##############################################################
222
+ #
223
+ # Builder functions
224
+ #
225
+ ##############################################################
226
+
227
+
228
+ def _check_output_axes(axes: Sequence[str]) -> None:
229
+ """Check that the input axes are valid."""
230
+ unique_names = set(axes)
231
+ if len(unique_names) != len(axes):
232
+ raise NgioValueError(
233
+ "Duplicate axis names found. Please provide unique names for each axis."
234
+ )
235
+ for name in axes:
236
+ if not isinstance(name, str):
237
+ raise NgioValueError(
238
+ f"Invalid axis name '{name}'. Axis names must be strings."
239
+ )
240
+
241
+
242
+ def _build_squeeze_tuple(
243
+ input_axes: tuple[str, ...], output_axes: tuple[str, ...]
244
+ ) -> tuple[tuple[int, ...], tuple[str, ...]]:
245
+ """Build a tuple of axes to squeeze."""
246
+ axes_to_squeeze = []
247
+ axes_after_squeeze = []
248
+ for i, ax in enumerate(input_axes):
249
+ if ax not in output_axes:
250
+ axes_to_squeeze.append(i)
251
+ else:
252
+ axes_after_squeeze.append(ax)
253
+ return tuple(axes_to_squeeze), tuple(axes_after_squeeze)
254
+
255
+
256
+ def _build_transpose_tuple(
257
+ input_axes: tuple[str, ...], output_axes: tuple[str, ...]
258
+ ) -> tuple[tuple[int, ...], tuple[str, ...]]:
259
+ """Build a tuple of axes to transpose."""
260
+ transposition_order = []
261
+ axes_names_after_transpose = []
262
+ for ax in output_axes:
263
+ if ax in input_axes:
264
+ transposition_order.append(input_axes.index(ax))
265
+ axes_names_after_transpose.append(ax)
266
+ return tuple(transposition_order), tuple(axes_names_after_transpose)
267
+
268
+
269
+ def _build_expand_tuple(
270
+ input_axes: tuple[str, ...], output_axes: tuple[str, ...]
271
+ ) -> tuple[int, ...]:
272
+ """Build a tuple of axes to expand."""
273
+ axes_to_expand = []
274
+ for i, ax in enumerate(output_axes):
275
+ if ax not in input_axes:
276
+ axes_to_expand.append(i)
277
+ return tuple(axes_to_expand)
278
+
279
+
280
+ def _build_axes_ops(
281
+ input_axes: tuple[str, ...], output_axes: tuple[str, ...]
282
+ ) -> AxesOps:
283
+ """Change the order of the axes."""
284
+ # Validate the names
285
+ _check_output_axes(output_axes)
286
+ # Step 1: Check find squeeze axes
287
+ axes_to_squeeze, input_axes = _build_squeeze_tuple(input_axes, output_axes)
288
+ # Step 2: Find the transposition order
289
+ transposition_order, input_axes = _build_transpose_tuple(input_axes, output_axes)
290
+ # Step 3: Find axes to expand
291
+ axes_to_expand = _build_expand_tuple(input_axes, output_axes)
292
+
293
+ # If the operations are empty, make them None
294
+ if len(axes_to_squeeze) == 0:
295
+ axes_to_squeeze = None
296
+
297
+ if np.allclose(transposition_order, np.arange(len(transposition_order))):
298
+ # If the transposition order is the identity, we don't need to transpose
299
+ transposition_order = None
300
+ if len(axes_to_expand) == 0:
301
+ axes_to_expand = None
302
+
303
+ return AxesOps(
304
+ input_axes=input_axes,
305
+ output_axes=output_axes,
306
+ transpose_op=transposition_order,
307
+ expand_op=axes_to_expand,
308
+ squeeze_op=axes_to_squeeze,
309
+ )
310
+
311
+
312
+ def _normalize_axes_order(
313
+ dimensions: Dimensions,
314
+ axes_order: Sequence[str],
315
+ ) -> tuple[str, ...]:
316
+ """Convert axes order to the on-disk axes names.
317
+
318
+ In this way there is not unambiguity in the axes order.
319
+ """
320
+ new_axes_order = []
321
+ for axis_name in axes_order:
322
+ axis = dimensions.axes_handler.get_axis(axis_name)
323
+ if axis is None:
324
+ new_axes_order.append(axis_name)
325
+ else:
326
+ new_axes_order.append(axis.name)
327
+ return tuple(new_axes_order)
328
+
329
+
330
+ def build_axes_ops(
331
+ *,
332
+ dimensions: Dimensions,
333
+ input_axes: tuple[str, ...],
334
+ axes_order: Sequence[str] | None,
335
+ ) -> AxesOps:
336
+ if axes_order is None:
337
+ return AxesOps(
338
+ input_axes=input_axes,
339
+ output_axes=input_axes,
340
+ )
341
+ output_axes = _normalize_axes_order(dimensions=dimensions, axes_order=axes_order)
342
+
343
+ axes_ops = _build_axes_ops(input_axes=input_axes, output_axes=output_axes)
344
+ return axes_ops