ngio 0.3.5__py3-none-any.whl → 0.4.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ngio/__init__.py +6 -0
  2. ngio/common/__init__.py +50 -48
  3. ngio/common/_array_io_pipes.py +554 -0
  4. ngio/common/_array_io_utils.py +508 -0
  5. ngio/common/_dimensions.py +63 -27
  6. ngio/common/_masking_roi.py +38 -10
  7. ngio/common/_pyramid.py +9 -7
  8. ngio/common/_roi.py +583 -72
  9. ngio/common/_synt_images_utils.py +101 -0
  10. ngio/common/_zoom.py +17 -12
  11. ngio/common/transforms/__init__.py +5 -0
  12. ngio/common/transforms/_label.py +12 -0
  13. ngio/common/transforms/_zoom.py +109 -0
  14. ngio/experimental/__init__.py +5 -0
  15. ngio/experimental/iterators/__init__.py +17 -0
  16. ngio/experimental/iterators/_abstract_iterator.py +170 -0
  17. ngio/experimental/iterators/_feature.py +151 -0
  18. ngio/experimental/iterators/_image_processing.py +169 -0
  19. ngio/experimental/iterators/_rois_utils.py +127 -0
  20. ngio/experimental/iterators/_segmentation.py +282 -0
  21. ngio/hcs/_plate.py +41 -36
  22. ngio/images/__init__.py +22 -1
  23. ngio/images/_abstract_image.py +247 -117
  24. ngio/images/_create.py +15 -15
  25. ngio/images/_create_synt_container.py +128 -0
  26. ngio/images/_image.py +425 -62
  27. ngio/images/_label.py +33 -30
  28. ngio/images/_masked_image.py +396 -122
  29. ngio/images/_ome_zarr_container.py +203 -66
  30. ngio/{common → images}/_table_ops.py +41 -41
  31. ngio/ome_zarr_meta/ngio_specs/__init__.py +2 -8
  32. ngio/ome_zarr_meta/ngio_specs/_axes.py +151 -128
  33. ngio/ome_zarr_meta/ngio_specs/_channels.py +55 -18
  34. ngio/ome_zarr_meta/ngio_specs/_dataset.py +7 -7
  35. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +3 -3
  36. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +11 -68
  37. ngio/ome_zarr_meta/v04/_v04_spec_utils.py +1 -1
  38. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  39. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  40. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  41. ngio/resources/__init__.py +54 -0
  42. ngio/resources/resource_model.py +35 -0
  43. ngio/tables/backends/_abstract_backend.py +5 -6
  44. ngio/tables/backends/_anndata.py +1 -1
  45. ngio/tables/backends/_anndata_utils.py +3 -3
  46. ngio/tables/backends/_non_zarr_backends.py +1 -1
  47. ngio/tables/backends/_table_backends.py +0 -1
  48. ngio/tables/backends/_utils.py +3 -3
  49. ngio/tables/v1/_roi_table.py +156 -69
  50. ngio/utils/__init__.py +2 -3
  51. ngio/utils/_logger.py +19 -0
  52. ngio/utils/_zarr_utils.py +1 -5
  53. {ngio-0.3.5.dist-info → ngio-0.4.0a2.dist-info}/METADATA +3 -1
  54. ngio-0.4.0a2.dist-info/RECORD +76 -0
  55. ngio/common/_array_pipe.py +0 -288
  56. ngio/common/_axes_transforms.py +0 -64
  57. ngio/common/_common_types.py +0 -5
  58. ngio/common/_slicer.py +0 -96
  59. ngio-0.3.5.dist-info/RECORD +0 -61
  60. {ngio-0.3.5.dist-info → ngio-0.4.0a2.dist-info}/WHEEL +0 -0
  61. {ngio-0.3.5.dist-info → ngio-0.4.0a2.dist-info}/licenses/LICENSE +0 -0
ngio/__init__.py CHANGED
@@ -19,12 +19,15 @@ from ngio.hcs import (
19
19
  open_ome_zarr_well,
20
20
  )
