ngio 0.2.0a2__py3-none-any.whl → 0.5.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. ngio/__init__.py +40 -12
  2. ngio/common/__init__.py +16 -32
  3. ngio/common/_dimensions.py +270 -48
  4. ngio/common/_masking_roi.py +153 -0
  5. ngio/common/_pyramid.py +267 -73
  6. ngio/common/_roi.py +290 -66
  7. ngio/common/_synt_images_utils.py +101 -0
  8. ngio/common/_zoom.py +54 -22
  9. ngio/experimental/__init__.py +5 -0
  10. ngio/experimental/iterators/__init__.py +15 -0
  11. ngio/experimental/iterators/_abstract_iterator.py +390 -0
  12. ngio/experimental/iterators/_feature.py +189 -0
  13. ngio/experimental/iterators/_image_processing.py +130 -0
  14. ngio/experimental/iterators/_mappers.py +48 -0
  15. ngio/experimental/iterators/_rois_utils.py +126 -0
  16. ngio/experimental/iterators/_segmentation.py +235 -0
  17. ngio/hcs/__init__.py +17 -58
  18. ngio/hcs/_plate.py +1354 -0
  19. ngio/images/__init__.py +30 -9
  20. ngio/images/_abstract_image.py +968 -0
  21. ngio/images/_create_synt_container.py +132 -0
  22. ngio/images/_create_utils.py +423 -0
  23. ngio/images/_image.py +926 -0
  24. ngio/images/_label.py +417 -0
  25. ngio/images/_masked_image.py +531 -0
  26. ngio/images/_ome_zarr_container.py +1235 -0
  27. ngio/images/_table_ops.py +471 -0
  28. ngio/io_pipes/__init__.py +75 -0
  29. ngio/io_pipes/_io_pipes.py +361 -0
  30. ngio/io_pipes/_io_pipes_masked.py +488 -0
  31. ngio/io_pipes/_io_pipes_roi.py +146 -0
  32. ngio/io_pipes/_io_pipes_types.py +56 -0
  33. ngio/io_pipes/_match_shape.py +377 -0
  34. ngio/io_pipes/_ops_axes.py +344 -0
  35. ngio/io_pipes/_ops_slices.py +411 -0
  36. ngio/io_pipes/_ops_slices_utils.py +199 -0
  37. ngio/io_pipes/_ops_transforms.py +104 -0
  38. ngio/io_pipes/_zoom_transform.py +180 -0
  39. ngio/ome_zarr_meta/__init__.py +39 -15
  40. ngio/ome_zarr_meta/_meta_handlers.py +490 -96
  41. ngio/ome_zarr_meta/ngio_specs/__init__.py +24 -10
  42. ngio/ome_zarr_meta/ngio_specs/_axes.py +268 -234
  43. ngio/ome_zarr_meta/ngio_specs/_channels.py +125 -41
  44. ngio/ome_zarr_meta/ngio_specs/_dataset.py +42 -87
  45. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +536 -2
  46. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +202 -198
  47. ngio/ome_zarr_meta/ngio_specs/_pixel_size.py +72 -34
  48. ngio/ome_zarr_meta/v04/__init__.py +21 -5
  49. ngio/ome_zarr_meta/v04/_custom_models.py +18 -0
  50. ngio/ome_zarr_meta/v04/{_v04_spec_utils.py → _v04_spec.py} +151 -90
  51. ngio/ome_zarr_meta/v05/__init__.py +27 -0
  52. ngio/ome_zarr_meta/v05/_custom_models.py +18 -0
  53. ngio/ome_zarr_meta/v05/_v05_spec.py +511 -0
  54. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  55. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  56. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  57. ngio/resources/__init__.py +55 -0
  58. ngio/resources/resource_model.py +36 -0
  59. ngio/tables/__init__.py +20 -4
  60. ngio/tables/_abstract_table.py +270 -0
  61. ngio/tables/_tables_container.py +449 -0
  62. ngio/tables/backends/__init__.py +50 -1
  63. ngio/tables/backends/_abstract_backend.py +200 -31
  64. ngio/tables/backends/_anndata.py +139 -0
  65. ngio/tables/backends/_anndata_utils.py +10 -114
  66. ngio/tables/backends/_csv.py +19 -0
  67. ngio/tables/backends/_json.py +92 -0
  68. ngio/tables/backends/_parquet.py +19 -0
  69. ngio/tables/backends/_py_arrow_backends.py +222 -0
  70. ngio/tables/backends/_table_backends.py +162 -38
  71. ngio/tables/backends/_utils.py +608 -0
  72. ngio/tables/v1/__init__.py +19 -4
  73. ngio/tables/v1/_condition_table.py +71 -0
  74. ngio/tables/v1/_feature_table.py +79 -115
  75. ngio/tables/v1/_generic_table.py +21 -90
  76. ngio/tables/v1/_roi_table.py +486 -137
  77. ngio/transforms/__init__.py +5 -0
  78. ngio/transforms/_zoom.py +19 -0
  79. ngio/utils/__init__.py +16 -14
  80. ngio/utils/_cache.py +48 -0
  81. ngio/utils/_datasets.py +121 -13
  82. ngio/utils/_fractal_fsspec_store.py +42 -0
  83. ngio/utils/_zarr_utils.py +374 -218
  84. ngio-0.5.0b4.dist-info/METADATA +147 -0
  85. ngio-0.5.0b4.dist-info/RECORD +88 -0
  86. {ngio-0.2.0a2.dist-info → ngio-0.5.0b4.dist-info}/WHEEL +1 -1
  87. ngio/common/_array_pipe.py +0 -160
  88. ngio/common/_axes_transforms.py +0 -63
  89. ngio/common/_common_types.py +0 -5
  90. ngio/common/_slicer.py +0 -97
  91. ngio/images/abstract_image.py +0 -240
  92. ngio/images/create.py +0 -251
  93. ngio/images/image.py +0 -389
  94. ngio/images/label.py +0 -236
  95. ngio/images/omezarr_container.py +0 -535
  96. ngio/ome_zarr_meta/_generic_handlers.py +0 -320
  97. ngio/ome_zarr_meta/v04/_meta_handlers.py +0 -54
  98. ngio/tables/_validators.py +0 -192
  99. ngio/tables/backends/_anndata_v1.py +0 -75
  100. ngio/tables/backends/_json_v1.py +0 -56
  101. ngio/tables/tables_container.py +0 -300
  102. ngio/tables/v1/_masking_roi_table.py +0 -175
  103. ngio/utils/_logger.py +0 -29
  104. ngio-0.2.0a2.dist-info/METADATA +0 -95
  105. ngio-0.2.0a2.dist-info/RECORD +0 -53
  106. {ngio-0.2.0a2.dist-info → ngio-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,390 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Generator
3
+ from typing import Generic, Literal, Self, TypeVar, overload
4
+
5
+ from ngio import Roi
6
+ from ngio.experimental.iterators._mappers import BasicMapper, MapperProtocol
7
+ from ngio.experimental.iterators._rois_utils import (
8
+ by_chunks,
9
+ by_yx,
10
+ by_zyx,
11
+ grid,
12
+ rois_product,
13
+ )
14
+ from ngio.images._abstract_image import AbstractImage
15
+ from ngio.io_pipes._io_pipes_types import DataGetterProtocol, DataSetterProtocol
16
+ from ngio.io_pipes._ops_slices_utils import check_if_regions_overlap
17
+ from ngio.tables import GenericRoiTable
18
+ from ngio.utils import NgioValueError
19
+
20
+ NumpyPipeType = TypeVar("NumpyPipeType")
21
+ DaskPipeType = TypeVar("DaskPipeType")
22
+
23
+
24
+ class AbstractIteratorBuilder(ABC, Generic[NumpyPipeType, DaskPipeType]):
25
+ """Base class for building iterators over ROIs."""
26
+
27
+ _rois: list[Roi]
28
+ _ref_image: AbstractImage
29
+
30
+ def __repr__(self) -> str:
31
+ return f"{self.__class__.__name__}(regions={len(self._rois)})"
32
+
33
+ @abstractmethod
34
+ def get_init_kwargs(self) -> dict:
35
+ """Return the initialization arguments for the iterator.
36
+
37
+ This is used to clone the iterator with the same parameters
38
+ after every "product" operation.
39
+ """
40
+ pass
41
+
42
+ @property
43
+ def rois(self) -> list[Roi]:
44
+ """Get the list of ROIs for the iterator."""
45
+ return self._rois
46
+
47
+ def _set_rois(self, rois: list[Roi]) -> None:
48
+ """Set the list of ROIs for the iterator."""
49
+ self._rois = rois
50
+
51
+ @property
52
+ def ref_image(self) -> AbstractImage:
53
+ """Get the reference image for the iterator."""
54
+ return self._ref_image
55
+
56
+ def _new_from_rois(self, rois: list[Roi]) -> Self:
57
+ """Create a new instance of the iterator with a different set of ROIs."""
58
+ init_kwargs = self.get_init_kwargs()
59
+ new_instance = self.__class__(**init_kwargs)
60
+ new_instance._set_rois(rois)
61
+ return new_instance
62
+
63
+ def grid(
64
+ self,
65
+ size_x: int | None = None,
66
+ size_y: int | None = None,
67
+ size_z: int | None = None,
68
+ size_t: int | None = None,
69
+ stride_x: int | None = None,
70
+ stride_y: int | None = None,
71
+ stride_z: int | None = None,
72
+ stride_t: int | None = None,
73
+ base_name: str = "",
74
+ ) -> Self:
75
+ """Create a grid of ROIs based on the input image dimensions."""
76
+ rois = grid(
77
+ rois=self.rois,
78
+ ref_image=self.ref_image,
79
+ size_x=size_x,
80
+ size_y=size_y,
81
+ size_z=size_z,
82
+ size_t=size_t,
83
+ stride_x=stride_x,
84
+ stride_y=stride_y,
85
+ stride_z=stride_z,
86
+ stride_t=stride_t,
87
+ base_name=base_name,
88
+ )
89
+ return self._new_from_rois(rois)
90
+
91
+ def by_yx(self) -> Self:
92
+ """Return a new iterator that iterates over ROIs by YX coordinates."""
93
+ rois = by_yx(self.rois, self.ref_image)
94
+ return self._new_from_rois(rois)
95
+
96
+ def by_zyx(self, strict: bool = True) -> Self:
97
+ """Return a new iterator that iterates over ROIs by ZYX coordinates.
98
+
99
+ Args:
100
+ strict (bool): If True, only iterate over ZYX if a Z axis
101
+ is present and not of size 1.
102
+
103
+ """
104
+ rois = by_zyx(self.rois, self.ref_image, strict=strict)
105
+ return self._new_from_rois(rois)
106
+
107
+ def by_chunks(self, overlap_xy: int = 0, overlap_z: int = 0) -> Self:
108
+ """Return a new iterator that iterates over ROIs by chunks.
109
+
110
+ Args:
111
+ overlap_xy (int): Overlap in XY dimensions.
112
+ overlap_z (int): Overlap in Z dimension.
113
+
114
+ Returns:
115
+ SegmentationIterator: A new iterator with chunked ROIs.
116
+ """
117
+ rois = by_chunks(
118
+ self.rois, self.ref_image, overlap_xy=overlap_xy, overlap_z=overlap_z
119
+ )
120
+ return self._new_from_rois(rois)
121
+
122
+ def product(self, other: list[Roi] | GenericRoiTable) -> Self:
123
+ """Cartesian product of the current ROIs with an arbitrary list of ROIs."""
124
+ if isinstance(other, GenericRoiTable):
125
+ other = other.rois()
126
+ rois = rois_product(self.rois, other)
127
+ return self._new_from_rois(rois)
128
+
129
+ @abstractmethod
130
+ def build_numpy_getter(self, roi: Roi) -> DataGetterProtocol[NumpyPipeType]:
131
+ """Build a getter function for the given ROI."""
132
+ raise NotImplementedError
133
+
134
+ @abstractmethod
135
+ def build_numpy_setter(self, roi: Roi) -> DataSetterProtocol[NumpyPipeType] | None:
136
+ """Build a setter function for the given ROI."""
137
+ raise NotImplementedError
138
+
139
+ @abstractmethod
140
+ def build_dask_getter(self, roi: Roi) -> DataGetterProtocol[DaskPipeType]:
141
+ """Build a Dask reader function for the given ROI."""
142
+ raise NotImplementedError
143
+
144
+ @abstractmethod
145
+ def build_dask_setter(self, roi: Roi) -> DataSetterProtocol[DaskPipeType] | None:
146
+ """Build a Dask setter function for the given ROI."""
147
+ raise NotImplementedError
148
+
149
+ @abstractmethod
150
+ def post_consolidate(self) -> None:
151
+ """Post-process the consolidated data."""
152
+ raise NotImplementedError
153
+
154
+ def _numpy_getters_generator(self) -> Generator[DataGetterProtocol[NumpyPipeType]]:
155
+ """Return a list of numpy getter functions for all ROIs."""
156
+ yield from (self.build_numpy_getter(roi) for roi in self.rois)
157
+
158
+ def _dask_getters_generator(self) -> Generator[DataGetterProtocol[DaskPipeType]]:
159
+ """Return a list of dask getter functions for all ROIs."""
160
+ yield from (self.build_dask_getter(roi) for roi in self.rois)
161
+
162
+ def _numpy_setters_generator(
163
+ self,
164
+ ) -> Generator[DataSetterProtocol[NumpyPipeType] | None]:
165
+ """Return a list of numpy setter functions for all ROIs."""
166
+ yield from (self.build_numpy_setter(roi) for roi in self.rois)
167
+
168
+ def _dask_setters_generator(
169
+ self,
170
+ ) -> Generator[DataSetterProtocol[DaskPipeType] | None]:
171
+ """Return a list of dask setter functions for all ROIs."""
172
+ yield from (self.build_dask_setter(roi) for roi in self.rois)
173
+
174
+ def _read_and_write_generator(
175
+ self,
176
+ getters: Generator[
177
+ DataGetterProtocol[NumpyPipeType] | DataGetterProtocol[DaskPipeType]
178
+ ],
179
+ setters: Generator[
180
+ DataSetterProtocol[NumpyPipeType] | DataSetterProtocol[DaskPipeType] | None
181
+ ],
182
+ ) -> Generator[
183
+ tuple[
184
+ DataGetterProtocol[NumpyPipeType] | DataGetterProtocol[DaskPipeType],
185
+ DataSetterProtocol[NumpyPipeType] | DataSetterProtocol[DaskPipeType],
186
+ ]
187
+ ]:
188
+ """Create an iterator over the pixels of the ROIs."""
189
+ for getter, setter in zip(getters, setters, strict=True):
190
+ if setter is None:
191
+ name = self.__class__.__name__
192
+ raise NgioValueError(f"Iterator is read-only: {name}")
193
+ yield getter, setter
194
+ self.post_consolidate()
195
+
196
+ @overload
197
+ def iter(
198
+ self,
199
+ lazy: Literal[True],
200
+ data_mode: Literal["numpy"],
201
+ iterator_mode: Literal["readwrite"],
202
+ ) -> Generator[
203
+ tuple[DataGetterProtocol[NumpyPipeType], DataSetterProtocol[NumpyPipeType]]
204
+ ]: ...
205
+
206
+ @overload
207
+ def iter(
208
+ self,
209
+ lazy: Literal[True],
210
+ data_mode: Literal["numpy"],
211
+ iterator_mode: Literal["readonly"] = ...,
212
+ ) -> Generator[DataGetterProtocol[NumpyPipeType]]: ...
213
+
214
+ @overload
215
+ def iter(
216
+ self,
217
+ lazy: Literal[True],
218
+ data_mode: Literal["dask"],
219
+ iterator_mode: Literal["readwrite"],
220
+ ) -> Generator[
221
+ tuple[DataGetterProtocol[DaskPipeType], DataSetterProtocol[DaskPipeType]]
222
+ ]: ...
223
+
224
+ @overload
225
+ def iter(
226
+ self,
227
+ lazy: Literal[True],
228
+ data_mode: Literal["dask"],
229
+ iterator_mode: Literal["readonly"] = ...,
230
+ ) -> Generator[DataGetterProtocol[DaskPipeType]]: ...
231
+
232
+ @overload
233
+ def iter(
234
+ self,
235
+ lazy: Literal[False],
236
+ data_mode: Literal["numpy"],
237
+ iterator_mode: Literal["readwrite"],
238
+ ) -> Generator[tuple[NumpyPipeType, DataSetterProtocol[NumpyPipeType]]]: ...
239
+
240
+ @overload
241
+ def iter(
242
+ self,
243
+ lazy: Literal[False],
244
+ data_mode: Literal["numpy"],
245
+ iterator_mode: Literal["readonly"] = ...,
246
+ ) -> Generator[NumpyPipeType]: ...
247
+
248
+ @overload
249
+ def iter(
250
+ self,
251
+ lazy: Literal[False],
252
+ data_mode: Literal["dask"],
253
+ iterator_mode: Literal["readwrite"],
254
+ ) -> Generator[tuple[DaskPipeType, DataSetterProtocol[DaskPipeType]]]: ...
255
+
256
+ @overload
257
+ def iter(
258
+ self,
259
+ lazy: Literal[False],
260
+ data_mode: Literal["dask"],
261
+ iterator_mode: Literal["readonly"] = ...,
262
+ ) -> Generator[DaskPipeType]: ...
263
+
264
+ def iter(
265
+ self,
266
+ lazy: bool = False,
267
+ data_mode: Literal["numpy", "dask"] = "dask",
268
+ iterator_mode: Literal["readwrite", "readonly"] = "readwrite",
269
+ ) -> Generator:
270
+ """Create an iterator over the pixels of the ROIs."""
271
+ if data_mode == "numpy":
272
+ getters = self._numpy_getters_generator()
273
+ setters = self._numpy_setters_generator()
274
+ elif data_mode == "dask":
275
+ getters = self._dask_getters_generator()
276
+ setters = self._dask_setters_generator()
277
+ else:
278
+ raise NgioValueError(f"Invalid mode: {data_mode}")
279
+
280
+ if iterator_mode == "readonly":
281
+ if lazy:
282
+ return getters
283
+ else:
284
+ return (getter() for getter in getters)
285
+ if lazy:
286
+ return self._read_and_write_generator(getters, setters)
287
+ else:
288
+ gen = self._read_and_write_generator(getters, setters)
289
+ return ((getter(), setter) for getter, setter in gen)
290
+
291
+ def iter_as_numpy(
292
+ self,
293
+ ):
294
+ """Create an iterator over the pixels of the ROIs."""
295
+ return self.iter(lazy=False, data_mode="numpy", iterator_mode="readwrite")
296
+
297
+ def iter_as_dask(
298
+ self,
299
+ ):
300
+ """Create an iterator over the pixels of the ROIs."""
301
+ return self.iter(lazy=False, data_mode="dask", iterator_mode="readwrite")
302
+
303
+ def map_as_numpy(
304
+ self,
305
+ func: Callable[[NumpyPipeType], NumpyPipeType],
306
+ mapper: MapperProtocol[NumpyPipeType] | None = None,
307
+ ) -> None:
308
+ """Apply a transformation function to the ROI pixels."""
309
+ if mapper is None:
310
+ _mapper = BasicMapper[NumpyPipeType]()
311
+ else:
312
+ _mapper = mapper
313
+
314
+ _mapper(
315
+ func=func,
316
+ getters=self._numpy_getters_generator(),
317
+ setters=self._numpy_setters_generator(),
318
+ )
319
+ self.post_consolidate()
320
+
321
+ def map_as_dask(
322
+ self,
323
+ func: Callable[[DaskPipeType], DaskPipeType],
324
+ mapper: MapperProtocol[DaskPipeType] | None = None,
325
+ ) -> None:
326
+ """Apply a transformation function to the ROI pixels."""
327
+ if mapper is None:
328
+ _mapper = BasicMapper[DaskPipeType]()
329
+ else:
330
+ _mapper = mapper
331
+
332
+ _mapper(
333
+ func=func,
334
+ getters=self._dask_getters_generator(),
335
+ setters=self._dask_setters_generator(),
336
+ )
337
+ self.post_consolidate()
338
+
339
+ def check_if_regions_overlap(self) -> bool:
340
+ """Check if any of the ROIs overlap logically.
341
+
342
+ If two ROIs cover the same pixel, they are considered to overlap.
343
+ This does not consider chunking or other storage details.
344
+
345
+ Returns:
346
+ bool: True if any ROIs overlap. False otherwise.
347
+ """
348
+ if len(self.rois) < 2:
349
+ # Less than 2 ROIs cannot overlap
350
+ return False
351
+
352
+ slicing_tuples = (
353
+ g.slicing_ops.normalized_slicing_tuple
354
+ for g in self._numpy_getters_generator()
355
+ )
356
+ x = check_if_regions_overlap(slicing_tuples)
357
+ return x
358
+
359
+ def require_no_regions_overlap(self) -> None:
360
+ """Ensure that the Iterator's ROIs do not overlap."""
361
+ if self.check_if_regions_overlap():
362
+ raise NgioValueError("Some rois overlap.")
363
+
364
+ def check_if_chunks_overlap(self) -> bool:
365
+ """Check if any of the ROIs overlap in terms of chunks.
366
+
367
+ If two ROIs cover the same chunk, they are considered to overlap in chunks.
368
+ This does not consider pixel-level overlaps.
369
+
370
+ Returns:
371
+ bool: True if any ROIs overlap in chunks, False otherwise.
372
+ """
373
+ from ngio.io_pipes._ops_slices_utils import check_if_chunks_overlap
374
+
375
+ if len(self.rois) < 2:
376
+ # Less than 2 ROIs cannot overlap
377
+ return False
378
+
379
+ slicing_tuples = (
380
+ g.slicing_ops.normalized_slicing_tuple
381
+ for g in self._numpy_getters_generator()
382
+ )
383
+ shape = self.ref_image.shape
384
+ chunks = self.ref_image.chunks
385
+ return check_if_chunks_overlap(slicing_tuples, shape, chunks)
386
+
387
+ def require_no_chunks_overlap(self) -> None:
388
+ """Ensure that the ROIs do not overlap in terms of chunks."""
389
+ if self.check_if_chunks_overlap():
390
+ raise NgioValueError("Some rois overlap in chunks.")
@@ -0,0 +1,189 @@
1
+ from collections.abc import Sequence
2
+ from typing import TypeAlias
3
+
4
+ import dask.array as da
5
+ import numpy as np
6
+
7
+ from ngio.common import Roi
8
+ from ngio.experimental.iterators._abstract_iterator import AbstractIteratorBuilder
9
+ from ngio.images import Image, Label
10
+ from ngio.images._image import (
11
+ ChannelSlicingInputType,
12
+ add_channel_selection_to_slicing_dict,
13
+ )
14
+ from ngio.io_pipes import (
15
+ DaskRoiGetter,
16
+ DataGetter,
17
+ NumpyRoiGetter,
18
+ TransformProtocol,
19
+ )
20
+
21
+ NumpyPipeType: TypeAlias = tuple[np.ndarray, np.ndarray, Roi]
22
+ DaskPipeType: TypeAlias = tuple[da.Array, da.Array, Roi]
23
+
24
+
25
+ class NumpyFeatureGetter(DataGetter[NumpyPipeType]):
26
+ def __init__(
27
+ self,
28
+ image_getter: NumpyRoiGetter,
29
+ label_getter: NumpyRoiGetter,
30
+ ) -> None:
31
+ self._image_getter = image_getter
32
+ self._label_getter = label_getter
33
+ super().__init__(
34
+ zarr_array=self._image_getter.zarr_array,
35
+ slicing_ops=self._image_getter.slicing_ops,
36
+ axes_ops=self._image_getter.axes_ops,
37
+ transforms=self._image_getter.transforms,
38
+ roi=self._image_getter.roi,
39
+ )
40
+
41
+ def get(self) -> NumpyPipeType:
42
+ return self._image_getter(), self._label_getter(), self.roi
43
+
44
+ @property
45
+ def image(self) -> np.ndarray:
46
+ return self._image_getter()
47
+
48
+ @property
49
+ def label(self) -> np.ndarray:
50
+ return self._label_getter()
51
+
52
+
53
+ class DaskFeatureGetter(DataGetter[DaskPipeType]):
54
+ def __init__(
55
+ self,
56
+ image_getter: DaskRoiGetter,
57
+ label_getter: DaskRoiGetter,
58
+ ) -> None:
59
+ self._image_getter = image_getter
60
+ self._label_getter = label_getter
61
+ super().__init__(
62
+ zarr_array=self._image_getter.zarr_array,
63
+ slicing_ops=self._image_getter.slicing_ops,
64
+ axes_ops=self._image_getter.axes_ops,
65
+ transforms=self._image_getter.transforms,
66
+ roi=self._image_getter.roi,
67
+ )
68
+
69
+ def get(self) -> DaskPipeType:
70
+ return self._image_getter(), self._label_getter(), self.roi
71
+
72
+ @property
73
+ def image(self) -> da.Array:
74
+ return self._image_getter()
75
+
76
+ @property
77
+ def label(self) -> da.Array:
78
+ return self._label_getter()
79
+
80
+
81
+ class FeatureExtractorIterator(AbstractIteratorBuilder[NumpyPipeType, DaskPipeType]):
82
+ """Base class for iterators over ROIs."""
83
+
84
+ def __init__(
85
+ self,
86
+ input_image: Image,
87
+ input_label: Label,
88
+ channel_selection: ChannelSlicingInputType = None,
89
+ axes_order: Sequence[str] | None = None,
90
+ input_transforms: Sequence[TransformProtocol] | None = None,
91
+ label_transforms: Sequence[TransformProtocol] | None = None,
92
+ ) -> None:
93
+ """Initialize the iterator with a ROI table and input/output images.
94
+
95
+ Args:
96
+ input_image (Image): The input image to be used as input for the
97
+ segmentation.
98
+ input_label (Label): The input label with the segmentation masks.
99
+ channel_selection (ChannelSlicingInputType): Optional
100
+ selection of channels to use for the segmentation.
101
+ axes_order (Sequence[str] | None): Optional axes order for the
102
+ segmentation.
103
+ input_transforms (Sequence[TransformProtocol] | None): Optional
104
+ transforms to apply to the input image.
105
+ label_transforms (Sequence[TransformProtocol] | None): Optional
106
+ transforms to apply to the output label.
107
+ """
108
+ self._input = input_image
109
+ self._input_label = input_label
110
+ self._ref_image = input_image
111
+ self._rois = input_image.build_image_roi_table(name=None).rois()
112
+
113
+ # Set iteration parameters
114
+ self._input_slicing_kwargs = add_channel_selection_to_slicing_dict(
115
+ image=self._input, channel_selection=channel_selection, slicing_dict={}
116
+ )
117
+ self._channel_selection = channel_selection
118
+ self._axes_order = axes_order
119
+ self._input_transforms = input_transforms
120
+ self._label_transforms = label_transforms
121
+
122
+ self._input.require_axes_match(self._input_label)
123
+ self._input.require_rescalable(self._input_label)
124
+
125
+ def get_init_kwargs(self) -> dict:
126
+ """Return the initialization arguments for the iterator."""
127
+ return {
128
+ "input_image": self._input,
129
+ "input_label": self._input_label,
130
+ "channel_selection": self._channel_selection,
131
+ "axes_order": self._axes_order,
132
+ "input_transforms": self._input_transforms,
133
+ "label_transforms": self._label_transforms,
134
+ }
135
+
136
+ def build_numpy_getter(self, roi: Roi) -> NumpyFeatureGetter:
137
+ data_getter = NumpyRoiGetter(
138
+ zarr_array=self._input.zarr_array,
139
+ dimensions=self._input.dimensions,
140
+ axes_order=self._axes_order,
141
+ transforms=self._input_transforms,
142
+ roi=roi,
143
+ slicing_dict=self._input_slicing_kwargs,
144
+ )
145
+ label_getter = NumpyRoiGetter(
146
+ zarr_array=self._input_label.zarr_array,
147
+ dimensions=self._input_label.dimensions,
148
+ axes_order=self._axes_order,
149
+ transforms=self._label_transforms,
150
+ roi=roi,
151
+ remove_channel_selection=True,
152
+ )
153
+ return NumpyFeatureGetter(data_getter, label_getter)
154
+
155
+ def build_dask_getter(self, roi: Roi) -> DaskFeatureGetter:
156
+ data_getter = DaskRoiGetter(
157
+ zarr_array=self._input.zarr_array,
158
+ dimensions=self._input.dimensions,
159
+ axes_order=self._axes_order,
160
+ transforms=self._input_transforms,
161
+ roi=roi,
162
+ slicing_dict=self._input_slicing_kwargs,
163
+ )
164
+ label_getter = DaskRoiGetter(
165
+ zarr_array=self._input_label.zarr_array,
166
+ dimensions=self._input_label.dimensions,
167
+ axes_order=self._axes_order,
168
+ transforms=self._label_transforms,
169
+ roi=roi,
170
+ remove_channel_selection=True,
171
+ )
172
+ return DaskFeatureGetter(data_getter, label_getter)
173
+
174
+ def build_numpy_setter(self, roi: Roi) -> None:
175
+ return None
176
+
177
+ def build_dask_setter(self, roi: Roi) -> None:
178
+ return None
179
+
180
+ def post_consolidate(self):
181
+ pass
182
+
183
+ def iter_as_numpy(self): # type: ignore[override]
184
+ """Create an iterator over the pixels of the ROIs."""
185
+ return self.iter(lazy=False, data_mode="numpy", iterator_mode="readonly")
186
+
187
+ def iter_as_dask(self): # type: ignore[override]
188
+ """Create an iterator over the pixels of the ROIs."""
189
+ return self.iter(lazy=False, data_mode="dask", iterator_mode="readonly")