ngio 0.5.0b6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. ngio/__init__.py +69 -0
  2. ngio/common/__init__.py +28 -0
  3. ngio/common/_dimensions.py +335 -0
  4. ngio/common/_masking_roi.py +153 -0
  5. ngio/common/_pyramid.py +408 -0
  6. ngio/common/_roi.py +315 -0
  7. ngio/common/_synt_images_utils.py +101 -0
  8. ngio/common/_zoom.py +188 -0
  9. ngio/experimental/__init__.py +5 -0
  10. ngio/experimental/iterators/__init__.py +15 -0
  11. ngio/experimental/iterators/_abstract_iterator.py +390 -0
  12. ngio/experimental/iterators/_feature.py +189 -0
  13. ngio/experimental/iterators/_image_processing.py +130 -0
  14. ngio/experimental/iterators/_mappers.py +48 -0
  15. ngio/experimental/iterators/_rois_utils.py +126 -0
  16. ngio/experimental/iterators/_segmentation.py +235 -0
  17. ngio/hcs/__init__.py +19 -0
  18. ngio/hcs/_plate.py +1354 -0
  19. ngio/images/__init__.py +44 -0
  20. ngio/images/_abstract_image.py +967 -0
  21. ngio/images/_create_synt_container.py +132 -0
  22. ngio/images/_create_utils.py +423 -0
  23. ngio/images/_image.py +926 -0
  24. ngio/images/_label.py +411 -0
  25. ngio/images/_masked_image.py +531 -0
  26. ngio/images/_ome_zarr_container.py +1237 -0
  27. ngio/images/_table_ops.py +471 -0
  28. ngio/io_pipes/__init__.py +75 -0
  29. ngio/io_pipes/_io_pipes.py +361 -0
  30. ngio/io_pipes/_io_pipes_masked.py +488 -0
  31. ngio/io_pipes/_io_pipes_roi.py +146 -0
  32. ngio/io_pipes/_io_pipes_types.py +56 -0
  33. ngio/io_pipes/_match_shape.py +377 -0
  34. ngio/io_pipes/_ops_axes.py +344 -0
  35. ngio/io_pipes/_ops_slices.py +411 -0
  36. ngio/io_pipes/_ops_slices_utils.py +199 -0
  37. ngio/io_pipes/_ops_transforms.py +104 -0
  38. ngio/io_pipes/_zoom_transform.py +180 -0
  39. ngio/ome_zarr_meta/__init__.py +65 -0
  40. ngio/ome_zarr_meta/_meta_handlers.py +536 -0
  41. ngio/ome_zarr_meta/ngio_specs/__init__.py +77 -0
  42. ngio/ome_zarr_meta/ngio_specs/_axes.py +515 -0
  43. ngio/ome_zarr_meta/ngio_specs/_channels.py +462 -0
  44. ngio/ome_zarr_meta/ngio_specs/_dataset.py +89 -0
  45. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +539 -0
  46. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +438 -0
  47. ngio/ome_zarr_meta/ngio_specs/_pixel_size.py +122 -0
  48. ngio/ome_zarr_meta/v04/__init__.py +27 -0
  49. ngio/ome_zarr_meta/v04/_custom_models.py +18 -0
  50. ngio/ome_zarr_meta/v04/_v04_spec.py +473 -0
  51. ngio/ome_zarr_meta/v05/__init__.py +27 -0
  52. ngio/ome_zarr_meta/v05/_custom_models.py +18 -0
  53. ngio/ome_zarr_meta/v05/_v05_spec.py +511 -0
  54. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  55. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  56. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  57. ngio/resources/__init__.py +55 -0
  58. ngio/resources/resource_model.py +36 -0
  59. ngio/tables/__init__.py +43 -0
  60. ngio/tables/_abstract_table.py +270 -0
  61. ngio/tables/_tables_container.py +449 -0
  62. ngio/tables/backends/__init__.py +57 -0
  63. ngio/tables/backends/_abstract_backend.py +240 -0
  64. ngio/tables/backends/_anndata.py +139 -0
  65. ngio/tables/backends/_anndata_utils.py +90 -0
  66. ngio/tables/backends/_csv.py +19 -0
  67. ngio/tables/backends/_json.py +92 -0
  68. ngio/tables/backends/_parquet.py +19 -0
  69. ngio/tables/backends/_py_arrow_backends.py +222 -0
  70. ngio/tables/backends/_table_backends.py +226 -0
  71. ngio/tables/backends/_utils.py +608 -0
  72. ngio/tables/v1/__init__.py +23 -0
  73. ngio/tables/v1/_condition_table.py +71 -0
  74. ngio/tables/v1/_feature_table.py +125 -0
  75. ngio/tables/v1/_generic_table.py +49 -0
  76. ngio/tables/v1/_roi_table.py +575 -0
  77. ngio/transforms/__init__.py +5 -0
  78. ngio/transforms/_zoom.py +19 -0
  79. ngio/utils/__init__.py +45 -0
  80. ngio/utils/_cache.py +48 -0
  81. ngio/utils/_datasets.py +165 -0
  82. ngio/utils/_errors.py +37 -0
  83. ngio/utils/_fractal_fsspec_store.py +42 -0
  84. ngio/utils/_zarr_utils.py +534 -0
  85. ngio-0.5.0b6.dist-info/METADATA +148 -0
  86. ngio-0.5.0b6.dist-info/RECORD +88 -0
  87. ngio-0.5.0b6.dist-info/WHEEL +4 -0
  88. ngio-0.5.0b6.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,488 @@