21
21
  from ngio.images import (
22
+ ChannelSelectionModel,
22
23
  Image,
23
24
  Label,
24
25
  OmeZarrContainer,
25
26
  create_empty_ome_zarr,
26
27
  create_ome_zarr_from_array,
28
+ create_synthetic_ome_zarr,
27
29
  open_image,
30
+ open_label,
28
31
  open_ome_zarr_container,
29
32
  )
30
33
  from ngio.ome_zarr_meta.ngio_specs import (
@@ -38,6 +41,7 @@ from ngio.ome_zarr_meta.ngio_specs import (
38
41
  __all__ = [
39
42
  "ArrayLike",
40
43
  "AxesSetup",
44
+ "ChannelSelectionModel",
41
45
  "DefaultNgffVersion",
42
46
  "Dimensions",
43
47
  "Image",
@@ -54,7 +58,9 @@ __all__ = [
54
58
  "create_empty_plate",
55
59
  "create_empty_well",
56
60
  "create_ome_zarr_from_array",
61
+ "create_synthetic_ome_zarr",
57
62
  "open_image",
63
+ "open_label",
58
64
  "open_ome_zarr_container",
59
65
  "open_ome_zarr_plate",
60
66
  "open_ome_zarr_well",
ngio/common/__init__.py CHANGED
@@ -1,37 +1,38 @@
1
1
  """Common classes and functions that are used across the package."""
2
2
 
3
- from ngio.common._array_pipe import (
4
- get_masked_pipe,
5
- get_pipe,
6
- set_masked_pipe,
7
- set_pipe,
3
+ from ngio.common._array_io_pipes import (
4
+ build_dask_getter,
5
+ build_dask_setter,
6
+ build_masked_dask_getter,
7
+ build_masked_dask_setter,
8
+ build_masked_numpy_getter,
9
+ build_masked_numpy_setter,
10
+ build_numpy_getter,
11
+ build_numpy_setter,
8
12
  )
9
- from ngio.common._axes_transforms import (
10
- transform_dask_array,
11
- transform_list,
12
- transform_numpy_array,
13
+ from ngio.common._array_io_utils import (
14
+ ArrayLike,
15
+ SlicingInputType,
16
+ TransformProtocol,
17
+ apply_dask_axes_ops,
18
+ apply_numpy_axes_ops,
19
+ apply_sequence_axes_ops,
13
20
  )
14
- from ngio.common._common_types import ArrayLike
15
21
  from ngio.common._dimensions import Dimensions
16
22
  from ngio.common._masking_roi import compute_masking_roi
17
23
  from ngio.common._pyramid import consolidate_pyramid, init_empty_pyramid, on_disk_zoom
18
- from ngio.common._roi import Roi, RoiPixels, roi_to_slice_kwargs
19
- from ngio.common._slicer import (
20
- SliceTransform,
21
- compute_and_slices,
22
- dask_get_slice,
23
- dask_set_slice,
24
- numpy_get_slice,
25
- numpy_set_slice,
26
- )
27
- from ngio.common._table_ops import (
28
- concatenate_image_tables,
29
- concatenate_image_tables_as,
30
- concatenate_image_tables_as_async,
31
- concatenate_image_tables_async,
32
- conctatenate_tables,
33
- list_image_tables,
34
- list_image_tables_async,
24
+ from ngio.common._roi import (
25
+ Roi,
26
+ RoiPixels,
27
+ build_roi_dask_getter,
28
+ build_roi_dask_setter,
29
+ build_roi_masked_dask_getter,
30
+ build_roi_masked_dask_setter,
31
+ build_roi_masked_numpy_getter,
32
+ build_roi_masked_numpy_setter,
33
+ build_roi_numpy_getter,
34
+ build_roi_numpy_setter,
35
+ roi_to_slicing_dict,
35
36
  )
36
37
  from ngio.common._zoom import dask_zoom, numpy_zoom
37
38
 
@@ -40,31 +41,32 @@ __all__ = [
40
41
  "Dimensions",
41
42
  "Roi",
42
43
  "RoiPixels",
43
- "SliceTransform",
44
- "compute_and_slices",
44
+ "SlicingInputType",
45
+ "TransformProtocol",
46
+ "apply_dask_axes_ops",
47
+ "apply_numpy_axes_ops",
48
+ "apply_sequence_axes_ops",
49
+ "build_dask_getter",
50
+ "build_dask_setter",
51
+ "build_masked_dask_getter",
52
+ "build_masked_dask_setter",
53
+ "build_masked_numpy_getter",
54
+ "build_masked_numpy_setter",
55
+ "build_numpy_getter",
56
+ "build_numpy_setter",
57
+ "build_roi_dask_getter",
58
+ "build_roi_dask_setter",
59
+ "build_roi_masked_dask_getter",
60
+ "build_roi_masked_dask_setter",
61
+ "build_roi_masked_numpy_getter",
62
+ "build_roi_masked_numpy_setter",
63
+ "build_roi_numpy_getter",
64
+ "build_roi_numpy_setter",
45
65
  "compute_masking_roi",
46
- "concatenate_image_tables",
47
- "concatenate_image_tables_as",
48
- "concatenate_image_tables_as_async",
49
- "concatenate_image_tables_async",
50
- "conctatenate_tables",
51
66
  "consolidate_pyramid",
52
- "dask_get_slice",
53
- "dask_set_slice",
54
67
  "dask_zoom",
55
- "get_masked_pipe",
56
- "get_pipe",
57
68
  "init_empty_pyramid",
58
- "list_image_tables",
59
- "list_image_tables_async",
60
- "numpy_get_slice",
61
- "numpy_set_slice",
62
69
  "numpy_zoom",
63
70
  "on_disk_zoom",
64
- "roi_to_slice_kwargs",
65
- "set_masked_pipe",
66
- "set_pipe",
67
- "transform_dask_array",
68
- "transform_list",
69
- "transform_numpy_array",
71
+ "roi_to_slicing_dict",
70
72
  ]
@@ -0,0 +1,554 @@
1
+ from collections.abc import Callable, Sequence
2
+
3
+ import dask.array as da
4
+ import numpy as np
5
+ import zarr
6
+ from dask.array import Array as DaskArray
7
+
8
+ from ngio.common._array_io_utils import (
9
+ SlicingInputType,
10
+ TransformProtocol,
11
+ apply_dask_axes_ops,
12
+ apply_dask_transforms,
13
+ apply_inverse_dask_transforms,
14
+ apply_inverse_numpy_transforms,
15
+ apply_numpy_axes_ops,
16
+ apply_numpy_transforms,
17
+ get_slice_as_dask,
18
+ get_slice_as_numpy,
19
+ set_dask_patch,
20
+ set_numpy_patch,
21
+ setup_from_disk_pipe,
22
+ setup_to_disk_pipe,
23
+ )
24
+ from ngio.common._dimensions import Dimensions
25
+ from ngio.common._zoom import dask_zoom, numpy_zoom
26
+ from ngio.ome_zarr_meta.ngio_specs import SlicingOps
27
+
28
+ ##############################################################
29
+ #
30
+ # Concrete "From Disk" Pipes
31
+ #
32
+ ##############################################################
33
+
34
+
35
+ def _numpy_get_pipe(
36
+ zarr_array: zarr.Array,
37
+ slicing_ops: SlicingOps,
38
+ transforms: Sequence[TransformProtocol] | None = None,
39
+ ) -> np.ndarray:
40
+ _array = get_slice_as_numpy(zarr_array, slice_tuple=slicing_ops.slice_tuple)
41
+ _array = apply_numpy_axes_ops(
42
+ _array,
43
+ squeeze_axes=slicing_ops.squeeze_axes,
44
+ transpose_axes=slicing_ops.transpose_axes,
45
+ expand_axes=slicing_ops.expand_axes,
46
+ )
47
+
48
+ _array = apply_numpy_transforms(
49
+ _array, transforms=transforms, slicing_ops=slicing_ops
50
+ )
51
+ return _array
52
+
53
+
54
+ def _dask_get_pipe(
55
+ zarr_array: zarr.Array,
56
+ slicing_ops: SlicingOps,
57
+ transforms: Sequence[TransformProtocol] | None,
58
+ ) -> DaskArray:
59
+ _array = get_slice_as_dask(zarr_array, slice_tuple=slicing_ops.slice_tuple)
60
+ _array = apply_dask_axes_ops(
61
+ _array,
62
+ squeeze_axes=slicing_ops.squeeze_axes,
63
+ transpose_axes=slicing_ops.transpose_axes,
64
+ expand_axes=slicing_ops.expand_axes,
65
+ )
66
+
67
+ _array = apply_dask_transforms(
68
+ _array, transforms=transforms, slicing_ops=slicing_ops
69
+ )
70
+ return _array
71
+
72
+
73
+ def build_numpy_getter(
74
+ *,
75
+ zarr_array: zarr.Array,
76
+ dimensions: Dimensions,
77
+ axes_order: Sequence[str] | None = None,
78
+ transforms: Sequence[TransformProtocol] | None = None,
79
+ slicing_dict: dict[str, SlicingInputType] | None = None,
80
+ remove_channel_selection: bool = False,
81
+ ) -> Callable[[], np.ndarray]:
82
+ """Get a numpy array from the zarr array with the given slice kwargs."""
83
+ slicing_dict = slicing_dict or {}
84
+ slicing_ops = setup_from_disk_pipe(
85
+ dimensions=dimensions,
86
+ axes_order=axes_order,
87
+ slicing_dict=slicing_dict,
88
+ remove_channel_selection=remove_channel_selection,
89
+ )
90
+
91
+ return lambda: _numpy_get_pipe(
92
+ zarr_array=zarr_array,
93
+ slicing_ops=slicing_ops,
94
+ transforms=transforms,
95
+ )
96
+
97
+
98
+ def build_dask_getter(
99
+ *,
100
+ zarr_array: zarr.Array,
101
+ dimensions: Dimensions,
102
+ axes_order: Sequence[str] | None = None,
103
+ transforms: Sequence[TransformProtocol] | None = None,
104
+ slicing_dict: dict[str, SlicingInputType] | None = None,
105
+ remove_channel_selection: bool = False,
106
+ ) -> Callable[[], DaskArray]:
107
+ """Get a dask array from the zarr array with the given slice kwargs."""
108
+ slicing_dict = slicing_dict or {}
109
+ slicing_ops = setup_from_disk_pipe(
110
+ dimensions=dimensions,
111
+ axes_order=axes_order,
112
+ slicing_dict=slicing_dict,
113
+ remove_channel_selection=remove_channel_selection,
114
+ )
115
+ return lambda: _dask_get_pipe(
116
+ zarr_array=zarr_array,
117
+ slicing_ops=slicing_ops,
118
+ transforms=transforms,
119
+ )
120
+
121
+
122
+ ##############################################################
123
+ #
124
+ # Concrete "To Disk" Pipes
125
+ #
126
+ ##############################################################
127
+
128
+
129
+ def _numpy_set_pipe(
130
+ zarr_array: zarr.Array,
131
+ patch: np.ndarray,
132
+ slicing_ops: SlicingOps,
133
+ transforms: Sequence[TransformProtocol] | None,
134
+ ) -> None:
135
+ _patch = apply_inverse_numpy_transforms(
136
+ patch, transforms=transforms, slicing_ops=slicing_ops
137
+ )
138
+ _patch = apply_numpy_axes_ops(
139
+ _patch,
140
+ squeeze_axes=slicing_ops.squeeze_axes,
141
+ transpose_axes=slicing_ops.transpose_axes,
142
+ expand_axes=slicing_ops.expand_axes,
143
+ )
144
+
145
+ if not np.can_cast(_patch.dtype, zarr_array.dtype, casting="safe"):
146
+ raise ValueError(
147
+ f"Cannot safely cast patch of dtype {_patch.dtype} to "
148
+ f"zarr array of dtype {zarr_array.dtype}."
149
+ )
150
+ set_numpy_patch(zarr_array, _patch, slicing_ops.slice_tuple)
151
+
152
+
153
+ def _dask_set_pipe(
154
+ zarr_array: zarr.Array,
155
+ patch: DaskArray,
156
+ slicing_ops: SlicingOps,
157
+ transforms: Sequence[TransformProtocol] | None,
158
+ ) -> None:
159
+ _patch = apply_inverse_dask_transforms(
160
+ patch, transforms=transforms, slicing_ops=slicing_ops
161
+ )
162
+ _patch = apply_dask_axes_ops(
163
+ _patch,
164
+ squeeze_axes=slicing_ops.squeeze_axes,
165
+ transpose_axes=slicing_ops.transpose_axes,
166
+ expand_axes=slicing_ops.expand_axes,
167
+ )
168
+ if not np.can_cast(_patch.dtype, zarr_array.dtype, casting="safe"):
169
+ raise ValueError(
170
+ f"Cannot safely cast patch of dtype {_patch.dtype} to "
171
+ f"zarr array of dtype {zarr_array.dtype}."
172
+ )
173
+ set_dask_patch(zarr_array, _patch, slicing_ops.slice_tuple)
174
+
175
+
176
+ def build_numpy_setter(
177
+ *,
178
+ zarr_array: zarr.Array,
179
+ dimensions: Dimensions,
180
+ axes_order: Sequence[str] | None = None,
181
+ transforms: Sequence[TransformProtocol] | None = None,
182
+ slicing_dict: dict[str, SlicingInputType] | None = None,
183
+ remove_channel_selection: bool = False,
184
+ ) -> Callable[[np.ndarray], None]:
185
+ """Set a numpy array to the zarr array with the given slice kwargs."""
186
+ slicing_dict = slicing_dict or {}
187
+ slicing_ops = setup_to_disk_pipe(
188
+ dimensions=dimensions,
189
+ axes_order=axes_order,
190
+ slicing_dict=slicing_dict,
191
+ remove_channel_selection=remove_channel_selection,
192
+ )
193
+ return lambda patch: _numpy_set_pipe(
194
+ zarr_array=zarr_array,
195
+ patch=patch,
196
+ slicing_ops=slicing_ops,
197
+ transforms=transforms,
198
+ )
199
+
200
+
201
+ def build_dask_setter(
202
+ *,
203
+ zarr_array: zarr.Array,
204
+ dimensions: Dimensions,
205
+ axes_order: Sequence[str] | None = None,
206
+ transforms: Sequence[TransformProtocol] | None = None,
207
+ slicing_dict: dict[str, SlicingInputType] | None = None,
208
+ remove_channel_selection: bool = False,
209
+ ) -> Callable[[DaskArray], None]:
210
+ """Set a dask array to the zarr array with the given slice kwargs."""
211
+ slicing_dict = slicing_dict or {}
212
+ slicing_ops = setup_to_disk_pipe(
213
+ dimensions=dimensions,
214
+ axes_order=axes_order,
215
+ slicing_dict=slicing_dict,
216
+ remove_channel_selection=remove_channel_selection,
217
+ )
218
+ return lambda patch: _dask_set_pipe(
219
+ zarr_array=zarr_array,
220
+ patch=patch,
221
+ slicing_ops=slicing_ops,
222
+ transforms=transforms,
223
+ )
224
+
225
+
226
+ ################################################################
227
+ #
228
+ # Masked Array Pipes
229
+ #
230
+ ################################################################
231
+
232
+
233
+ def _match_data_shape(mask: np.ndarray, data_shape: tuple[int, ...]) -> np.ndarray:
234
+ """Scale the mask data to match the shape of the data."""
235
+ if mask.ndim < len(data_shape):
236
+ mask = np.reshape(mask, (1,) * (len(data_shape) - mask.ndim) + mask.shape)
237
+ elif mask.ndim > len(data_shape):
238
+ raise ValueError(
239
+ "The mask has more dimensions than the data and cannot be matched."
240
+ )
241
+
242
+ zoom_factors = []
243
+ for s_d, s_m in zip(data_shape, mask.shape, strict=True):
244
+ if s_m == s_d:
245
+ zoom_factors.append(1.0)
246
+ elif s_m == 1:
247
+ zoom_factors.append(s_d) # expand singleton
248
+ else:
249
+ zoom_factors.append(s_d / s_m)
250
+
251
+ mask_matched: np.ndarray = numpy_zoom(mask, scale=tuple(zoom_factors), order=0)
252
+ return mask_matched
253
+
254
+
255
+ def _label_to_bool_mask_numpy(
256
+ label_data: np.ndarray | DaskArray,
257
+ label: int | None = None,
258
+ data_shape: tuple[int, ...] | None = None,
259
+ allow_scaling: bool = True,
260
+ ) -> np.ndarray:
261
+ """Convert label data to a boolean mask."""
262
+ if label is not None:
263
+ bool_mask = label_data == label
264
+ else:
265
+ bool_mask = label_data != 0
266
+
267
+ if data_shape is not None and label_data.shape != data_shape:
268
+ if allow_scaling:
269
+ bool_mask = _match_data_shape(bool_mask, data_shape)
270
+ else:
271
+ bool_mask = np.broadcast_to(bool_mask, data_shape)
272
+ return bool_mask
273
+
274
+
275
+ def build_masked_numpy_getter(
276
+ *,
277
+ zarr_array: zarr.Array,
278
+ dimensions: Dimensions,
279
+ label_zarr_array: zarr.Array,
280
+ label_dimensions: Dimensions,
281
+ label_id: int | None,
282
+ axes_order: Sequence[str] | None = None,
283
+ transforms: Sequence[TransformProtocol] | None = None,
284
+ label_transforms: Sequence[TransformProtocol] | None = None,
285
+ slicing_dict: dict[str, SlicingInputType] | None = None,
286
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
287
+ fill_value: int | float = 0,
288
+ allow_scaling: bool = True,
289
+ remove_channel_selection: bool = False,
290
+ ) -> Callable[[], np.ndarray]:
291
+ """Get a numpy array from the zarr array with the given slice kwargs."""
292
+ slicing_dict = slicing_dict or {}
293
+ label_slicing_dict = label_slicing_dict or slicing_dict
294
+
295
+ data_getter = build_numpy_getter(
296
+ zarr_array=zarr_array,
297
+ dimensions=dimensions,
298
+ axes_order=axes_order,
299
+ transforms=transforms,
300
+ slicing_dict=slicing_dict,
301
+ remove_channel_selection=remove_channel_selection,
302
+ )
303
+
304
+ label_data_getter = build_numpy_getter(
305
+ zarr_array=label_zarr_array,
306
+ dimensions=label_dimensions,
307
+ axes_order=axes_order,
308
+ transforms=label_transforms,
309
+ slicing_dict=label_slicing_dict,
310
+ remove_channel_selection=True,
311
+ )
312
+
313
+ def get_masked_data_as_numpy() -> np.ndarray:
314
+ data = data_getter()
315
+ label_data = label_data_getter()
316
+ bool_mask = _label_to_bool_mask_numpy(
317
+ label_data=label_data,
318
+ label=label_id,
319
+ data_shape=data.shape,
320
+ allow_scaling=allow_scaling,
321
+ )
322
+ masked_data = np.where(bool_mask, data, fill_value)
323
+ return masked_data
324
+
325
+ return get_masked_data_as_numpy
326
+
327
+
328
+ def _match_data_shape_dask(mask: da.Array, data_shape: tuple[int, ...]) -> da.Array:
329
+ """Scale the mask data to match the shape of the data."""
330
+ if mask.ndim < len(data_shape):
331
+ mask = da.reshape(mask, (1,) * (len(data_shape) - mask.ndim) + mask.shape)
332
+ elif mask.ndim > len(data_shape):
333
+ raise ValueError(
334
+ "The mask has more dimensions than the data and cannot be matched."
335
+ )
336
+
337
+ zoom_factors = []
338
+ for s_d, s_m in zip(data_shape, mask.shape, strict=True):
339
+ if s_m == s_d:
340
+ zoom_factors.append(1.0)
341
+ elif s_m == 1:
342
+ zoom_factors.append(s_d) # expand singleton
343
+ else:
344
+ zoom_factors.append(s_d / s_m)
345
+
346
+ mask_matched: da.Array = dask_zoom(mask, scale=tuple(zoom_factors), order=0)
347
+ return mask_matched
348
+
349
+
350
+ def _label_to_bool_mask_dask(
351
+ label_data: DaskArray,
352
+ label: int | None = None,
353
+ data_shape: tuple[int, ...] | None = None,
354
+ allow_scaling: bool = True,
355
+ ) -> DaskArray:
356
+ """Convert label data to a boolean mask for Dask arrays."""
357
+ if label is not None:
358
+ bool_mask = label_data == label
359
+ else:
360
+ bool_mask = label_data != 0
361
+
362
+ if data_shape is not None and label_data.shape != data_shape:
363
+ if allow_scaling:
364
+ bool_mask = _match_data_shape_dask(bool_mask, data_shape)
365
+ else:
366
+ bool_mask = da.broadcast_to(bool_mask, data_shape)
367
+ return bool_mask
368
+
369
+
370
+ def build_masked_dask_getter(
371
+ *,
372
+ zarr_array: zarr.Array,
373
+ dimensions: Dimensions,
374
+ label_zarr_array: zarr.Array,
375
+ label_dimensions: Dimensions,
376
+ label_id: int | None,
377
+ axes_order: Sequence[str] | None = None,
378
+ transforms: Sequence[TransformProtocol] | None = None,
379
+ label_transforms: Sequence[TransformProtocol] | None = None,
380
+ slicing_dict: dict[str, SlicingInputType] | None = None,
381
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
382
+ fill_value: int | float = 0,
383
+ allow_scaling: bool = True,
384
+ remove_channel_selection: bool = False,
385
+ ) -> Callable[[], DaskArray]:
386
+ """Get a dask array from the zarr array with the given slice kwargs."""
387
+ slicing_dict = slicing_dict or {}
388
+ label_slicing_dict = label_slicing_dict or slicing_dict
389
+
390
+ data_getter = build_dask_getter(
391
+ zarr_array=zarr_array,
392
+ dimensions=dimensions,
393
+ axes_order=axes_order,
394
+ transforms=transforms,
395
+ slicing_dict=slicing_dict,
396
+ remove_channel_selection=remove_channel_selection,
397
+ )
398
+
399
+ label_data_getter = build_dask_getter(
400
+ zarr_array=label_zarr_array,
401
+ dimensions=label_dimensions,
402
+ axes_order=axes_order,
403
+ transforms=label_transforms,
404
+ slicing_dict=label_slicing_dict,
405
+ remove_channel_selection=True,
406
+ )
407
+
408
+ def get_masked_data_as_dask() -> DaskArray:
409
+ data = data_getter()
410
+ label_data = label_data_getter()
411
+ data_shape = tuple(int(dim) for dim in data.shape)
412
+ bool_mask = _label_to_bool_mask_dask(
413
+ label_data=label_data,
414
+ label=label_id,
415
+ data_shape=data_shape,
416
+ allow_scaling=allow_scaling,
417
+ )
418
+ masked_data = da.where(bool_mask, data, fill_value)
419
+ return masked_data
420
+
421
+ return get_masked_data_as_dask
422
+
423
+
424
+ def build_masked_numpy_setter(
425
+ *,
426
+ zarr_array: zarr.Array,
427
+ dimensions: Dimensions,
428
+ label_zarr_array: zarr.Array,
429
+ label_dimensions: Dimensions,
430
+ label_id: int | None,
431
+ axes_order: Sequence[str] | None = None,
432
+ transforms: Sequence[TransformProtocol] | None = None,
433
+ label_transforms: Sequence[TransformProtocol] | None = None,
434
+ slicing_dict: dict[str, SlicingInputType] | None = None,
435
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
436
+ data_getter: Callable[[], np.ndarray] | None = None,
437
+ label_data_getter: Callable[[], np.ndarray] | None = None,
438
+ allow_scaling: bool = True,
439
+ remove_channel_selection: bool = False,
440
+ ) -> Callable[[np.ndarray], None]:
441
+ """Set a numpy array to the zarr array with the given slice kwargs."""
442
+ slicing_dict = slicing_dict or {}
443
+ label_slicing_dict = label_slicing_dict or slicing_dict
444
+
445
+ if data_getter is None:
446
+ data_getter = build_numpy_getter(
447
+ zarr_array=zarr_array,
448
+ dimensions=dimensions,
449
+ axes_order=axes_order,
450
+ transforms=transforms,
451
+ slicing_dict=slicing_dict,
452
+ remove_channel_selection=remove_channel_selection,
453
+ )
454
+
455
+ if label_data_getter is None:
456
+ label_data_getter = build_numpy_getter(
457
+ zarr_array=label_zarr_array,
458
+ dimensions=label_dimensions,
459
+ axes_order=axes_order,
460
+ transforms=label_transforms,
461
+ slicing_dict=label_slicing_dict,
462
+ remove_channel_selection=True,
463
+ )
464
+
465
+ masked_data_setter = build_numpy_setter(
466
+ zarr_array=zarr_array,
467
+ dimensions=dimensions,
468
+ axes_order=axes_order,
469
+ transforms=transforms,
470
+ slicing_dict=slicing_dict,
471
+ remove_channel_selection=remove_channel_selection,
472
+ )
473
+
474
+ def set_patch_masked_as_numpy(patch: np.ndarray) -> None:
475
+ """Set a numpy patch to the array, masked by the label array."""
476
+ data = data_getter()
477
+ label_data = label_data_getter()
478
+ bool_mask = _label_to_bool_mask_numpy(
479
+ label_data=label_data,
480
+ label=label_id,
481
+ data_shape=data.shape,
482
+ allow_scaling=allow_scaling,
483
+ )
484
+ mask_data = np.where(bool_mask, patch, data)
485
+ masked_data_setter(mask_data)
486
+
487
+ return set_patch_masked_as_numpy
488
+
489
+
490
+ def build_masked_dask_setter(
491
+ *,
492
+ zarr_array: zarr.Array,
493
+ dimensions: Dimensions,
494
+ label_zarr_array: zarr.Array,
495
+ label_dimensions: Dimensions,
496
+ label_id: int | None,
497
+ axes_order: Sequence[str] | None = None,
498
+ transforms: Sequence[TransformProtocol] | None = None,
499
+ label_transforms: Sequence[TransformProtocol] | None = None,
500
+ slicing_dict: dict[str, SlicingInputType] | None = None,
501
+ label_slicing_dict: dict[str, SlicingInputType] | None = None,
502
+ data_getter: Callable[[], DaskArray] | None = None,
503
+ label_data_getter: Callable[[], DaskArray] | None = None,
504
+ allow_scaling: bool = True,
505
+ remove_channel_selection: bool = False,
506
+ ) -> Callable[[DaskArray], None]:
507
+ """Set a dask array to the zarr array with the given slice kwargs."""
508
+ slicing_dict = slicing_dict or {}
509
+ label_slicing_dict = label_slicing_dict or slicing_dict
510
+
511
+ if data_getter is None:
512
+ data_getter = build_dask_getter(
513
+ zarr_array=zarr_array,
514
+ dimensions=dimensions,
515
+ axes_order=axes_order,
516
+ transforms=transforms,
517
+ slicing_dict=slicing_dict,
518
+ remove_channel_selection=remove_channel_selection,
519
+ )
520
+
521
+ if label_data_getter is None:
522
+ label_data_getter = build_dask_getter(
523
+ zarr_array=label_zarr_array,
524
+ dimensions=label_dimensions,
525
+ axes_order=axes_order,
526
+ transforms=label_transforms,
527
+ slicing_dict=label_slicing_dict,
528
+ remove_channel_selection=True,
529
+ )
530
+
531
+ data_setter = build_dask_setter(
532
+ zarr_array=zarr_array,
533
+ dimensions=dimensions,
534
+ axes_order=axes_order,
535
+ transforms=transforms,
536
+ slicing_dict=slicing_dict,
537
+ remove_channel_selection=remove_channel_selection,
538
+ )
539
+
540
+ def set_patch_masked_as_dask(patch: DaskArray) -> None:
541
+ """Set a dask patch to the array, masked by the label array."""
542
+ data = data_getter()
543
+ label_data = label_data_getter()
544
+ data_shape = tuple(int(dim) for dim in data.shape)
545
+ bool_mask = _label_to_bool_mask_dask(
546
+ label_data=label_data,
547
+ label=label_id,
548
+ data_shape=data_shape,
549
+ allow_scaling=allow_scaling,
550
+ )
551
+ mask_data = da.where(bool_mask, patch, data)
552
+ data_setter(mask_data)
553
+
554
+ return set_patch_masked_as_dask