ngio 0.3.5__py3-none-any.whl → 0.4.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ngio/__init__.py +6 -0
  2. ngio/common/__init__.py +50 -48
  3. ngio/common/_array_io_pipes.py +549 -0
  4. ngio/common/_array_io_utils.py +508 -0
  5. ngio/common/_dimensions.py +63 -27
  6. ngio/common/_masking_roi.py +38 -10
  7. ngio/common/_pyramid.py +9 -7
  8. ngio/common/_roi.py +571 -72
  9. ngio/common/_synt_images_utils.py +101 -0
  10. ngio/common/_zoom.py +17 -12
  11. ngio/common/transforms/__init__.py +5 -0
  12. ngio/common/transforms/_label.py +12 -0
  13. ngio/common/transforms/_zoom.py +109 -0
  14. ngio/experimental/__init__.py +5 -0
  15. ngio/experimental/iterators/__init__.py +17 -0
  16. ngio/experimental/iterators/_abstract_iterator.py +170 -0
  17. ngio/experimental/iterators/_feature.py +151 -0
  18. ngio/experimental/iterators/_image_processing.py +169 -0
  19. ngio/experimental/iterators/_rois_utils.py +127 -0
  20. ngio/experimental/iterators/_segmentation.py +278 -0
  21. ngio/hcs/_plate.py +41 -36
  22. ngio/images/__init__.py +22 -1
  23. ngio/images/_abstract_image.py +247 -117
  24. ngio/images/_create.py +15 -15
  25. ngio/images/_create_synt_container.py +128 -0
  26. ngio/images/_image.py +425 -62
  27. ngio/images/_label.py +33 -30
  28. ngio/images/_masked_image.py +396 -122
  29. ngio/images/_ome_zarr_container.py +203 -66
  30. ngio/{common → images}/_table_ops.py +41 -41
  31. ngio/ome_zarr_meta/ngio_specs/__init__.py +2 -8
  32. ngio/ome_zarr_meta/ngio_specs/_axes.py +151 -128
  33. ngio/ome_zarr_meta/ngio_specs/_channels.py +55 -18
  34. ngio/ome_zarr_meta/ngio_specs/_dataset.py +7 -7
  35. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +3 -3
  36. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +11 -68
  37. ngio/ome_zarr_meta/v04/_v04_spec_utils.py +1 -1
  38. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  39. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  40. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  41. ngio/resources/__init__.py +54 -0
  42. ngio/resources/resource_model.py +35 -0
  43. ngio/tables/backends/_abstract_backend.py +5 -6
  44. ngio/tables/backends/_anndata.py +1 -1
  45. ngio/tables/backends/_anndata_utils.py +3 -3
  46. ngio/tables/backends/_non_zarr_backends.py +1 -1
  47. ngio/tables/backends/_table_backends.py +0 -1
  48. ngio/tables/backends/_utils.py +3 -3
  49. ngio/tables/v1/_roi_table.py +156 -69
  50. ngio/utils/__init__.py +2 -3
  51. ngio/utils/_logger.py +19 -0
  52. ngio/utils/_zarr_utils.py +1 -5
  53. {ngio-0.3.5.dist-info → ngio-0.4.0a1.dist-info}/METADATA +3 -1
  54. ngio-0.4.0a1.dist-info/RECORD +76 -0
  55. ngio/common/_array_pipe.py +0 -288
  56. ngio/common/_axes_transforms.py +0 -64
  57. ngio/common/_common_types.py +0 -5
  58. ngio/common/_slicer.py +0 -96
  59. ngio-0.3.5.dist-info/RECORD +0 -61
  60. {ngio-0.3.5.dist-info → ngio-0.4.0a1.dist-info}/WHEEL +0 -0
  61. {ngio-0.3.5.dist-info → ngio-0.4.0a1.dist-info}/licenses/LICENSE +0 -0
ngio/__init__.py CHANGED
@@ -19,12 +19,15 @@ from ngio.hcs import (
19
19
  open_ome_zarr_well,
20
20
  )