1
+ from collections.abc import Sequence
2
+
3
+ import dask.array as da
4
+ import numpy as np
5
+ import zarr
6
+ from dask.array import Array as DaskArray
7
+
8
+ from ngio.common._dimensions import Dimensions
9
+ from ngio.common._roi import Roi
10
+ from ngio.io_pipes._io_pipes import (
11
+ DaskGetter,
12
+ DaskSetter,
13
+ DataGetter,
14
+ DataSetter,
15
+ NumpyGetter,
16
+ NumpySetter,
17
+ )
18
+ from ngio.io_pipes._io_pipes_roi import roi_to_slicing_dict
19
+ from ngio.io_pipes._match_shape import dask_match_shape, numpy_match_shape
20
+ from ngio.io_pipes._ops_slices import SlicingInputType
21
+ from ngio.io_pipes._ops_transforms import TransformProtocol
22
+ from ngio.io_pipes._zoom_transform import BaseZoomTransform
23
+
24
+ ##############################################################
25
+ #
26
+ # Numpy Pipes
27
+ #
28
+ ##############################################################
29
+
30
+
31
+ def _numpy_label_to_bool_mask(
32
+ label_data: np.ndarray,
33
+ label: int | None,
34
+ data_shape: tuple[int, ...],
35
+ label_axes: tuple[str, ...],
36
+ data_axes: tuple[str, ...],
37
+ allow_rescaling: bool = True,
38
+ ) -> np.ndarray:
39
+ """Convert label data to a boolean mask."""
40
+ if label is not None:
41
+ bool_mask = label_data == label
42
+ else:
43
+ bool_mask = label_data != 0
44
+
45
+ bool_mask = numpy_match_shape(
46
+ array=bool_mask,
47
+ reference_shape=data_shape,
48
+ array_axes=label_axes,
49
+ reference_axes=data_axes,
50
+ allow_rescaling=allow_rescaling,
51
+ )
52
+ return bool_mask
53
+
54
+
55
+ def _setup_numpy_getters(
56
+ zarr_array: zarr.Array,
57
+ dimensions: Dimensions,
58
+ label_zarr_array: zarr.Array,
59
+ label_dimensions: Dimensions,
60
+ roi: Roi,
61
+ axes_order: Sequence[str] | None = None,
62
+ transforms: Sequence[TransformProtocol] | None = None,
63
+ label_transforms: Sequence[TransformProtocol] | None = None,
64
+ slicing_dict: dict[str, SlicingInputType] | None = None,
65
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
66
+ allow_rescaling: bool = True,
67
+ remove_channel_selection: bool = False,
68
+ ) -> tuple[NumpyGetter, NumpyGetter, dict[str, SlicingInputType]]:
69
+ """Prepare slice kwargs for getting a masked array."""
70
+ slicing_dict = roi_to_slicing_dict(
71
+ roi=roi,
72
+ pixel_size=dimensions.pixel_size,
73
+ slicing_dict=slicing_dict,
74
+ )
75
+
76
+ data_getter = NumpyGetter(
77
+ zarr_array=zarr_array,
78
+ dimensions=dimensions,
79
+ axes_order=axes_order,
80
+ transforms=transforms,
81
+ slicing_dict=slicing_dict,
82
+ remove_channel_selection=remove_channel_selection,
83
+ )
84
+
85
+ if allow_rescaling:
86
+ _zoom_transform = BaseZoomTransform(
87
+ input_dimensions=dimensions,
88
+ target_dimensions=label_dimensions,
89
+ order="nearest",
90
+ )
91
+ if label_transforms is None or len(label_transforms) == 0:
92
+ label_transforms = [_zoom_transform]
93
+ else:
94
+ label_transforms = [_zoom_transform, *label_transforms]
95
+
96
+ label_slicing_dict = roi_to_slicing_dict(
97
+ roi=roi,
98
+ pixel_size=label_dimensions.pixel_size,
99
+ slicing_dict=label_slicing_dict,
100
+ )
101
+ label_data_getter = NumpyGetter(
102
+ zarr_array=label_zarr_array,
103
+ dimensions=label_dimensions,
104
+ axes_order=axes_order,
105
+ transforms=label_transforms,
106
+ slicing_dict=label_slicing_dict,
107
+ remove_channel_selection=True,
108
+ )
109
+ return data_getter, label_data_getter, slicing_dict
110
+
111
+
112
+ class NumpyGetterMasked(DataGetter[np.ndarray]):
113
+ def __init__(
114
+ self,
115
+ *,
116
+ zarr_array: zarr.Array,
117
+ dimensions: Dimensions,
118
+ label_zarr_array: zarr.Array,
119
+ label_dimensions: Dimensions,
120
+ roi: Roi,
121
+ axes_order: Sequence[str] | None = None,
122
+ transforms: Sequence[TransformProtocol] | None = None,
123
+ label_transforms: Sequence[TransformProtocol] | None = None,
124
+ slicing_dict: dict[str, SlicingInputType] | None = None,
125
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
126
+ fill_value: int | float = 0,
127
+ allow_rescaling: bool = True,
128
+ remove_channel_selection: bool = False,
129
+ ):
130
+ """Prepare slice kwargs for getting a masked array."""
131
+ data_getter, label_data_getter, slicing_dict = _setup_numpy_getters(
132
+ zarr_array=zarr_array,
133
+ dimensions=dimensions,
134
+ label_zarr_array=label_zarr_array,
135
+ label_dimensions=label_dimensions,
136
+ roi=roi,
137
+ axes_order=axes_order,
138
+ transforms=transforms,
139
+ label_transforms=label_transforms,
140
+ slicing_dict=slicing_dict,
141
+ label_slicing_dict=label_slicing_dict,
142
+ allow_rescaling=allow_rescaling,
143
+ remove_channel_selection=remove_channel_selection,
144
+ )
145
+ self._data_getter = data_getter
146
+ self._label_data_getter = label_data_getter
147
+
148
+ self._label_id = roi.label
149
+ self._fill_value = fill_value
150
+ self._allow_rescaling = allow_rescaling
151
+ super().__init__(
152
+ zarr_array=zarr_array,
153
+ slicing_ops=self._data_getter.slicing_ops,
154
+ axes_ops=self._data_getter.axes_ops,
155
+ transforms=self._data_getter.transforms,
156
+ roi=roi,
157
+ )
158
+
159
+ @property
160
+ def label_id(self) -> int | None:
161
+ return self._label_id
162
+
163
+ def get(self) -> np.ndarray:
164
+ """Get the masked data as a numpy array."""
165
+ data = self._data_getter()
166
+ label_data = self._label_data_getter()
167
+
168
+ bool_mask = _numpy_label_to_bool_mask(
169
+ label_data=label_data,
170
+ label=self.label_id,
171
+ data_shape=data.shape,
172
+ label_axes=self._label_data_getter.axes_ops.output_axes,
173
+ data_axes=self._data_getter.axes_ops.output_axes,
174
+ allow_rescaling=self._allow_rescaling,
175
+ )
176
+ if bool_mask.shape != data.shape:
177
+ bool_mask = np.broadcast_to(bool_mask, data.shape)
178
+ masked_data = np.where(bool_mask, data, self._fill_value)
179
+ return masked_data
180
+
181
+
182
+ class NumpySetterMasked(DataSetter[np.ndarray]):
183
+ def __init__(
184
+ self,
185
+ *,
186
+ zarr_array: zarr.Array,
187
+ dimensions: Dimensions,
188
+ label_zarr_array: zarr.Array,
189
+ label_dimensions: Dimensions,
190
+ roi: Roi,
191
+ axes_order: Sequence[str] | None = None,
192
+ transforms: Sequence[TransformProtocol] | None = None,
193
+ label_transforms: Sequence[TransformProtocol] | None = None,
194
+ slicing_dict: dict[str, SlicingInputType] | None = None,
195
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
196
+ allow_rescaling: bool = True,
197
+ remove_channel_selection: bool = False,
198
+ ):
199
+ """Prepare slice kwargs for setting a masked array."""
200
+ _data_getter, _label_data_getter, slicing_dict = _setup_numpy_getters(
201
+ zarr_array=zarr_array,
202
+ dimensions=dimensions,
203
+ label_zarr_array=label_zarr_array,
204
+ label_dimensions=label_dimensions,
205
+ roi=roi,
206
+ axes_order=axes_order,
207
+ transforms=transforms,
208
+ label_transforms=label_transforms,
209
+ slicing_dict=slicing_dict,
210
+ label_slicing_dict=label_slicing_dict,
211
+ allow_rescaling=allow_rescaling,
212
+ remove_channel_selection=remove_channel_selection,
213
+ )
214
+ self._data_getter = _data_getter
215
+ self._label_data_getter = _label_data_getter
216
+ self._label_id = roi.label
217
+ self._allow_rescaling = allow_rescaling
218
+
219
+ self._data_setter = NumpySetter(
220
+ zarr_array=zarr_array,
221
+ dimensions=dimensions,
222
+ axes_order=axes_order,
223
+ transforms=transforms,
224
+ slicing_dict=slicing_dict,
225
+ remove_channel_selection=remove_channel_selection,
226
+ )
227
+ super().__init__(
228
+ zarr_array=zarr_array,
229
+ slicing_ops=self._data_setter.slicing_ops,
230
+ axes_ops=self._data_setter.axes_ops,
231
+ transforms=self._data_setter.transforms,
232
+ roi=roi,
233
+ )
234
+
235
+ @property
236
+ def label_id(self) -> int | None:
237
+ return self._label_id
238
+
239
+ def set(self, patch: np.ndarray) -> None:
240
+ data = self._data_getter()
241
+ label_data = self._label_data_getter()
242
+
243
+ bool_mask = _numpy_label_to_bool_mask(
244
+ label_data=label_data,
245
+ label=self.label_id,
246
+ data_shape=data.shape,
247
+ label_axes=self._label_data_getter.axes_ops.output_axes,
248
+ data_axes=self._data_getter.axes_ops.output_axes,
249
+ allow_rescaling=self._allow_rescaling,
250
+ )
251
+ if bool_mask.shape != data.shape:
252
+ bool_mask = np.broadcast_to(bool_mask, data.shape)
253
+ masked_patch = np.where(bool_mask, patch, data)
254
+ self._data_setter(masked_patch)
255
+
256
+
257
+ ##############################################################
258
+ #
259
+ # Dask Pipes
260
+ #
261
+ ##############################################################
262
+
263
+
264
+ def _dask_label_to_bool_mask(
265
+ label_data: DaskArray,
266
+ label: int | None,
267
+ data_shape: tuple[int, ...],
268
+ label_axes: tuple[str, ...],
269
+ data_axes: tuple[str, ...],
270
+ allow_rescaling: bool = True,
271
+ ) -> DaskArray:
272
+ """Convert label data to a boolean mask."""
273
+ if label is not None:
274
+ bool_mask = label_data == label
275
+ else:
276
+ bool_mask = label_data != 0
277
+
278
+ bool_mask = dask_match_shape(
279
+ array=bool_mask,
280
+ reference_shape=data_shape,
281
+ array_axes=label_axes,
282
+ reference_axes=data_axes,
283
+ allow_rescaling=allow_rescaling,
284
+ )
285
+ return bool_mask
286
+
287
+
288
+ def _setup_dask_getters(
289
+ zarr_array: zarr.Array,
290
+ dimensions: Dimensions,
291
+ label_zarr_array: zarr.Array,
292
+ label_dimensions: Dimensions,
293
+ roi: Roi,
294
+ axes_order: Sequence[str] | None = None,
295
+ transforms: Sequence[TransformProtocol] | None = None,
296
+ label_transforms: Sequence[TransformProtocol] | None = None,
297
+ slicing_dict: dict[str, SlicingInputType] | None = None,
298
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
299
+ allow_rescaling: bool = True,
300
+ remove_channel_selection: bool = False,
301
+ ) -> tuple[DaskGetter, DaskGetter, dict[str, SlicingInputType]]:
302
+ """Prepare slice kwargs for getting a masked array."""
303
+ slicing_dict = roi_to_slicing_dict(
304
+ roi=roi,
305
+ pixel_size=dimensions.pixel_size,
306
+ slicing_dict=slicing_dict,
307
+ )
308
+
309
+ data_getter = DaskGetter(
310
+ zarr_array=zarr_array,
311
+ dimensions=dimensions,
312
+ axes_order=axes_order,
313
+ transforms=transforms,
314
+ slicing_dict=slicing_dict,
315
+ remove_channel_selection=remove_channel_selection,
316
+ )
317
+
318
+ if allow_rescaling:
319
+ _zoom_transform = BaseZoomTransform(
320
+ input_dimensions=dimensions,
321
+ target_dimensions=label_dimensions,
322
+ order="nearest",
323
+ )
324
+ if label_transforms is None or len(label_transforms) == 0:
325
+ label_transforms = [_zoom_transform]
326
+ else:
327
+ label_transforms = [_zoom_transform, *label_transforms]
328
+
329
+ label_slicing_dict = roi_to_slicing_dict(
330
+ roi=roi,
331
+ pixel_size=label_dimensions.pixel_size,
332
+ slicing_dict=label_slicing_dict,
333
+ )
334
+ label_data_getter = DaskGetter(
335
+ zarr_array=label_zarr_array,
336
+ dimensions=label_dimensions,
337
+ axes_order=axes_order,
338
+ transforms=label_transforms,
339
+ slicing_dict=label_slicing_dict,
340
+ remove_channel_selection=True,
341
+ )
342
+ return data_getter, label_data_getter, slicing_dict
343
+
344
+
345
+ class DaskGetterMasked(DataGetter[DaskArray]):
346
+ def __init__(
347
+ self,
348
+ *,
349
+ zarr_array: zarr.Array,
350
+ dimensions: Dimensions,
351
+ label_zarr_array: zarr.Array,
352
+ label_dimensions: Dimensions,
353
+ roi: Roi,
354
+ axes_order: Sequence[str] | None = None,
355
+ transforms: Sequence[TransformProtocol] | None = None,
356
+ label_transforms: Sequence[TransformProtocol] | None = None,
357
+ slicing_dict: dict[str, SlicingInputType] | None = None,
358
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
359
+ fill_value: int | float = 0,
360
+ allow_rescaling: bool = True,
361
+ remove_channel_selection: bool = False,
362
+ ):
363
+ """Prepare slice kwargs for getting a masked array."""
364
+ _data_getter, _label_data_getter, slicing_dict = _setup_dask_getters(
365
+ zarr_array=zarr_array,
366
+ dimensions=dimensions,
367
+ label_zarr_array=label_zarr_array,
368
+ label_dimensions=label_dimensions,
369
+ roi=roi,
370
+ axes_order=axes_order,
371
+ transforms=transforms,
372
+ label_transforms=label_transforms,
373
+ slicing_dict=slicing_dict,
374
+ label_slicing_dict=label_slicing_dict,
375
+ allow_rescaling=allow_rescaling,
376
+ remove_channel_selection=remove_channel_selection,
377
+ )
378
+ self._data_getter = _data_getter
379
+ self._label_data_getter = _label_data_getter
380
+ self._label_id = roi.label
381
+ self._fill_value = fill_value
382
+ self._allow_rescaling = allow_rescaling
383
+ super().__init__(
384
+ zarr_array=zarr_array,
385
+ slicing_ops=self._data_getter.slicing_ops,
386
+ axes_ops=self._data_getter.axes_ops,
387
+ transforms=self._data_getter.transforms,
388
+ roi=roi,
389
+ )
390
+
391
+ @property
392
+ def label_id(self) -> int | None:
393
+ return self._label_id
394
+
395
+ def get(self) -> DaskArray:
396
+ data = self._data_getter()
397
+ label_data = self._label_data_getter()
398
+ data_shape = tuple(int(dim) for dim in data.shape)
399
+ bool_mask = _dask_label_to_bool_mask(
400
+ label_data=label_data,
401
+ label=self.label_id,
402
+ data_shape=data_shape,
403
+ label_axes=self._label_data_getter.axes_ops.output_axes,
404
+ data_axes=self._data_getter.axes_ops.output_axes,
405
+ allow_rescaling=self._allow_rescaling,
406
+ )
407
+ if bool_mask.shape != data.shape:
408
+ bool_mask = da.broadcast_to(bool_mask, data.shape)
409
+ masked_data = da.where(bool_mask, data, self._fill_value)
410
+ return masked_data
411
+
412
+
413
+ class DaskSetterMasked(DataSetter[DaskArray]):
414
+ def __init__(
415
+ self,
416
+ *,
417
+ zarr_array: zarr.Array,
418
+ dimensions: Dimensions,
419
+ label_zarr_array: zarr.Array,
420
+ label_dimensions: Dimensions,
421
+ roi: Roi,
422
+ axes_order: Sequence[str] | None = None,
423
+ transforms: Sequence[TransformProtocol] | None = None,
424
+ label_transforms: Sequence[TransformProtocol] | None = None,
425
+ slicing_dict: dict[str, SlicingInputType] | None = None,
426
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
427
+ allow_rescaling: bool = True,
428
+ remove_channel_selection: bool = False,
429
+ ):
430
+ """Prepare slice kwargs for setting a masked array."""
431
+ _data_getter, _label_data_getter, slicing_dict = _setup_dask_getters(
432
+ zarr_array=zarr_array,
433
+ dimensions=dimensions,
434
+ label_zarr_array=label_zarr_array,
435
+ label_dimensions=label_dimensions,
436
+ roi=roi,
437
+ axes_order=axes_order,
438
+ transforms=transforms,
439
+ label_transforms=label_transforms,
440
+ slicing_dict=slicing_dict,
441
+ label_slicing_dict=label_slicing_dict,
442
+ allow_rescaling=allow_rescaling,
443
+ remove_channel_selection=remove_channel_selection,
444
+ )
445
+ self._data_getter = _data_getter
446
+ self._label_data_getter = _label_data_getter
447
+
448
+ self._label_id = roi.label
449
+ self._allow_rescaling = allow_rescaling
450
+
451
+ self._data_setter = DaskSetter(
452
+ zarr_array=zarr_array,
453
+ dimensions=dimensions,
454
+ axes_order=axes_order,
455
+ transforms=transforms,
456
+ slicing_dict=slicing_dict,
457
+ remove_channel_selection=remove_channel_selection,
458
+ )
459
+
460
+ super().__init__(
461
+ zarr_array=zarr_array,
462
+ slicing_ops=self._data_setter.slicing_ops,
463
+ axes_ops=self._data_setter.axes_ops,
464
+ transforms=self._data_setter.transforms,
465
+ roi=roi,
466
+ )
467
+
468
+ @property
469
+ def label_id(self) -> int | None:
470
+ return self._label_id
471
+
472
+ def set(self, patch: DaskArray) -> None:
473
+ data = self._data_getter()
474
+ label_data = self._label_data_getter()
475
+ data_shape = tuple(int(dim) for dim in data.shape)
476
+
477
+ bool_mask = _dask_label_to_bool_mask(
478
+ label_data=label_data,
479
+ label=self.label_id,
480
+ data_shape=data_shape,
481
+ label_axes=self._label_data_getter.axes_ops.output_axes,
482
+ data_axes=self._data_getter.axes_ops.output_axes,
483
+ allow_rescaling=self._allow_rescaling,
484
+ )
485
+ if bool_mask.shape != data.shape:
486
+ bool_mask = da.broadcast_to(bool_mask, data.shape)
487
+ masked_patch = da.where(bool_mask, patch, data)
488
+ self._data_setter(masked_patch)
@@ -0,0 +1,146 @@
1
+ from collections.abc import Sequence
2
+
3
+ import zarr
4
+
5
+ from ngio.common._dimensions import Dimensions
6
+ from ngio.common._roi import Roi
7
+ from ngio.io_pipes._io_pipes import (
8
+ DaskGetter,
9
+ DaskSetter,
10
+ NumpyGetter,
11
+ NumpySetter,
12
+ )
13
+ from ngio.io_pipes._ops_slices import SlicingInputType
14
+ from ngio.io_pipes._ops_transforms import TransformProtocol
15
+ from ngio.ome_zarr_meta.ngio_specs._pixel_size import PixelSize
16
+
17
+
18
+ def roi_to_slicing_dict(
19
+ *,
20
+ roi: Roi,
21
+ pixel_size: PixelSize,
22
+ slicing_dict: dict[str, SlicingInputType] | None = None,
23
+ ) -> dict[str, SlicingInputType]:
24
+ """Convert a ROI to a slicing dictionary."""
25
+ roi_slicing_dict: dict[str, SlicingInputType] = roi.to_slicing_dict(
26
+ pixel_size=pixel_size
27
+ ) # type: ignore
28
+ if slicing_dict is None:
29
+ return roi_slicing_dict
30
+
31
+ # Additional slice kwargs can be provided
32
+ # and will override the ones from the ROI
33
+ roi_slicing_dict.update(slicing_dict)
34
+ return roi_slicing_dict
35
+
36
+
37
+ class NumpyRoiGetter(NumpyGetter):
38
+ def __init__(
39
+ self,
40
+ *,
41
+ zarr_array: zarr.Array,
42
+ dimensions: Dimensions,
43
+ roi: Roi,
44
+ axes_order: Sequence[str] | None = None,
45
+ transforms: Sequence[TransformProtocol] | None = None,
46
+ slicing_dict: dict[str, SlicingInputType] | None = None,
47
+ remove_channel_selection: bool = False,
48
+ ) -> None:
49
+ input_slice_kwargs = roi_to_slicing_dict(
50
+ roi=roi,
51
+ pixel_size=dimensions.pixel_size,
52
+ slicing_dict=slicing_dict,
53
+ )
54
+ super().__init__(
55
+ zarr_array=zarr_array,
56
+ dimensions=dimensions,
57
+ axes_order=axes_order,
58
+ transforms=transforms,
59
+ slicing_dict=input_slice_kwargs,
60
+ remove_channel_selection=remove_channel_selection,
61
+ roi=roi,
62
+ )
63
+
64
+
65
+ class DaskRoiGetter(DaskGetter):
66
+ def __init__(
67
+ self,
68
+ *,
69
+ zarr_array: zarr.Array,
70
+ dimensions: Dimensions,
71
+ roi: Roi,
72
+ axes_order: Sequence[str] | None = None,
73
+ transforms: Sequence[TransformProtocol] | None = None,
74
+ slicing_dict: dict[str, SlicingInputType] | None = None,
75
+ remove_channel_selection: bool = False,
76
+ ) -> None:
77
+ input_slice_kwargs = roi_to_slicing_dict(
78
+ roi=roi,
79
+ pixel_size=dimensions.pixel_size,
80
+ slicing_dict=slicing_dict,
81
+ )
82
+ super().__init__(
83
+ zarr_array=zarr_array,
84
+ dimensions=dimensions,
85
+ axes_order=axes_order,
86
+ transforms=transforms,
87
+ slicing_dict=input_slice_kwargs,
88
+ remove_channel_selection=remove_channel_selection,
89
+ roi=roi,
90
+ )
91
+
92
+
93
+ class NumpyRoiSetter(NumpySetter):
94
+ def __init__(
95
+ self,
96
+ *,
97
+ zarr_array: zarr.Array,
98
+ dimensions: Dimensions,
99
+ roi: Roi,
100
+ axes_order: Sequence[str] | None = None,
101
+ transforms: Sequence[TransformProtocol] | None = None,
102
+ slicing_dict: dict[str, SlicingInputType] | None = None,
103
+ remove_channel_selection: bool = False,
104
+ ) -> None:
105
+ input_slice_kwargs = roi_to_slicing_dict(
106
+ roi=roi,
107
+ pixel_size=dimensions.pixel_size,
108
+ slicing_dict=slicing_dict,
109
+ )
110
+ super().__init__(
111
+ zarr_array=zarr_array,
112
+ dimensions=dimensions,
113
+ axes_order=axes_order,
114
+ transforms=transforms,
115
+ slicing_dict=input_slice_kwargs,
116
+ remove_channel_selection=remove_channel_selection,
117
+ roi=roi,
118
+ )
119
+
120
+
121
+ class DaskRoiSetter(DaskSetter):
122
+ def __init__(
123
+ self,
124
+ *,
125
+ zarr_array: zarr.Array,
126
+ dimensions: Dimensions,
127
+ roi: Roi,
128
+ axes_order: Sequence[str] | None = None,
129
+ transforms: Sequence[TransformProtocol] | None = None,
130
+ slicing_dict: dict[str, SlicingInputType] | None = None,
131
+ remove_channel_selection: bool = False,
132
+ ) -> None:
133
+ input_slice_kwargs = roi_to_slicing_dict(
134
+ roi=roi,
135
+ pixel_size=dimensions.pixel_size,
136
+ slicing_dict=slicing_dict,
137
+ )
138
+ super().__init__(
139
+ zarr_array=zarr_array,
140
+ dimensions=dimensions,
141
+ axes_order=axes_order,
142
+ transforms=transforms,
143
+ slicing_dict=input_slice_kwargs,
144
+ remove_channel_selection=remove_channel_selection,
145
+ roi=roi,
146
+ )
@@ -0,0 +1,56 @@
1
+ from collections.abc import Sequence
2
+ from typing import Protocol, TypeVar
3
+
4
+ import zarr
5
+
6
+ from ngio.common._roi import Roi
7
+ from ngio.io_pipes._ops_axes import AxesOps
8
+ from ngio.io_pipes._ops_slices import SlicingOps
9
+ from ngio.io_pipes._ops_transforms import TransformProtocol
10
+
11
+ GetterDataType = TypeVar("GetterDataType", covariant=True)
12
+ SetterDataType = TypeVar("SetterDataType", contravariant=True)
13
+
14
+
15
+ class DataGetterProtocol(Protocol[GetterDataType]):
16
+ @property
17
+ def zarr_array(self) -> zarr.Array: ...
18
+
19
+ @property
20
+ def slicing_ops(self) -> SlicingOps: ...
21
+
22
+ @property
23
+ def axes_ops(self) -> AxesOps: ...
24
+
25
+ @property
26
+ def transforms(self) -> Sequence[TransformProtocol] | None: ...
27
+
28
+ @property
29
+ def roi(self) -> Roi: ...
30
+
31
+ def __call__(self) -> GetterDataType:
32
+ return self.get()
33
+
34
+ def get(self) -> GetterDataType: ...
35
+
36
+
37
+ class DataSetterProtocol(Protocol[SetterDataType]):
38
+ @property
39
+ def zarr_array(self) -> zarr.Array: ...
40
+
41
+ @property
42
+ def slicing_ops(self) -> SlicingOps: ...
43
+
44
+ @property
45
+ def axes_ops(self) -> AxesOps: ...
46
+
47
+ @property
48
+ def transforms(self) -> Sequence[TransformProtocol] | None: ...
49
+
50
+ @property
51
+ def roi(self) -> Roi: ...
52
+
53
+ def __call__(self, patch: SetterDataType) -> None:
54
+ return self.set(patch)
55
+
56
+ def set(self, patch: SetterDataType) -> None: ...