ngio 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ngio/__init__.py +7 -2
- ngio/common/__init__.py +5 -52
- ngio/common/_dimensions.py +270 -55
- ngio/common/_masking_roi.py +38 -10
- ngio/common/_pyramid.py +51 -30
- ngio/common/_roi.py +269 -82
- ngio/common/_synt_images_utils.py +101 -0
- ngio/common/_zoom.py +49 -19
- ngio/experimental/__init__.py +5 -0
- ngio/experimental/iterators/__init__.py +15 -0
- ngio/experimental/iterators/_abstract_iterator.py +390 -0
- ngio/experimental/iterators/_feature.py +189 -0
- ngio/experimental/iterators/_image_processing.py +130 -0
- ngio/experimental/iterators/_mappers.py +48 -0
- ngio/experimental/iterators/_rois_utils.py +127 -0
- ngio/experimental/iterators/_segmentation.py +235 -0
- ngio/hcs/_plate.py +41 -36
- ngio/images/__init__.py +22 -1
- ngio/images/_abstract_image.py +403 -176
- ngio/images/_create.py +31 -15
- ngio/images/_create_synt_container.py +138 -0
- ngio/images/_image.py +452 -63
- ngio/images/_label.py +56 -30
- ngio/images/_masked_image.py +387 -129
- ngio/images/_ome_zarr_container.py +237 -67
- ngio/{common → images}/_table_ops.py +41 -41
- ngio/io_pipes/__init__.py +75 -0
- ngio/io_pipes/_io_pipes.py +361 -0
- ngio/io_pipes/_io_pipes_masked.py +488 -0
- ngio/io_pipes/_io_pipes_roi.py +152 -0
- ngio/io_pipes/_io_pipes_types.py +56 -0
- ngio/io_pipes/_match_shape.py +376 -0
- ngio/io_pipes/_ops_axes.py +344 -0
- ngio/io_pipes/_ops_slices.py +446 -0
- ngio/io_pipes/_ops_slices_utils.py +196 -0
- ngio/io_pipes/_ops_transforms.py +104 -0
- ngio/io_pipes/_zoom_transform.py +175 -0
- ngio/ome_zarr_meta/__init__.py +4 -2
- ngio/ome_zarr_meta/ngio_specs/__init__.py +4 -10
- ngio/ome_zarr_meta/ngio_specs/_axes.py +186 -175
- ngio/ome_zarr_meta/ngio_specs/_channels.py +55 -18
- ngio/ome_zarr_meta/ngio_specs/_dataset.py +48 -122
- ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +6 -15
- ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +38 -87
- ngio/ome_zarr_meta/ngio_specs/_pixel_size.py +17 -1
- ngio/ome_zarr_meta/v04/_v04_spec_utils.py +34 -31
- ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
- ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
- ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
- ngio/resources/__init__.py +55 -0
- ngio/resources/resource_model.py +36 -0
- ngio/tables/backends/_abstract_backend.py +5 -6
- ngio/tables/backends/_anndata.py +1 -2
- ngio/tables/backends/_anndata_utils.py +3 -3
- ngio/tables/backends/_non_zarr_backends.py +1 -1
- ngio/tables/backends/_table_backends.py +0 -1
- ngio/tables/backends/_utils.py +3 -3
- ngio/tables/v1/_roi_table.py +165 -70
- ngio/transforms/__init__.py +5 -0
- ngio/transforms/_zoom.py +19 -0
- ngio/utils/__init__.py +2 -3
- ngio/utils/_datasets.py +5 -0
- ngio/utils/_logger.py +19 -0
- ngio/utils/_zarr_utils.py +6 -6
- {ngio-0.3.4.dist-info → ngio-0.4.0.dist-info}/METADATA +24 -22
- ngio-0.4.0.dist-info/RECORD +85 -0
- ngio/common/_array_pipe.py +0 -288
- ngio/common/_axes_transforms.py +0 -64
- ngio/common/_common_types.py +0 -5
- ngio/common/_slicer.py +0 -96
- ngio-0.3.4.dist-info/RECORD +0 -61
- {ngio-0.3.4.dist-info → ngio-0.4.0.dist-info}/WHEEL +0 -0
- {ngio-0.3.4.dist-info → ngio-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import dask.array as da
|
|
4
|
+
import numpy as np
|
|
5
|
+
import zarr
|
|
6
|
+
from dask.array import Array as DaskArray
|
|
7
|
+
|
|
8
|
+
from ngio.common._dimensions import Dimensions
|
|
9
|
+
from ngio.common._roi import Roi, RoiPixels
|
|
10
|
+
from ngio.io_pipes._io_pipes import (
|
|
11
|
+
DaskGetter,
|
|
12
|
+
DaskSetter,
|
|
13
|
+
DataGetter,
|
|
14
|
+
DataSetter,
|
|
15
|
+
NumpyGetter,
|
|
16
|
+
NumpySetter,
|
|
17
|
+
)
|
|
18
|
+
from ngio.io_pipes._io_pipes_roi import roi_to_slicing_dict
|
|
19
|
+
from ngio.io_pipes._match_shape import dask_match_shape, numpy_match_shape
|
|
20
|
+
from ngio.io_pipes._ops_slices import SlicingInputType
|
|
21
|
+
from ngio.io_pipes._ops_transforms import TransformProtocol
|
|
22
|
+
from ngio.io_pipes._zoom_transform import BaseZoomTransform
|
|
23
|
+
|
|
24
|
+
##############################################################
|
|
25
|
+
#
|
|
26
|
+
# Numpy Pipes
|
|
27
|
+
#
|
|
28
|
+
##############################################################
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _numpy_label_to_bool_mask(
|
|
32
|
+
label_data: np.ndarray,
|
|
33
|
+
label: int | None,
|
|
34
|
+
data_shape: tuple[int, ...],
|
|
35
|
+
label_axes: tuple[str, ...],
|
|
36
|
+
data_axes: tuple[str, ...],
|
|
37
|
+
allow_rescaling: bool = True,
|
|
38
|
+
) -> np.ndarray:
|
|
39
|
+
"""Convert label data to a boolean mask."""
|
|
40
|
+
if label is not None:
|
|
41
|
+
bool_mask = label_data == label
|
|
42
|
+
else:
|
|
43
|
+
bool_mask = label_data != 0
|
|
44
|
+
|
|
45
|
+
bool_mask = numpy_match_shape(
|
|
46
|
+
array=bool_mask,
|
|
47
|
+
reference_shape=data_shape,
|
|
48
|
+
array_axes=label_axes,
|
|
49
|
+
reference_axes=data_axes,
|
|
50
|
+
allow_rescaling=allow_rescaling,
|
|
51
|
+
)
|
|
52
|
+
return bool_mask
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _setup_numpy_getters(
|
|
56
|
+
zarr_array: zarr.Array,
|
|
57
|
+
dimensions: Dimensions,
|
|
58
|
+
label_zarr_array: zarr.Array,
|
|
59
|
+
label_dimensions: Dimensions,
|
|
60
|
+
roi: Roi | RoiPixels,
|
|
61
|
+
axes_order: Sequence[str] | None = None,
|
|
62
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
63
|
+
label_transforms: Sequence[TransformProtocol] | None = None,
|
|
64
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
65
|
+
label_slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
66
|
+
allow_rescaling: bool = True,
|
|
67
|
+
remove_channel_selection: bool = False,
|
|
68
|
+
) -> tuple[NumpyGetter, NumpyGetter, dict[str, SlicingInputType]]:
|
|
69
|
+
"""Prepare slice kwargs for getting a masked array."""
|
|
70
|
+
slicing_dict = roi_to_slicing_dict(
|
|
71
|
+
roi=roi,
|
|
72
|
+
pixel_size=dimensions.pixel_size,
|
|
73
|
+
slicing_dict=slicing_dict,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
data_getter = NumpyGetter(
|
|
77
|
+
zarr_array=zarr_array,
|
|
78
|
+
dimensions=dimensions,
|
|
79
|
+
axes_order=axes_order,
|
|
80
|
+
transforms=transforms,
|
|
81
|
+
slicing_dict=slicing_dict,
|
|
82
|
+
remove_channel_selection=remove_channel_selection,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if allow_rescaling:
|
|
86
|
+
_zoom_transform = BaseZoomTransform(
|
|
87
|
+
input_dimensions=dimensions,
|
|
88
|
+
target_dimensions=label_dimensions,
|
|
89
|
+
order="nearest",
|
|
90
|
+
)
|
|
91
|
+
if label_transforms is None or len(label_transforms) == 0:
|
|
92
|
+
label_transforms = [_zoom_transform]
|
|
93
|
+
else:
|
|
94
|
+
label_transforms = [_zoom_transform, *label_transforms]
|
|
95
|
+
|
|
96
|
+
label_slicing_dict = roi_to_slicing_dict(
|
|
97
|
+
roi=roi,
|
|
98
|
+
pixel_size=label_dimensions.pixel_size,
|
|
99
|
+
slicing_dict=label_slicing_dict,
|
|
100
|
+
)
|
|
101
|
+
label_data_getter = NumpyGetter(
|
|
102
|
+
zarr_array=label_zarr_array,
|
|
103
|
+
dimensions=label_dimensions,
|
|
104
|
+
axes_order=axes_order,
|
|
105
|
+
transforms=label_transforms,
|
|
106
|
+
slicing_dict=label_slicing_dict,
|
|
107
|
+
remove_channel_selection=True,
|
|
108
|
+
)
|
|
109
|
+
return data_getter, label_data_getter, slicing_dict
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class NumpyGetterMasked(DataGetter[np.ndarray]):
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
*,
|
|
116
|
+
zarr_array: zarr.Array,
|
|
117
|
+
dimensions: Dimensions,
|
|
118
|
+
label_zarr_array: zarr.Array,
|
|
119
|
+
label_dimensions: Dimensions,
|
|
120
|
+
roi: Roi | RoiPixels,
|
|
121
|
+
axes_order: Sequence[str] | None = None,
|
|
122
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
123
|
+
label_transforms: Sequence[TransformProtocol] | None = None,
|
|
124
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
125
|
+
label_slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
126
|
+
fill_value: int | float = 0,
|
|
127
|
+
allow_rescaling: bool = True,
|
|
128
|
+
remove_channel_selection: bool = False,
|
|
129
|
+
):
|
|
130
|
+
"""Prepare slice kwargs for getting a masked array."""
|
|
131
|
+
data_getter, label_data_getter, slicing_dict = _setup_numpy_getters(
|
|
132
|
+
zarr_array=zarr_array,
|
|
133
|
+
dimensions=dimensions,
|
|
134
|
+
label_zarr_array=label_zarr_array,
|
|
135
|
+
label_dimensions=label_dimensions,
|
|
136
|
+
roi=roi,
|
|
137
|
+
axes_order=axes_order,
|
|
138
|
+
transforms=transforms,
|
|
139
|
+
label_transforms=label_transforms,
|
|
140
|
+
slicing_dict=slicing_dict,
|
|
141
|
+
label_slicing_dict=label_slicing_dict,
|
|
142
|
+
allow_rescaling=allow_rescaling,
|
|
143
|
+
remove_channel_selection=remove_channel_selection,
|
|
144
|
+
)
|
|
145
|
+
self._data_getter = data_getter
|
|
146
|
+
self._label_data_getter = label_data_getter
|
|
147
|
+
|
|
148
|
+
self._label_id = roi.label
|
|
149
|
+
self._fill_value = fill_value
|
|
150
|
+
self._allow_rescaling = allow_rescaling
|
|
151
|
+
super().__init__(
|
|
152
|
+
zarr_array=zarr_array,
|
|
153
|
+
slicing_ops=self._data_getter.slicing_ops,
|
|
154
|
+
axes_ops=self._data_getter.axes_ops,
|
|
155
|
+
transforms=self._data_getter.transforms,
|
|
156
|
+
roi=roi,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def label_id(self) -> int | None:
|
|
161
|
+
return self._label_id
|
|
162
|
+
|
|
163
|
+
def get(self) -> np.ndarray:
|
|
164
|
+
"""Get the masked data as a numpy array."""
|
|
165
|
+
data = self._data_getter()
|
|
166
|
+
label_data = self._label_data_getter()
|
|
167
|
+
|
|
168
|
+
bool_mask = _numpy_label_to_bool_mask(
|
|
169
|
+
label_data=label_data,
|
|
170
|
+
label=self.label_id,
|
|
171
|
+
data_shape=data.shape,
|
|
172
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
173
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
174
|
+
allow_rescaling=self._allow_rescaling,
|
|
175
|
+
)
|
|
176
|
+
if bool_mask.shape != data.shape:
|
|
177
|
+
bool_mask = np.broadcast_to(bool_mask, data.shape)
|
|
178
|
+
masked_data = np.where(bool_mask, data, self._fill_value)
|
|
179
|
+
return masked_data
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class NumpySetterMasked(DataSetter[np.ndarray]):
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
*,
|
|
186
|
+
zarr_array: zarr.Array,
|
|
187
|
+
dimensions: Dimensions,
|
|
188
|
+
label_zarr_array: zarr.Array,
|
|
189
|
+
label_dimensions: Dimensions,
|
|
190
|
+
roi: Roi | RoiPixels,
|
|
191
|
+
axes_order: Sequence[str] | None = None,
|
|
192
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
193
|
+
label_transforms: Sequence[TransformProtocol] | None = None,
|
|
194
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
195
|
+
label_slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
196
|
+
allow_rescaling: bool = True,
|
|
197
|
+
remove_channel_selection: bool = False,
|
|
198
|
+
):
|
|
199
|
+
"""Prepare slice kwargs for setting a masked array."""
|
|
200
|
+
_data_getter, _label_data_getter, slicing_dict = _setup_numpy_getters(
|
|
201
|
+
zarr_array=zarr_array,
|
|
202
|
+
dimensions=dimensions,
|
|
203
|
+
label_zarr_array=label_zarr_array,
|
|
204
|
+
label_dimensions=label_dimensions,
|
|
205
|
+
roi=roi,
|
|
206
|
+
axes_order=axes_order,
|
|
207
|
+
transforms=transforms,
|
|
208
|
+
label_transforms=label_transforms,
|
|
209
|
+
slicing_dict=slicing_dict,
|
|
210
|
+
label_slicing_dict=label_slicing_dict,
|
|
211
|
+
allow_rescaling=allow_rescaling,
|
|
212
|
+
remove_channel_selection=remove_channel_selection,
|
|
213
|
+
)
|
|
214
|
+
self._data_getter = _data_getter
|
|
215
|
+
self._label_data_getter = _label_data_getter
|
|
216
|
+
self._label_id = roi.label
|
|
217
|
+
self._allow_rescaling = allow_rescaling
|
|
218
|
+
|
|
219
|
+
self._data_setter = NumpySetter(
|
|
220
|
+
zarr_array=zarr_array,
|
|
221
|
+
dimensions=dimensions,
|
|
222
|
+
axes_order=axes_order,
|
|
223
|
+
transforms=transforms,
|
|
224
|
+
slicing_dict=slicing_dict,
|
|
225
|
+
remove_channel_selection=remove_channel_selection,
|
|
226
|
+
)
|
|
227
|
+
super().__init__(
|
|
228
|
+
zarr_array=zarr_array,
|
|
229
|
+
slicing_ops=self._data_setter.slicing_ops,
|
|
230
|
+
axes_ops=self._data_setter.axes_ops,
|
|
231
|
+
transforms=self._data_setter.transforms,
|
|
232
|
+
roi=roi,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def label_id(self) -> int | None:
|
|
237
|
+
return self._label_id
|
|
238
|
+
|
|
239
|
+
def set(self, patch: np.ndarray) -> None:
|
|
240
|
+
data = self._data_getter()
|
|
241
|
+
label_data = self._label_data_getter()
|
|
242
|
+
|
|
243
|
+
bool_mask = _numpy_label_to_bool_mask(
|
|
244
|
+
label_data=label_data,
|
|
245
|
+
label=self.label_id,
|
|
246
|
+
data_shape=data.shape,
|
|
247
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
248
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
249
|
+
allow_rescaling=self._allow_rescaling,
|
|
250
|
+
)
|
|
251
|
+
if bool_mask.shape != data.shape:
|
|
252
|
+
bool_mask = np.broadcast_to(bool_mask, data.shape)
|
|
253
|
+
masked_patch = np.where(bool_mask, patch, data)
|
|
254
|
+
self._data_setter(masked_patch)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
##############################################################
|
|
258
|
+
#
|
|
259
|
+
# Dask Pipes
|
|
260
|
+
#
|
|
261
|
+
##############################################################
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _dask_label_to_bool_mask(
|
|
265
|
+
label_data: DaskArray,
|
|
266
|
+
label: int | None,
|
|
267
|
+
data_shape: tuple[int, ...],
|
|
268
|
+
label_axes: tuple[str, ...],
|
|
269
|
+
data_axes: tuple[str, ...],
|
|
270
|
+
allow_rescaling: bool = True,
|
|
271
|
+
) -> DaskArray:
|
|
272
|
+
"""Convert label data to a boolean mask."""
|
|
273
|
+
if label is not None:
|
|
274
|
+
bool_mask = label_data == label
|
|
275
|
+
else:
|
|
276
|
+
bool_mask = label_data != 0
|
|
277
|
+
|
|
278
|
+
bool_mask = dask_match_shape(
|
|
279
|
+
array=bool_mask,
|
|
280
|
+
reference_shape=data_shape,
|
|
281
|
+
array_axes=label_axes,
|
|
282
|
+
reference_axes=data_axes,
|
|
283
|
+
allow_rescaling=allow_rescaling,
|
|
284
|
+
)
|
|
285
|
+
return bool_mask
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _setup_dask_getters(
|
|
289
|
+
zarr_array: zarr.Array,
|
|
290
|
+
dimensions: Dimensions,
|
|
291
|
+
label_zarr_array: zarr.Array,
|
|
292
|
+
label_dimensions: Dimensions,
|
|
293
|
+
roi: Roi | RoiPixels,
|
|
294
|
+
axes_order: Sequence[str] | None = None,
|
|
295
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
296
|
+
label_transforms: Sequence[TransformProtocol] | None = None,
|
|
297
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
298
|
+
label_slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
299
|
+
allow_rescaling: bool = True,
|
|
300
|
+
remove_channel_selection: bool = False,
|
|
301
|
+
) -> tuple[DaskGetter, DaskGetter, dict[str, SlicingInputType]]:
|
|
302
|
+
"""Prepare slice kwargs for getting a masked array."""
|
|
303
|
+
slicing_dict = roi_to_slicing_dict(
|
|
304
|
+
roi=roi,
|
|
305
|
+
pixel_size=dimensions.pixel_size,
|
|
306
|
+
slicing_dict=slicing_dict,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
data_getter = DaskGetter(
|
|
310
|
+
zarr_array=zarr_array,
|
|
311
|
+
dimensions=dimensions,
|
|
312
|
+
axes_order=axes_order,
|
|
313
|
+
transforms=transforms,
|
|
314
|
+
slicing_dict=slicing_dict,
|
|
315
|
+
remove_channel_selection=remove_channel_selection,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if allow_rescaling:
|
|
319
|
+
_zoom_transform = BaseZoomTransform(
|
|
320
|
+
input_dimensions=dimensions,
|
|
321
|
+
target_dimensions=label_dimensions,
|
|
322
|
+
order="nearest",
|
|
323
|
+
)
|
|
324
|
+
if label_transforms is None or len(label_transforms) == 0:
|
|
325
|
+
label_transforms = [_zoom_transform]
|
|
326
|
+
else:
|
|
327
|
+
label_transforms = [_zoom_transform, *label_transforms]
|
|
328
|
+
|
|
329
|
+
label_slicing_dict = roi_to_slicing_dict(
|
|
330
|
+
roi=roi,
|
|
331
|
+
pixel_size=label_dimensions.pixel_size,
|
|
332
|
+
slicing_dict=label_slicing_dict,
|
|
333
|
+
)
|
|
334
|
+
label_data_getter = DaskGetter(
|
|
335
|
+
zarr_array=label_zarr_array,
|
|
336
|
+
dimensions=label_dimensions,
|
|
337
|
+
axes_order=axes_order,
|
|
338
|
+
transforms=label_transforms,
|
|
339
|
+
slicing_dict=label_slicing_dict,
|
|
340
|
+
remove_channel_selection=True,
|
|
341
|
+
)
|
|
342
|
+
return data_getter, label_data_getter, slicing_dict
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class DaskGetterMasked(DataGetter[DaskArray]):
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
*,
|
|
349
|
+
zarr_array: zarr.Array,
|
|
350
|
+
dimensions: Dimensions,
|
|
351
|
+
label_zarr_array: zarr.Array,
|
|
352
|
+
label_dimensions: Dimensions,
|
|
353
|
+
roi: Roi | RoiPixels,
|
|
354
|
+
axes_order: Sequence[str] | None = None,
|
|
355
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
356
|
+
label_transforms: Sequence[TransformProtocol] | None = None,
|
|
357
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
358
|
+
label_slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
359
|
+
fill_value: int | float = 0,
|
|
360
|
+
allow_rescaling: bool = True,
|
|
361
|
+
remove_channel_selection: bool = False,
|
|
362
|
+
):
|
|
363
|
+
"""Prepare slice kwargs for getting a masked array."""
|
|
364
|
+
_data_getter, _label_data_getter, slicing_dict = _setup_dask_getters(
|
|
365
|
+
zarr_array=zarr_array,
|
|
366
|
+
dimensions=dimensions,
|
|
367
|
+
label_zarr_array=label_zarr_array,
|
|
368
|
+
label_dimensions=label_dimensions,
|
|
369
|
+
roi=roi,
|
|
370
|
+
axes_order=axes_order,
|
|
371
|
+
transforms=transforms,
|
|
372
|
+
label_transforms=label_transforms,
|
|
373
|
+
slicing_dict=slicing_dict,
|
|
374
|
+
label_slicing_dict=label_slicing_dict,
|
|
375
|
+
allow_rescaling=allow_rescaling,
|
|
376
|
+
remove_channel_selection=remove_channel_selection,
|
|
377
|
+
)
|
|
378
|
+
self._data_getter = _data_getter
|
|
379
|
+
self._label_data_getter = _label_data_getter
|
|
380
|
+
self._label_id = roi.label
|
|
381
|
+
self._fill_value = fill_value
|
|
382
|
+
self._allow_rescaling = allow_rescaling
|
|
383
|
+
super().__init__(
|
|
384
|
+
zarr_array=zarr_array,
|
|
385
|
+
slicing_ops=self._data_getter.slicing_ops,
|
|
386
|
+
axes_ops=self._data_getter.axes_ops,
|
|
387
|
+
transforms=self._data_getter.transforms,
|
|
388
|
+
roi=roi,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def label_id(self) -> int | None:
|
|
393
|
+
return self._label_id
|
|
394
|
+
|
|
395
|
+
def get(self) -> DaskArray:
|
|
396
|
+
data = self._data_getter()
|
|
397
|
+
label_data = self._label_data_getter()
|
|
398
|
+
data_shape = tuple(int(dim) for dim in data.shape)
|
|
399
|
+
bool_mask = _dask_label_to_bool_mask(
|
|
400
|
+
label_data=label_data,
|
|
401
|
+
label=self.label_id,
|
|
402
|
+
data_shape=data_shape,
|
|
403
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
404
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
405
|
+
allow_rescaling=self._allow_rescaling,
|
|
406
|
+
)
|
|
407
|
+
if bool_mask.shape != data.shape:
|
|
408
|
+
bool_mask = da.broadcast_to(bool_mask, data.shape)
|
|
409
|
+
masked_data = da.where(bool_mask, data, self._fill_value)
|
|
410
|
+
return masked_data
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class DaskSetterMasked(DataSetter[DaskArray]):
|
|
414
|
+
def __init__(
|
|
415
|
+
self,
|
|
416
|
+
*,
|
|
417
|
+
zarr_array: zarr.Array,
|
|
418
|
+
dimensions: Dimensions,
|
|
419
|
+
label_zarr_array: zarr.Array,
|
|
420
|
+
label_dimensions: Dimensions,
|
|
421
|
+
roi: Roi | RoiPixels,
|
|
422
|
+
axes_order: Sequence[str] | None = None,
|
|
423
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
424
|
+
label_transforms: Sequence[TransformProtocol] | None = None,
|
|
425
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
426
|
+
label_slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
427
|
+
allow_rescaling: bool = True,
|
|
428
|
+
remove_channel_selection: bool = False,
|
|
429
|
+
):
|
|
430
|
+
"""Prepare slice kwargs for setting a masked array."""
|
|
431
|
+
_data_getter, _label_data_getter, slicing_dict = _setup_dask_getters(
|
|
432
|
+
zarr_array=zarr_array,
|
|
433
|
+
dimensions=dimensions,
|
|
434
|
+
label_zarr_array=label_zarr_array,
|
|
435
|
+
label_dimensions=label_dimensions,
|
|
436
|
+
roi=roi,
|
|
437
|
+
axes_order=axes_order,
|
|
438
|
+
transforms=transforms,
|
|
439
|
+
label_transforms=label_transforms,
|
|
440
|
+
slicing_dict=slicing_dict,
|
|
441
|
+
label_slicing_dict=label_slicing_dict,
|
|
442
|
+
allow_rescaling=allow_rescaling,
|
|
443
|
+
remove_channel_selection=remove_channel_selection,
|
|
444
|
+
)
|
|
445
|
+
self._data_getter = _data_getter
|
|
446
|
+
self._label_data_getter = _label_data_getter
|
|
447
|
+
|
|
448
|
+
self._label_id = roi.label
|
|
449
|
+
self._allow_rescaling = allow_rescaling
|
|
450
|
+
|
|
451
|
+
self._data_setter = DaskSetter(
|
|
452
|
+
zarr_array=zarr_array,
|
|
453
|
+
dimensions=dimensions,
|
|
454
|
+
axes_order=axes_order,
|
|
455
|
+
transforms=transforms,
|
|
456
|
+
slicing_dict=slicing_dict,
|
|
457
|
+
remove_channel_selection=remove_channel_selection,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
super().__init__(
|
|
461
|
+
zarr_array=zarr_array,
|
|
462
|
+
slicing_ops=self._data_setter.slicing_ops,
|
|
463
|
+
axes_ops=self._data_setter.axes_ops,
|
|
464
|
+
transforms=self._data_setter.transforms,
|
|
465
|
+
roi=roi,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
@property
|
|
469
|
+
def label_id(self) -> int | None:
|
|
470
|
+
return self._label_id
|
|
471
|
+
|
|
472
|
+
def set(self, patch: DaskArray) -> None:
|
|
473
|
+
data = self._data_getter()
|
|
474
|
+
label_data = self._label_data_getter()
|
|
475
|
+
data_shape = tuple(int(dim) for dim in data.shape)
|
|
476
|
+
|
|
477
|
+
bool_mask = _dask_label_to_bool_mask(
|
|
478
|
+
label_data=label_data,
|
|
479
|
+
label=self.label_id,
|
|
480
|
+
data_shape=data_shape,
|
|
481
|
+
label_axes=self._label_data_getter.axes_ops.output_axes,
|
|
482
|
+
data_axes=self._data_getter.axes_ops.output_axes,
|
|
483
|
+
allow_rescaling=self._allow_rescaling,
|
|
484
|
+
)
|
|
485
|
+
if bool_mask.shape != data.shape:
|
|
486
|
+
bool_mask = da.broadcast_to(bool_mask, data.shape)
|
|
487
|
+
masked_patch = da.where(bool_mask, patch, data)
|
|
488
|
+
self._data_setter(masked_patch)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
import zarr
|
|
4
|
+
|
|
5
|
+
from ngio.common._dimensions import Dimensions
|
|
6
|
+
from ngio.common._roi import Roi, RoiPixels
|
|
7
|
+
from ngio.io_pipes._io_pipes import (
|
|
8
|
+
DaskGetter,
|
|
9
|
+
DaskSetter,
|
|
10
|
+
NumpyGetter,
|
|
11
|
+
NumpySetter,
|
|
12
|
+
)
|
|
13
|
+
from ngio.io_pipes._ops_slices import SlicingInputType
|
|
14
|
+
from ngio.io_pipes._ops_transforms import TransformProtocol
|
|
15
|
+
from ngio.ome_zarr_meta.ngio_specs._pixel_size import PixelSize
|
|
16
|
+
from ngio.utils import NgioValueError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def roi_to_slicing_dict(
|
|
20
|
+
*,
|
|
21
|
+
roi: Roi | RoiPixels,
|
|
22
|
+
pixel_size: PixelSize | None = None,
|
|
23
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
24
|
+
) -> dict[str, SlicingInputType]:
|
|
25
|
+
"""Convert a ROI to a slicing dictionary."""
|
|
26
|
+
if isinstance(roi, Roi):
|
|
27
|
+
if pixel_size is None:
|
|
28
|
+
raise NgioValueError(
|
|
29
|
+
"pixel_size must be provided when converting a Roi to slice_kwargs."
|
|
30
|
+
)
|
|
31
|
+
roi = roi.to_roi_pixels(pixel_size=pixel_size)
|
|
32
|
+
|
|
33
|
+
roi_slicing_dict: dict[str, SlicingInputType] = roi.to_slicing_dict() # type: ignore
|
|
34
|
+
if slicing_dict is None:
|
|
35
|
+
return roi_slicing_dict
|
|
36
|
+
|
|
37
|
+
# Additional slice kwargs can be provided
|
|
38
|
+
# and will override the ones from the ROI
|
|
39
|
+
roi_slicing_dict.update(slicing_dict)
|
|
40
|
+
return roi_slicing_dict
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class NumpyRoiGetter(NumpyGetter):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
*,
|
|
47
|
+
zarr_array: zarr.Array,
|
|
48
|
+
dimensions: Dimensions,
|
|
49
|
+
roi: Roi | RoiPixels,
|
|
50
|
+
axes_order: Sequence[str] | None = None,
|
|
51
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
52
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
53
|
+
remove_channel_selection: bool = False,
|
|
54
|
+
) -> None:
|
|
55
|
+
input_slice_kwargs = roi_to_slicing_dict(
|
|
56
|
+
roi=roi,
|
|
57
|
+
pixel_size=dimensions.pixel_size,
|
|
58
|
+
slicing_dict=slicing_dict,
|
|
59
|
+
)
|
|
60
|
+
super().__init__(
|
|
61
|
+
zarr_array=zarr_array,
|
|
62
|
+
dimensions=dimensions,
|
|
63
|
+
axes_order=axes_order,
|
|
64
|
+
transforms=transforms,
|
|
65
|
+
slicing_dict=input_slice_kwargs,
|
|
66
|
+
remove_channel_selection=remove_channel_selection,
|
|
67
|
+
roi=roi,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class DaskRoiGetter(DaskGetter):
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
*,
|
|
75
|
+
zarr_array: zarr.Array,
|
|
76
|
+
dimensions: Dimensions,
|
|
77
|
+
roi: Roi | RoiPixels,
|
|
78
|
+
axes_order: Sequence[str] | None = None,
|
|
79
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
80
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
81
|
+
remove_channel_selection: bool = False,
|
|
82
|
+
) -> None:
|
|
83
|
+
input_slice_kwargs = roi_to_slicing_dict(
|
|
84
|
+
roi=roi,
|
|
85
|
+
pixel_size=dimensions.pixel_size,
|
|
86
|
+
slicing_dict=slicing_dict,
|
|
87
|
+
)
|
|
88
|
+
super().__init__(
|
|
89
|
+
zarr_array=zarr_array,
|
|
90
|
+
dimensions=dimensions,
|
|
91
|
+
axes_order=axes_order,
|
|
92
|
+
transforms=transforms,
|
|
93
|
+
slicing_dict=input_slice_kwargs,
|
|
94
|
+
remove_channel_selection=remove_channel_selection,
|
|
95
|
+
roi=roi,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class NumpyRoiSetter(NumpySetter):
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
*,
|
|
103
|
+
zarr_array: zarr.Array,
|
|
104
|
+
dimensions: Dimensions,
|
|
105
|
+
roi: Roi | RoiPixels,
|
|
106
|
+
axes_order: Sequence[str] | None = None,
|
|
107
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
108
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
109
|
+
remove_channel_selection: bool = False,
|
|
110
|
+
) -> None:
|
|
111
|
+
input_slice_kwargs = roi_to_slicing_dict(
|
|
112
|
+
roi=roi,
|
|
113
|
+
pixel_size=dimensions.pixel_size,
|
|
114
|
+
slicing_dict=slicing_dict,
|
|
115
|
+
)
|
|
116
|
+
super().__init__(
|
|
117
|
+
zarr_array=zarr_array,
|
|
118
|
+
dimensions=dimensions,
|
|
119
|
+
axes_order=axes_order,
|
|
120
|
+
transforms=transforms,
|
|
121
|
+
slicing_dict=input_slice_kwargs,
|
|
122
|
+
remove_channel_selection=remove_channel_selection,
|
|
123
|
+
roi=roi,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class DaskRoiSetter(DaskSetter):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
*,
|
|
131
|
+
zarr_array: zarr.Array,
|
|
132
|
+
dimensions: Dimensions,
|
|
133
|
+
roi: Roi | RoiPixels,
|
|
134
|
+
axes_order: Sequence[str] | None = None,
|
|
135
|
+
transforms: Sequence[TransformProtocol] | None = None,
|
|
136
|
+
slicing_dict: dict[str, SlicingInputType] | None = None,
|
|
137
|
+
remove_channel_selection: bool = False,
|
|
138
|
+
) -> None:
|
|
139
|
+
input_slice_kwargs = roi_to_slicing_dict(
|
|
140
|
+
roi=roi,
|
|
141
|
+
pixel_size=dimensions.pixel_size,
|
|
142
|
+
slicing_dict=slicing_dict,
|
|
143
|
+
)
|
|
144
|
+
super().__init__(
|
|
145
|
+
zarr_array=zarr_array,
|
|
146
|
+
dimensions=dimensions,
|
|
147
|
+
axes_order=axes_order,
|
|
148
|
+
transforms=transforms,
|
|
149
|
+
slicing_dict=input_slice_kwargs,
|
|
150
|
+
remove_channel_selection=remove_channel_selection,
|
|
151
|
+
roi=roi,
|
|
152
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Protocol, TypeVar
|
|
3
|
+
|
|
4
|
+
import zarr
|
|
5
|
+
|
|
6
|
+
from ngio.common._roi import Roi, RoiPixels
|
|
7
|
+
from ngio.io_pipes._ops_axes import AxesOps
|
|
8
|
+
from ngio.io_pipes._ops_slices import SlicingOps
|
|
9
|
+
from ngio.io_pipes._ops_transforms import TransformProtocol
|
|
10
|
+
|
|
11
|
+
GetterDataType = TypeVar("GetterDataType", covariant=True)
|
|
12
|
+
SetterDataType = TypeVar("SetterDataType", contravariant=True)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DataGetterProtocol(Protocol[GetterDataType]):
|
|
16
|
+
@property
|
|
17
|
+
def zarr_array(self) -> zarr.Array: ...
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def slicing_ops(self) -> SlicingOps: ...
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def axes_ops(self) -> AxesOps: ...
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def transforms(self) -> Sequence[TransformProtocol] | None: ...
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def roi(self) -> Roi | RoiPixels: ...
|
|
30
|
+
|
|
31
|
+
def __call__(self) -> GetterDataType:
|
|
32
|
+
return self.get()
|
|
33
|
+
|
|
34
|
+
def get(self) -> GetterDataType: ...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class DataSetterProtocol(Protocol[SetterDataType]):
|
|
38
|
+
@property
|
|
39
|
+
def zarr_array(self) -> zarr.Array: ...
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def slicing_ops(self) -> SlicingOps: ...
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def axes_ops(self) -> AxesOps: ...
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def transforms(self) -> Sequence[TransformProtocol] | None: ...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def roi(self) -> Roi | RoiPixels: ...
|
|
52
|
+
|
|
53
|
+
def __call__(self, patch: SetterDataType) -> None:
|
|
54
|
+
return self.set(patch)
|
|
55
|
+
|
|
56
|
+
def set(self, patch: SetterDataType) -> None: ...
|