21
21
  from ngio.images import (
22
+ ChannelSelectionModel,
22
23
  Image,
23
24
  Label,
24
25
  OmeZarrContainer,
25
26
  create_empty_ome_zarr,
26
27
  create_ome_zarr_from_array,
28
+ create_synthetic_ome_zarr,
27
29
  open_image,
30
+ open_label,
28
31
  open_ome_zarr_container,
29
32
  )
30
33
  from ngio.ome_zarr_meta.ngio_specs import (
@@ -38,6 +41,7 @@ from ngio.ome_zarr_meta.ngio_specs import (
38
41
  __all__ = [
39
42
  "ArrayLike",
40
43
  "AxesSetup",
44
+ "ChannelSelectionModel",
41
45
  "DefaultNgffVersion",
42
46
  "Dimensions",
43
47
  "Image",
@@ -54,7 +58,9 @@ __all__ = [
54
58
  "create_empty_plate",
55
59
  "create_empty_well",
56
60
  "create_ome_zarr_from_array",
61
+ "create_synthetic_ome_zarr",
57
62
  "open_image",
63
+ "open_label",
58
64
  "open_ome_zarr_container",
59
65
  "open_ome_zarr_plate",
60
66
  "open_ome_zarr_well",
ngio/common/__init__.py CHANGED
@@ -1,37 +1,38 @@
1
1
  """Common classes and functions that are used across the package."""
2
2
 
3
- from ngio.common._array_pipe import (
4
- get_masked_pipe,
5
- get_pipe,
6
- set_masked_pipe,
7
- set_pipe,
3
+ from ngio.common._array_io_pipes import (
4
+ build_dask_getter,
5
+ build_dask_setter,
6
+ build_masked_dask_getter,
7
+ build_masked_dask_setter,
8
+ build_masked_numpy_getter,
9
+ build_masked_numpy_setter,
10
+ build_numpy_getter,
11
+ build_numpy_setter,
8
12
  )
9
- from ngio.common._axes_transforms import (
10
- transform_dask_array,
11
- transform_list,
12
- transform_numpy_array,
13
+ from ngio.common._array_io_utils import (
14
+ ArrayLike,
15
+ SlicingInputType,
16
+ TransformProtocol,
17
+ apply_dask_axes_ops,
18
+ apply_numpy_axes_ops,
19
+ apply_sequence_axes_ops,
13
20
  )
14
- from ngio.common._common_types import ArrayLike
15
21
  from ngio.common._dimensions import Dimensions
16
22
  from ngio.common._masking_roi import compute_masking_roi
17
23
  from ngio.common._pyramid import consolidate_pyramid, init_empty_pyramid, on_disk_zoom
18
- from ngio.common._roi import Roi, RoiPixels, roi_to_slice_kwargs
19
- from ngio.common._slicer import (
20
- SliceTransform,
21
- compute_and_slices,
22
- dask_get_slice,
23
- dask_set_slice,
24
- numpy_get_slice,
25
- numpy_set_slice,
26
- )
27
- from ngio.common._table_ops import (
28
- concatenate_image_tables,
29
- concatenate_image_tables_as,
30
- concatenate_image_tables_as_async,
31
- concatenate_image_tables_async,
32
- conctatenate_tables,
33
- list_image_tables,
34
- list_image_tables_async,
24
+ from ngio.common._roi import (
25
+ Roi,
26
+ RoiPixels,
27
+ build_roi_dask_getter,
28
+ build_roi_dask_setter,
29
+ build_roi_masked_dask_getter,
30
+ build_roi_masked_dask_setter,
31
+ build_roi_masked_numpy_getter,
32
+ build_roi_masked_numpy_setter,
33
+ build_roi_numpy_getter,
34
+ build_roi_numpy_setter,
35
+ roi_to_slicing_dict,
35
36
  )
36
37
  from ngio.common._zoom import dask_zoom, numpy_zoom
37
38
 
@@ -40,31 +41,32 @@ __all__ = [
40
41
  "Dimensions",
41
42
  "Roi",
42
43
  "RoiPixels",
43
- "SliceTransform",
44
- "compute_and_slices",
44
+ "SlicingInputType",
45
+ "TransformProtocol",
46
+ "apply_dask_axes_ops",
47
+ "apply_numpy_axes_ops",
48
+ "apply_sequence_axes_ops",
49
+ "build_dask_getter",
50
+ "build_dask_setter",
51
+ "build_masked_dask_getter",
52
+ "build_masked_dask_setter",
53
+ "build_masked_numpy_getter",
54
+ "build_masked_numpy_setter",
55
+ "build_numpy_getter",
56
+ "build_numpy_setter",
57
+ "build_roi_dask_getter",
58
+ "build_roi_dask_setter",
59
+ "build_roi_masked_dask_getter",
60
+ "build_roi_masked_dask_setter",
61
+ "build_roi_masked_numpy_getter",
62
+ "build_roi_masked_numpy_setter",
63
+ "build_roi_numpy_getter",
64
+ "build_roi_numpy_setter",
45
65
  "compute_masking_roi",
46
- "concatenate_image_tables",
47
- "concatenate_image_tables_as",
48
- "concatenate_image_tables_as_async",
49
- "concatenate_image_tables_async",
50
- "conctatenate_tables",
51
66
  "consolidate_pyramid",
52
- "dask_get_slice",
53
- "dask_set_slice",
54
67
  "dask_zoom",
55
- "get_masked_pipe",
56
- "get_pipe",
57
68
  "init_empty_pyramid",
58
- "list_image_tables",
59
- "list_image_tables_async",
60
- "numpy_get_slice",
61
- "numpy_set_slice",
62
69
  "numpy_zoom",
63
70
  "on_disk_zoom",
64
- "roi_to_slice_kwargs",
65
- "set_masked_pipe",
66
- "set_pipe",
67
- "transform_dask_array",
68
- "transform_list",
69
- "transform_numpy_array",
71
+ "roi_to_slicing_dict",
70
72
  ]
@@ -0,0 +1,549 @@
1
+ from collections.abc import Callable, Sequence
2
+
3
+ import dask.array as da
4
+ import numpy as np
5
+ import zarr
6
+ from dask.array import Array as DaskArray
7
+
8
+ from ngio.common._array_io_utils import (
9
+ SlicingInputType,
10
+ TransformProtocol,
11
+ apply_dask_axes_ops,
12
+ apply_dask_transforms,
13
+ apply_inverse_dask_transforms,
14
+ apply_inverse_numpy_transforms,
15
+ apply_numpy_axes_ops,
16
+ apply_numpy_transforms,
17
+ get_slice_as_dask,
18
+ get_slice_as_numpy,
19
+ set_dask_patch,
20
+ set_numpy_patch,
21
+ setup_from_disk_pipe,
22
+ setup_to_disk_pipe,
23
+ )
24
+ from ngio.common._dimensions import Dimensions
25
+ from ngio.common._zoom import dask_zoom, numpy_zoom
26
+ from ngio.ome_zarr_meta.ngio_specs import SlicingOps
27
+
28
+ ##############################################################
29
+ #
30
+ # Concrete "From Disk" Pipes
31
+ #
32
+ ##############################################################
33
+
34
+
35
+ def _numpy_get_pipe(
36
+ zarr_array: zarr.Array,
37
+ slicing_ops: SlicingOps,
38
+ transforms: Sequence[TransformProtocol] | None = None,
39
+ ) -> np.ndarray:
40
+ _array = get_slice_as_numpy(zarr_array, slice_tuple=slicing_ops.slice_tuple)
41
+ _array = apply_numpy_axes_ops(
42
+ _array,
43
+ squeeze_axes=slicing_ops.squeeze_axes,
44
+ transpose_axes=slicing_ops.transpose_axes,
45
+ expand_axes=slicing_ops.expand_axes,
46
+ )
47
+
48
+ _array = apply_numpy_transforms(
49
+ _array, transforms=transforms, slicing_ops=slicing_ops
50
+ )
51
+ return _array
52
+
53
+
54
+ def _dask_get_pipe(
55
+ zarr_array: zarr.Array,
56
+ slicing_ops: SlicingOps,
57
+ transforms: Sequence[TransformProtocol] | None,
58
+ ) -> DaskArray:
59
+ _array = get_slice_as_dask(zarr_array, slice_tuple=slicing_ops.slice_tuple)
60
+ _array = apply_dask_axes_ops(
61
+ _array,
62
+ squeeze_axes=slicing_ops.squeeze_axes,
63
+ transpose_axes=slicing_ops.transpose_axes,
64
+ expand_axes=slicing_ops.expand_axes,
65
+ )
66
+
67
+ _array = apply_dask_transforms(
68
+ _array, transforms=transforms, slicing_ops=slicing_ops
69
+ )
70
+ return _array
71
+
72
+
73
+ def build_numpy_getter(
74
+ *,
75
+ zarr_array: zarr.Array,
76
+ dimensions: Dimensions,
77
+ axes_order: Sequence[str] | None = None,
78
+ transforms: Sequence[TransformProtocol] | None = None,
79
+ slicing_dict: dict[str, SlicingInputType] | None = None,
80
+ remove_channel_selection: bool = False,
81
+ ) -> Callable[[], np.ndarray]:
82
+ """Get a numpy array from the zarr array with the given slice kwargs."""
83
+ slicing_dict = slicing_dict or {}
84
+ slicing_ops = setup_from_disk_pipe(
85
+ dimensions=dimensions,
86
+ axes_order=axes_order,
87
+ slicing_dict=slicing_dict,
88
+ remove_channel_selection=remove_channel_selection,
89
+ )
90
+ return lambda: _numpy_get_pipe(
91
+ zarr_array=zarr_array,
92
+ slicing_ops=slicing_ops,
93
+ transforms=transforms,
94
+ )
95
+
96
+
97
+ def build_dask_getter(
98
+ *,
99
+ zarr_array: zarr.Array,
100
+ dimensions: Dimensions,
101
+ axes_order: Sequence[str] | None = None,
102
+ transforms: Sequence[TransformProtocol] | None = None,
103
+ slicing_dict: dict[str, SlicingInputType] | None = None,
104
+ remove_channel_selection: bool = False,
105
+ ) -> Callable[[], DaskArray]:
106
+ """Get a dask array from the zarr array with the given slice kwargs."""
107
+ slicing_dict = slicing_dict or {}
108
+ slicing_ops = setup_from_disk_pipe(
109
+ dimensions=dimensions,
110
+ axes_order=axes_order,
111
+ slicing_dict=slicing_dict,
112
+ remove_channel_selection=remove_channel_selection,
113
+ )
114
+ return lambda: _dask_get_pipe(
115
+ zarr_array=zarr_array,
116
+ slicing_ops=slicing_ops,
117
+ transforms=transforms,
118
+ )
119
+
120
+
121
+ ##############################################################
122
+ #
123
+ # Concrete "To Disk" Pipes
124
+ #
125
+ ##############################################################
126
+
127
+
128
+ def _numpy_set_pipe(
129
+ zarr_array: zarr.Array,
130
+ patch: np.ndarray,
131
+ slicing_ops: SlicingOps,
132
+ transforms: Sequence[TransformProtocol] | None,
133
+ ) -> None:
134
+ _patch = apply_inverse_numpy_transforms(
135
+ patch, transforms=transforms, slicing_ops=slicing_ops
136
+ )
137
+ _patch = apply_numpy_axes_ops(
138
+ _patch,
139
+ squeeze_axes=slicing_ops.squeeze_axes,
140
+ transpose_axes=slicing_ops.transpose_axes,
141
+ expand_axes=slicing_ops.expand_axes,
142
+ )
143
+
144
+ if not np.can_cast(_patch.dtype, zarr_array.dtype, casting="safe"):
145
+ raise ValueError(
146
+ f"Cannot safely cast patch of dtype {_patch.dtype} to "
147
+ f"zarr array of dtype {zarr_array.dtype}."
148
+ )
149
+ set_numpy_patch(zarr_array, _patch, slicing_ops.slice_tuple)
150
+
151
+
152
+ def _dask_set_pipe(
153
+ zarr_array: zarr.Array,
154
+ patch: DaskArray,
155
+ slicing_ops: SlicingOps,
156
+ transforms: Sequence[TransformProtocol] | None,
157
+ ) -> None:
158
+ _patch = apply_inverse_dask_transforms(
159
+ patch, transforms=transforms, slicing_ops=slicing_ops
160
+ )
161
+ _patch = apply_dask_axes_ops(
162
+ _patch,
163
+ squeeze_axes=slicing_ops.squeeze_axes,
164
+ transpose_axes=slicing_ops.transpose_axes,
165
+ expand_axes=slicing_ops.expand_axes,
166
+ )
167
+ if not np.can_cast(_patch.dtype, zarr_array.dtype, casting="safe"):
168
+ raise ValueError(
169
+ f"Cannot safely cast patch of dtype {_patch.dtype} to "
170
+ f"zarr array of dtype {zarr_array.dtype}."
171
+ )
172
+ set_dask_patch(zarr_array, _patch, slicing_ops.slice_tuple)
173
+
174
+
175
+ def build_numpy_setter(
176
+ *,
177
+ zarr_array: zarr.Array,
178
+ dimensions: Dimensions,
179
+ axes_order: Sequence[str] | None = None,
180
+ transforms: Sequence[TransformProtocol] | None = None,
181
+ slicing_dict: dict[str, SlicingInputType] | None = None,
182
+ remove_channel_selection: bool = False,
183
+ ) -> Callable[[np.ndarray], None]:
184
+ """Set a numpy array to the zarr array with the given slice kwargs."""
185
+ slicing_dict = slicing_dict or {}
186
+ slicing_ops = setup_to_disk_pipe(
187
+ dimensions=dimensions,
188
+ axes_order=axes_order,
189
+ slicing_dict=slicing_dict,
190
+ remove_channel_selection=remove_channel_selection,
191
+ )
192
+ return lambda patch: _numpy_set_pipe(
193
+ zarr_array=zarr_array,
194
+ patch=patch,
195
+ slicing_ops=slicing_ops,
196
+ transforms=transforms,
197
+ )
198
+
199
+
200
+ def build_dask_setter(
201
+ *,
202
+ zarr_array: zarr.Array,
203
+ dimensions: Dimensions,
204
+ axes_order: Sequence[str] | None = None,
205
+ transforms: Sequence[TransformProtocol] | None = None,
206
+ slicing_dict: dict[str, SlicingInputType] | None = None,
207
+ remove_channel_selection: bool = False,
208
+ ) -> Callable[[DaskArray], None]:
209
+ """Set a dask array to the zarr array with the given slice kwargs."""
210
+ slicing_dict = slicing_dict or {}
211
+ slicing_ops = setup_to_disk_pipe(
212
+ dimensions=dimensions,
213
+ axes_order=axes_order,
214
+ slicing_dict=slicing_dict,
215
+ remove_channel_selection=remove_channel_selection,
216
+ )
217
+ return lambda patch: _dask_set_pipe(
218
+ zarr_array=zarr_array,
219
+ patch=patch,
220
+ slicing_ops=slicing_ops,
221
+ transforms=transforms,
222
+ )
223
+
224
+
225
+ ################################################################
226
+ #
227
+ # Masked Array Pipes
228
+ #
229
+ ################################################################
230
+
231
+
232
+ def _match_data_shape(mask: np.ndarray, data_shape: tuple[int, ...]) -> np.ndarray:
233
+ """Scale the mask data to match the shape of the data."""
234
+ if mask.ndim < len(data_shape):
235
+ mask = np.reshape(mask, (1,) * (len(data_shape) - mask.ndim) + mask.shape)
236
+ elif mask.ndim > len(data_shape):
237
+ raise ValueError(
238
+ "The mask has more dimensions than the data and cannot be matched."
239
+ )
240
+
241
+ zoom_factors = []
242
+ for s_d, s_m in zip(data_shape, mask.shape, strict=True):
243
+ if s_m == s_d:
244
+ zoom_factors.append(1.0)
245
+ elif s_m == 1:
246
+ zoom_factors.append(s_d) # expand singleton
247
+ else:
248
+ zoom_factors.append(s_d / s_m)
249
+
250
+ mask_matched: np.ndarray = numpy_zoom(mask, scale=tuple(zoom_factors), order=0)
251
+ return mask_matched
252
+
253
+
254
+ def _label_to_bool_mask_numpy(
255
+ label_data: np.ndarray | DaskArray,
256
+ label: int | None = None,
257
+ data_shape: tuple[int, ...] | None = None,
258
+ allow_scaling: bool = True,
259
+ ) -> np.ndarray:
260
+ """Convert label data to a boolean mask."""
261
+ if label is not None:
262
+ bool_mask = label_data == label
263
+ else:
264
+ bool_mask = label_data != 0
265
+
266
+ if data_shape is not None and label_data.shape != data_shape:
267
+ if allow_scaling:
268
+ bool_mask = _match_data_shape(bool_mask, data_shape)
269
+ else:
270
+ bool_mask = np.broadcast_to(bool_mask, data_shape)
271
+ return bool_mask
272
+
273
+
274
+ def build_masked_numpy_getter(
275
+ *,
276
+ zarr_array: zarr.Array,
277
+ dimensions: Dimensions,
278
+ label_zarr_array: zarr.Array,
279
+ label_dimensions: Dimensions,
280
+ label_id: int | None,
281
+ axes_order: Sequence[str] | None = None,
282
+ transforms: Sequence[TransformProtocol] | None = None,
283
+ label_transforms: Sequence[TransformProtocol] | None = None,
284
+ slicing_dict: dict[str, SlicingInputType] | None = None,
285
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
286
+ fill_value: int | float = 0,
287
+ allow_scaling: bool = True,
288
+ ) -> Callable[[], np.ndarray]:
289
+ """Get a numpy array from the zarr array with the given slice kwargs."""
290
+ slicing_dict = slicing_dict or {}
291
+ label_slicing_dict = label_slicing_dict or slicing_dict
292
+
293
+ data_getter = build_numpy_getter(
294
+ zarr_array=zarr_array,
295
+ dimensions=dimensions,
296
+ axes_order=axes_order,
297
+ transforms=transforms,
298
+ slicing_dict=slicing_dict,
299
+ remove_channel_selection=False,
300
+ )
301
+
302
+ label_data_getter = build_numpy_getter(
303
+ zarr_array=label_zarr_array,
304
+ dimensions=label_dimensions,
305
+ axes_order=axes_order,
306
+ transforms=label_transforms,
307
+ slicing_dict=label_slicing_dict,
308
+ remove_channel_selection=True,
309
+ )
310
+
311
+ def get_masked_data_as_numpy() -> np.ndarray:
312
+ data = data_getter()
313
+ label_data = label_data_getter()
314
+ bool_mask = _label_to_bool_mask_numpy(
315
+ label_data=label_data,
316
+ label=label_id,
317
+ data_shape=data.shape,
318
+ allow_scaling=allow_scaling,
319
+ )
320
+ masked_data = np.where(bool_mask, data, fill_value)
321
+ return masked_data
322
+
323
+ return get_masked_data_as_numpy
324
+
325
+
326
+ def _match_data_shape_dask(mask: da.Array, data_shape: tuple[int, ...]) -> da.Array:
327
+ """Scale the mask data to match the shape of the data."""
328
+ if mask.ndim < len(data_shape):
329
+ mask = da.reshape(mask, (1,) * (len(data_shape) - mask.ndim) + mask.shape)
330
+ elif mask.ndim > len(data_shape):
331
+ raise ValueError(
332
+ "The mask has more dimensions than the data and cannot be matched."
333
+ )
334
+
335
+ zoom_factors = []
336
+ for s_d, s_m in zip(data_shape, mask.shape, strict=True):
337
+ if s_m == s_d:
338
+ zoom_factors.append(1.0)
339
+ elif s_m == 1:
340
+ zoom_factors.append(s_d) # expand singleton
341
+ else:
342
+ zoom_factors.append(s_d / s_m)
343
+
344
+ mask_matched: da.Array = dask_zoom(mask, scale=tuple(zoom_factors), order=0)
345
+ return mask_matched
346
+
347
+
348
+ def _label_to_bool_mask_dask(
349
+ label_data: DaskArray,
350
+ label: int | None = None,
351
+ data_shape: tuple[int, ...] | None = None,
352
+ allow_scaling: bool = True,
353
+ ) -> DaskArray:
354
+ """Convert label data to a boolean mask for Dask arrays."""
355
+ if label is not None:
356
+ bool_mask = label_data == label
357
+ else:
358
+ bool_mask = label_data != 0
359
+
360
+ if data_shape is not None and label_data.shape != data_shape:
361
+ if allow_scaling:
362
+ bool_mask = _match_data_shape_dask(bool_mask, data_shape)
363
+ else:
364
+ bool_mask = da.broadcast_to(bool_mask, data_shape)
365
+ return bool_mask
366
+
367
+
368
+ def build_masked_dask_getter(
369
+ *,
370
+ zarr_array: zarr.Array,
371
+ dimensions: Dimensions,
372
+ label_zarr_array: zarr.Array,
373
+ label_dimensions: Dimensions,
374
+ label_id: int | None,
375
+ axes_order: Sequence[str] | None = None,
376
+ transforms: Sequence[TransformProtocol] | None = None,
377
+ label_transforms: Sequence[TransformProtocol] | None = None,
378
+ slicing_dict: dict[str, SlicingInputType] | None = None,
379
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
380
+ fill_value: int | float = 0,
381
+ allow_scaling: bool = True,
382
+ ) -> Callable[[], DaskArray]:
383
+ """Get a dask array from the zarr array with the given slice kwargs."""
384
+ slicing_dict = slicing_dict or {}
385
+ label_slicing_dict = label_slicing_dict or slicing_dict
386
+
387
+ data_getter = build_dask_getter(
388
+ zarr_array=zarr_array,
389
+ dimensions=dimensions,
390
+ axes_order=axes_order,
391
+ transforms=transforms,
392
+ slicing_dict=slicing_dict,
393
+ remove_channel_selection=False,
394
+ )
395
+
396
+ label_data_getter = build_dask_getter(
397
+ zarr_array=label_zarr_array,
398
+ dimensions=label_dimensions,
399
+ axes_order=axes_order,
400
+ transforms=label_transforms,
401
+ slicing_dict=label_slicing_dict,
402
+ remove_channel_selection=True,
403
+ )
404
+
405
+ def get_masked_data_as_dask() -> DaskArray:
406
+ data = data_getter()
407
+ label_data = label_data_getter()
408
+ data_shape = tuple(int(dim) for dim in data.shape)
409
+ bool_mask = _label_to_bool_mask_dask(
410
+ label_data=label_data,
411
+ label=label_id,
412
+ data_shape=data_shape,
413
+ allow_scaling=allow_scaling,
414
+ )
415
+ masked_data = da.where(bool_mask, data, fill_value)
416
+ return masked_data
417
+
418
+ return get_masked_data_as_dask
419
+
420
+
421
+ def build_masked_numpy_setter(
422
+ *,
423
+ zarr_array: zarr.Array,
424
+ dimensions: Dimensions,
425
+ label_zarr_array: zarr.Array,
426
+ label_dimensions: Dimensions,
427
+ label_id: int | None,
428
+ axes_order: Sequence[str] | None = None,
429
+ transforms: Sequence[TransformProtocol] | None = None,
430
+ label_transforms: Sequence[TransformProtocol] | None = None,
431
+ slicing_dict: dict[str, SlicingInputType] | None = None,
432
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
433
+ data_getter: Callable[[], np.ndarray] | None = None,
434
+ label_data_getter: Callable[[], np.ndarray] | None = None,
435
+ allow_scaling: bool = True,
436
+ ) -> Callable[[np.ndarray], None]:
437
+ """Set a numpy array to the zarr array with the given slice kwargs."""
438
+ slicing_dict = slicing_dict or {}
439
+ label_slicing_dict = label_slicing_dict or slicing_dict
440
+
441
+ if data_getter is None:
442
+ data_getter = build_numpy_getter(
443
+ zarr_array=zarr_array,
444
+ dimensions=dimensions,
445
+ axes_order=axes_order,
446
+ transforms=transforms,
447
+ slicing_dict=slicing_dict,
448
+ remove_channel_selection=False,
449
+ )
450
+
451
+ if label_data_getter is None:
452
+ label_data_getter = build_numpy_getter(
453
+ zarr_array=label_zarr_array,
454
+ dimensions=label_dimensions,
455
+ axes_order=axes_order,
456
+ transforms=label_transforms,
457
+ slicing_dict=label_slicing_dict,
458
+ remove_channel_selection=True,
459
+ )
460
+
461
+ masked_data_setter = build_numpy_setter(
462
+ zarr_array=zarr_array,
463
+ dimensions=dimensions,
464
+ axes_order=axes_order,
465
+ transforms=transforms,
466
+ slicing_dict=slicing_dict,
467
+ remove_channel_selection=False,
468
+ )
469
+
470
+ def set_patch_masked_as_numpy(patch: np.ndarray) -> None:
471
+ """Set a numpy patch to the array, masked by the label array."""
472
+ data = data_getter()
473
+ label_data = label_data_getter()
474
+ bool_mask = _label_to_bool_mask_numpy(
475
+ label_data=label_data,
476
+ label=label_id,
477
+ data_shape=data.shape,
478
+ allow_scaling=allow_scaling,
479
+ )
480
+ mask_data = np.where(bool_mask, patch, data)
481
+ masked_data_setter(mask_data)
482
+
483
+ return set_patch_masked_as_numpy
484
+
485
+
486
+ def build_masked_dask_setter(
487
+ *,
488
+ zarr_array: zarr.Array,
489
+ dimensions: Dimensions,
490
+ label_zarr_array: zarr.Array,
491
+ label_dimensions: Dimensions,
492
+ label_id: int | None,
493
+ axes_order: Sequence[str] | None = None,
494
+ transforms: Sequence[TransformProtocol] | None = None,
495
+ label_transforms: Sequence[TransformProtocol] | None = None,
496
+ slicing_dict: dict[str, SlicingInputType] | None = None,
497
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
498
+ data_getter: Callable[[], DaskArray] | None = None,
499
+ label_data_getter: Callable[[], DaskArray] | None = None,
500
+ allow_scaling: bool = True,
501
+ ) -> Callable[[DaskArray], None]:
502
+ """Set a dask array to the zarr array with the given slice kwargs."""
503
+ slicing_dict = slicing_dict or {}
504
+ label_slicing_dict = label_slicing_dict or slicing_dict
505
+
506
+ if data_getter is None:
507
+ data_getter = build_dask_getter(
508
+ zarr_array=zarr_array,
509
+ dimensions=dimensions,
510
+ axes_order=axes_order,
511
+ transforms=transforms,
512
+ slicing_dict=slicing_dict,
513
+ remove_channel_selection=False,
514
+ )
515
+
516
+ if label_data_getter is None:
517
+ label_data_getter = build_dask_getter(
518
+ zarr_array=label_zarr_array,
519
+ dimensions=label_dimensions,
520
+ axes_order=axes_order,
521
+ transforms=label_transforms,
522
+ slicing_dict=label_slicing_dict,
523
+ remove_channel_selection=True,
524
+ )
525
+
526
+ data_setter = build_dask_setter(
527
+ zarr_array=zarr_array,
528
+ dimensions=dimensions,
529
+ axes_order=axes_order,
530
+ transforms=transforms,
531
+ slicing_dict=slicing_dict,
532
+ remove_channel_selection=False,
533
+ )
534
+
535
+ def set_patch_masked_as_dask(patch: DaskArray) -> None:
536
+ """Set a dask patch to the array, masked by the label array."""
537
+ data = data_getter()
538
+ label_data = label_data_getter()
539
+ data_shape = tuple(int(dim) for dim in data.shape)
540
+ bool_mask = _label_to_bool_mask_dask(
541
+ label_data=label_data,
542
+ label=label_id,
543
+ data_shape=data_shape,
544
+ allow_scaling=allow_scaling,
545
+ )
546
+ mask_data = da.where(bool_mask, patch, data)
547
+ data_setter(mask_data)
548
+
549
+ return set_patch_masked_as_dask