ngio 0.3.5__py3-none-any.whl → 0.4.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ngio/__init__.py +6 -0
  2. ngio/common/__init__.py +50 -48
  3. ngio/common/_array_io_pipes.py +549 -0
  4. ngio/common/_array_io_utils.py +508 -0
  5. ngio/common/_dimensions.py +63 -27
  6. ngio/common/_masking_roi.py +38 -10
  7. ngio/common/_pyramid.py +9 -7
  8. ngio/common/_roi.py +571 -72
  9. ngio/common/_synt_images_utils.py +101 -0
  10. ngio/common/_zoom.py +17 -12
  11. ngio/common/transforms/__init__.py +5 -0
  12. ngio/common/transforms/_label.py +12 -0
  13. ngio/common/transforms/_zoom.py +109 -0
  14. ngio/experimental/__init__.py +5 -0
  15. ngio/experimental/iterators/__init__.py +17 -0
  16. ngio/experimental/iterators/_abstract_iterator.py +170 -0
  17. ngio/experimental/iterators/_feature.py +151 -0
  18. ngio/experimental/iterators/_image_processing.py +169 -0
  19. ngio/experimental/iterators/_rois_utils.py +127 -0
  20. ngio/experimental/iterators/_segmentation.py +278 -0
  21. ngio/hcs/_plate.py +41 -36
  22. ngio/images/__init__.py +22 -1
  23. ngio/images/_abstract_image.py +247 -117
  24. ngio/images/_create.py +15 -15
  25. ngio/images/_create_synt_container.py +128 -0
  26. ngio/images/_image.py +425 -62
  27. ngio/images/_label.py +33 -30
  28. ngio/images/_masked_image.py +396 -122
  29. ngio/images/_ome_zarr_container.py +203 -66
  30. ngio/{common → images}/_table_ops.py +41 -41
  31. ngio/ome_zarr_meta/ngio_specs/__init__.py +2 -8
  32. ngio/ome_zarr_meta/ngio_specs/_axes.py +151 -128
  33. ngio/ome_zarr_meta/ngio_specs/_channels.py +55 -18
  34. ngio/ome_zarr_meta/ngio_specs/_dataset.py +7 -7
  35. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +3 -3
  36. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +11 -68
  37. ngio/ome_zarr_meta/v04/_v04_spec_utils.py +1 -1
  38. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  39. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  40. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  41. ngio/resources/__init__.py +54 -0
  42. ngio/resources/resource_model.py +35 -0
  43. ngio/tables/backends/_abstract_backend.py +5 -6
  44. ngio/tables/backends/_anndata.py +1 -1
  45. ngio/tables/backends/_anndata_utils.py +3 -3
  46. ngio/tables/backends/_non_zarr_backends.py +1 -1
  47. ngio/tables/backends/_table_backends.py +0 -1
  48. ngio/tables/backends/_utils.py +3 -3
  49. ngio/tables/v1/_roi_table.py +156 -69
  50. ngio/utils/__init__.py +2 -3
  51. ngio/utils/_logger.py +19 -0
  52. ngio/utils/_zarr_utils.py +1 -5
  53. {ngio-0.3.5.dist-info → ngio-0.4.0a1.dist-info}/METADATA +3 -1
  54. ngio-0.4.0a1.dist-info/RECORD +76 -0
  55. ngio/common/_array_pipe.py +0 -288
  56. ngio/common/_axes_transforms.py +0 -64
  57. ngio/common/_common_types.py +0 -5
  58. ngio/common/_slicer.py +0 -96
  59. ngio-0.3.5.dist-info/RECORD +0 -61
  60. {ngio-0.3.5.dist-info → ngio-0.4.0a1.dist-info}/WHEEL +0 -0
  61. {ngio-0.3.5.dist-info → ngio-0.4.0a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,101 @@
1
+ from math import ceil
2
+
3
+ import numpy as np
4
+
5
+
6
+ def _center_crop(arr: np.ndarray, target: int, axis: int) -> np.ndarray:
7
+ # Center-crop the array `arr` along dimension `axis` to size `target`.
8
+ # This assumes target < arr.shape[axis].
9
+ n = arr.shape[axis]
10
+ start = (n - target) // 2
11
+ end = start + target
12
+ slc = [slice(None)] * arr.ndim
13
+ slc[axis] = slice(start, end)
14
+ return arr[tuple(slc)]
15
+
16
+
17
+ def _tile_to(
18
+ arr: np.ndarray, target: int, axis: int, label_mode: bool = False
19
+ ) -> np.ndarray:
20
+ # Tile the array `arr` along dimension `axis` to size `target`.
21
+ # This assumes target > arr.shape[axis].
22
+ n = arr.shape[axis]
23
+ reps = ceil(target / n)
24
+
25
+ tiles = []
26
+ flip = False
27
+ max_label = 0
28
+ for _ in range(reps):
29
+ if flip:
30
+ t_arr = np.flip(arr, axis=axis)
31
+ else:
32
+ t_arr = 1 * arr
33
+ if label_mode:
34
+ # Remove duplicate labels
35
+ t_arr = np.where(t_arr > 0, t_arr + max_label, 0)
36
+ max_label = t_arr.max()
37
+ tiles.append(t_arr)
38
+ flip = not flip
39
+
40
+ tiled = np.concatenate(tiles, axis=axis)
41
+
42
+ slc = [slice(None)] * arr.ndim
43
+ slc[axis] = slice(0, target)
44
+ return tiled[tuple(slc)]
45
+
46
+
47
+ def _fit_to_shape_2d(
48
+ src: np.ndarray, out_shape: tuple[int, int], label_mode: bool = False
49
+ ) -> np.ndarray:
50
+ """Fit a 2D array to a target shape by center-cropping or tiling as necessary."""
51
+ out_r, out_c = out_shape
52
+ arr = src
53
+ if out_r < arr.shape[0]:
54
+ arr = _center_crop(arr, out_r, axis=0)
55
+ else:
56
+ arr = _tile_to(arr, out_r, axis=0, label_mode=label_mode)
57
+
58
+ if out_c < arr.shape[1]:
59
+ arr = _center_crop(arr, out_c, axis=1)
60
+ else:
61
+ arr = _tile_to(arr, out_c, axis=1, label_mode=label_mode)
62
+ return arr
63
+
64
+
65
+ def fit_to_shape(
66
+ arr: np.ndarray, out_shape: tuple[int, ...], ensure_unique_info: bool = False
67
+ ) -> np.ndarray:
68
+ """Fit a 2D array to a target shape.
69
+
70
+ The x,y dimensions of `arr` are fitted to the last two dimensions of
71
+ `out_shape` by center-cropping or tiling as necessary.
72
+ The other dimensions are broadcasted as necessary.
73
+
74
+ WARNING: This does not zoom the image, it only crops or tiles it.
75
+
76
+ Args:
77
+ arr (np.ndarray): The input 2D array.
78
+ out_shape (tuple[int, ...]): The target shape. Must have at least 2
79
+ and at most 5 dimensions.
80
+ ensure_unique_info (bool, optional): If True, assumes that `arr` is a label
81
+ image and ensures that labels do not overlap when tiling. Defaults to False.
82
+
83
+ Returns:
84
+ np.ndarray: The fitted array with shape `out_shape`.
85
+ """
86
+ if len(out_shape) < 2:
87
+ raise ValueError("`out_shape` must contain at least 2 dimensions.")
88
+
89
+ if len(out_shape) > 5:
90
+ raise ValueError("`out_shape` must contain at most 5 dimensions.")
91
+
92
+ if any(d <= 0 for d in out_shape):
93
+ raise ValueError("`out_shape` must contain positive integers.")
94
+
95
+ if arr.ndim != 2:
96
+ raise ValueError("`arr` must be a 2D array.")
97
+
98
+ *_, sy, sx = out_shape
99
+ arr = _fit_to_shape_2d(arr, out_shape=(sy, sx), label_mode=ensure_unique_info)
100
+ arr = np.broadcast_to(arr, out_shape)
101
+ return arr
ngio/common/_zoom.py CHANGED
@@ -12,12 +12,11 @@ def _stacked_zoom(x, zoom_y, zoom_x, order=1, mode="grid-constant", grid_mode=Tr
12
12
  *rest, yshape, xshape = x.shape
13
13
  x = x.reshape(-1, yshape, xshape)
14
14
  scale_xy = (zoom_y, zoom_x)
15
- x_out = np.stack(
16
- [
17
- scipy_zoom(x[i], scale_xy, order=order, mode=mode, grid_mode=True)
18
- for i in range(x.shape[0])
19
- ]
20
- )
15
+ _x_out = [
16
+ scipy_zoom(x[i], scale_xy, order=order, mode=mode, grid_mode=grid_mode)
17
+ for i in range(x.shape[0])
18
+ ]
19
+ x_out = np.stack(_x_out) # type: ignore (scipy_zoom returns np.ndarray, but type is not inferred correctly)
21
20
  return x_out.reshape(*rest, *x_out.shape[1:])
22
21
 
23
22
 
@@ -45,13 +44,13 @@ def fast_zoom(x, zoom, order=1, mode="grid-constant", grid_mode=True, auto_stack
45
44
  )
46
45
  else:
47
46
  xs = scipy_zoom(xs, new_zoom, order=order, mode=mode, grid_mode=grid_mode)
48
- x = np.expand_dims(xs, axis=singletons)
47
+ x = np.expand_dims(xs, axis=singletons) # type: ignore (scipy_zoom returns np.ndarray, but type is not inferred correctly)
49
48
  return x
50
49
 
51
50
 
52
51
  def _zoom_inputs_check(
53
52
  source_array: np.ndarray | da.Array,
54
- scale: tuple[int, ...] | None = None,
53
+ scale: tuple[int | float, ...] | None = None,
55
54
  target_shape: tuple[int, ...] | None = None,
56
55
  ) -> tuple[np.ndarray, tuple[int, ...]]:
57
56
  if scale is None and target_shape is None:
@@ -74,12 +73,18 @@ def _zoom_inputs_check(
74
73
  _scale = np.array(scale)
75
74
  _target_shape = tuple(np.array(source_array.shape) * scale)
76
75
 
76
+ if len(_scale) != source_array.ndim:
77
+ raise NgioValueError(
78
+ f"Cannot scale array of shape {source_array.shape} with factors {_scale}."
79
+ " Target shape must have the same number of dimensions as the source array."
80
+ )
81
+
77
82
  return _scale, _target_shape
78
83
 
79
84
 
80
85
  def dask_zoom(
81
86
  source_array: da.Array,
82
- scale: tuple[int, ...] | None = None,
87
+ scale: tuple[float | int, ...] | None = None,
83
88
  target_shape: tuple[int, ...] | None = None,
84
89
  order: Literal[0, 1, 2] = 1,
85
90
  ) -> da.Array:
@@ -106,10 +111,10 @@ def dask_zoom(
106
111
  )
107
112
 
108
113
  # Rechunk to better match the scaling operation
109
- source_chunks = np.array(source_array.chunksize)
114
+ source_chunks = np.array(source_array.chunksize) # type: ignore (da.Array.chunksize is a tuple of ints)
110
115
  better_source_chunks = np.maximum(1, np.round(source_chunks * _scale) / _scale)
111
116
  better_source_chunks = better_source_chunks.astype(int)
112
- source_array = source_array.rechunk(better_source_chunks) # type: ignore
117
+ source_array = source_array.rechunk(better_source_chunks) # type: ignore (better_source_chunks is a valid input for rechunk)
113
118
 
114
119
  # Calculate the block output shape
115
120
  block_output_shape = tuple(np.ceil(better_source_chunks * _scale).astype(int))
@@ -130,7 +135,7 @@ def dask_zoom(
130
135
 
131
136
  def numpy_zoom(
132
137
  source_array: np.ndarray,
133
- scale: tuple[int, ...] | None = None,
138
+ scale: tuple[int | float, ...] | None = None,
134
139
  target_shape: tuple[int, ...] | None = None,
135
140
  order: Literal[0, 1, 2] = 1,
136
141
  ) -> np.ndarray:
@@ -0,0 +1,5 @@
1
+ """Concrete IO transformations."""
2
+
3
+ from ngio.common.transforms._zoom import ZoomTransform
4
+
5
+ __all__ = ["ZoomTransform"]
@@ -0,0 +1,12 @@
1
+ """Transforms and pre-post-processors for label images."""
2
+
3
+ # def make_unique_label_np(x: np.ndarray, p: int, n: int) -> np.ndarray:
4
+ # """Make a unique label for the patch."""
5
+ # x = np.where(x > 0, (1 + p - n) + x * n, 0)
6
+ # return x
7
+ #
8
+ #
9
+ # def make_unique_label_da(x: da.Array, p: int, n: int) -> da.Array:
10
+ # """Make a unique label for the patch."""
11
+ # x = da.where(x > 0, (1 + p - n) + x * n, 0)
12
+ # return x
@@ -0,0 +1,109 @@
1
+ from collections.abc import Sequence
2
+ from typing import Literal
3
+
4
+ import dask.array as da
5
+ import numpy as np
6
+
7
+ from ngio.common._array_io_utils import apply_sequence_axes_ops
8
+ from ngio.common._dimensions import Dimensions
9
+ from ngio.common._zoom import dask_zoom, numpy_zoom
10
+ from ngio.ome_zarr_meta.ngio_specs import SlicingOps
11
+
12
+
13
+ class ZoomTransform:
14
+ def __init__(self, scale: Sequence[float], order: Literal[0, 1, 2]):
15
+ self._scale = tuple(scale)
16
+ self._order: Literal[0, 1, 2] = order
17
+
18
+ @property
19
+ def scale(self) -> tuple[float, ...]:
20
+ return self._scale
21
+
22
+ @property
23
+ def inv_scale(self) -> tuple[float, ...]:
24
+ return tuple([1 / s for s in self._scale])
25
+
26
+ @classmethod
27
+ def from_dimensions(
28
+ cls,
29
+ original_dimension: Dimensions,
30
+ target_dimension: Dimensions,
31
+ order: Literal[0, 1, 2],
32
+ ):
33
+ scale = []
34
+ for o_ax_name in original_dimension.axes_mapper.axes_names:
35
+ t_ax = target_dimension.axes_mapper.get_axis(name=o_ax_name)
36
+ if t_ax is None:
37
+ _scale = 1
38
+ else:
39
+ t_shape = target_dimension.get(o_ax_name)
40
+ o_shape = original_dimension.get(o_ax_name)
41
+ assert t_shape is not None and o_shape is not None
42
+ _scale = t_shape / o_shape
43
+ scale.append(_scale)
44
+
45
+ return cls(scale, order)
46
+
47
+ def apply_numpy_transform(
48
+ self, array: np.ndarray, slicing_ops: SlicingOps
49
+ ) -> np.ndarray:
50
+ """Apply the scaling transformation to a numpy array."""
51
+ scale = tuple(
52
+ apply_sequence_axes_ops(
53
+ self.scale,
54
+ default=1,
55
+ squeeze_axes=slicing_ops.squeeze_axes,
56
+ transpose_axes=slicing_ops.transpose_axes,
57
+ expand_axes=slicing_ops.expand_axes,
58
+ )
59
+ )
60
+ array = numpy_zoom(source_array=array, scale=scale, order=self._order)
61
+ return array
62
+
63
+ def apply_dask_transform(
64
+ self, array: da.Array, slicing_ops: SlicingOps
65
+ ) -> da.Array:
66
+ """Apply the scaling transformation to a dask array."""
67
+ scale = tuple(
68
+ apply_sequence_axes_ops(
69
+ self.scale,
70
+ default=1,
71
+ squeeze_axes=slicing_ops.squeeze_axes,
72
+ transpose_axes=slicing_ops.transpose_axes,
73
+ expand_axes=slicing_ops.expand_axes,
74
+ )
75
+ )
76
+ array = dask_zoom(source_array=array, scale=scale, order=self._order)
77
+ return array
78
+
79
+ def apply_inverse_numpy_transform(
80
+ self, array: np.ndarray, slicing_ops: SlicingOps
81
+ ) -> np.ndarray:
82
+ """Apply the inverse scaling transformation to a numpy array."""
83
+ scale = tuple(
84
+ apply_sequence_axes_ops(
85
+ self.inv_scale,
86
+ default=1,
87
+ squeeze_axes=slicing_ops.squeeze_axes,
88
+ transpose_axes=slicing_ops.transpose_axes,
89
+ expand_axes=slicing_ops.expand_axes,
90
+ )
91
+ )
92
+ array = numpy_zoom(source_array=array, scale=scale, order=self._order)
93
+ return array
94
+
95
+ def apply_inverse_dask_transform(
96
+ self, array: da.Array, slicing_ops: SlicingOps
97
+ ) -> da.Array:
98
+ """Apply the inverse scaling transformation to a dask array."""
99
+ scale = tuple(
100
+ apply_sequence_axes_ops(
101
+ self.inv_scale,
102
+ default=1,
103
+ squeeze_axes=slicing_ops.squeeze_axes,
104
+ transpose_axes=slicing_ops.transpose_axes,
105
+ expand_axes=slicing_ops.expand_axes,
106
+ )
107
+ )
108
+ array = dask_zoom(source_array=array, scale=scale, order=self._order)
109
+ return array
@@ -0,0 +1,5 @@
1
+ """This module provides experimental features.
2
+
3
+ Use with caution as these features may change or be removed in future releases
4
+ without notice.
5
+ """
@@ -0,0 +1,17 @@
1
+ """This file is part of NGIO, a library for working with OME-Zarr data."""
2
+
3
+ from ngio.experimental.iterators._feature import FeatureExtractorIterator
4
+ from ngio.experimental.iterators._image_processing import ImageProcessingIterator
5
+ from ngio.experimental.iterators._segmentation import (
6
+ MaskedSegmentationIterator,
7
+ SegmentationIterator,
8
+ )
9
+
10
+ # from ngio.experimental.iterators._builder import IteratorBuilder
11
+
12
+ __all__ = [
13
+ "FeatureExtractorIterator",
14
+ "ImageProcessingIterator",
15
+ "MaskedSegmentationIterator",
16
+ "SegmentationIterator",
17
+ ]
@@ -0,0 +1,170 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable
3
+ from typing import Self
4
+
5
+ from ngio import Roi
6
+ from ngio.experimental.iterators._rois_utils import (
7
+ by_chunks,
8
+ by_yx,
9
+ by_zyx,
10
+ grid,
11
+ rois_product,
12
+ )
13
+ from ngio.images._abstract_image import AbstractImage
14
+ from ngio.tables import GenericRoiTable
15
+
16
+
17
+ class AbstractIteratorBuilder(ABC):
18
+ """Base class for building iterators over ROIs."""
19
+
20
+ _rois: list[Roi]
21
+ _ref_image: AbstractImage
22
+
23
+ def __repr__(self) -> str:
24
+ return f"{self.__class__.__name__}(regions={len(self._rois)})"
25
+
26
+ @abstractmethod
27
+ def get_init_kwargs(self) -> dict:
28
+ """Return the initialization arguments for the iterator."""
29
+ pass
30
+
31
+ @property
32
+ def rois(self) -> list[Roi]:
33
+ """Get the list of ROIs for the iterator."""
34
+ return self._rois
35
+
36
+ def _set_rois(self, rois: list[Roi]) -> None:
37
+ """Set the list of ROIs for the iterator."""
38
+ self._rois = rois
39
+
40
+ @property
41
+ def ref_image(self) -> AbstractImage:
42
+ """Get the reference image for the iterator."""
43
+ return self._ref_image
44
+
45
+ def _new_from_rois(self, rois: list[Roi]) -> Self:
46
+ """Create a new instance of the iterator with a different set of ROIs."""
47
+ init_kwargs = self.get_init_kwargs()
48
+ new_instance = self.__class__(**init_kwargs)
49
+ new_instance._set_rois(rois)
50
+ return new_instance
51
+
52
+ def grid(
53
+ self,
54
+ size_x: int | None = None,
55
+ size_y: int | None = None,
56
+ size_z: int | None = None,
57
+ size_t: int | None = None,
58
+ stride_x: int | None = None,
59
+ stride_y: int | None = None,
60
+ stride_z: int | None = None,
61
+ stride_t: int | None = None,
62
+ base_name: str = "",
63
+ ) -> Self:
64
+ """Create a grid of ROIs based on the input image dimensions."""
65
+ rois = grid(
66
+ rois=self.rois,
67
+ ref_image=self.ref_image,
68
+ size_x=size_x,
69
+ size_y=size_y,
70
+ size_z=size_z,
71
+ size_t=size_t,
72
+ stride_x=stride_x,
73
+ stride_y=stride_y,
74
+ stride_z=stride_z,
75
+ stride_t=stride_t,
76
+ base_name=base_name,
77
+ )
78
+ return self._new_from_rois(rois)
79
+
80
+ def by_yx(self) -> Self:
81
+ """Return a new iterator that iterates over ROIs by YX coordinates."""
82
+ rois = by_yx(self.rois, self.ref_image)
83
+ return self._new_from_rois(rois)
84
+
85
+ def by_zyx(self, strict: bool = True) -> Self:
86
+ """Return a new iterator that iterates over ROIs by ZYX coordinates.
87
+
88
+ Args:
89
+ strict (bool): If True, only iterate over ZYX if a Z axis
90
+ is present and not of size 1.
91
+
92
+ """
93
+ rois = by_zyx(self.rois, self.ref_image, strict=strict)
94
+ return self._new_from_rois(rois)
95
+
96
+ def by_chunks(self, overlap_xy: int = 0, overlap_z: int = 0) -> Self:
97
+ """Return a new iterator that iterates over ROIs by chunks.
98
+
99
+ Args:
100
+ overlap_xy (int): Overlap in XY dimensions.
101
+ overlap_z (int): Overlap in Z dimension.
102
+
103
+ Returns:
104
+ SegmentationIterator: A new iterator with chunked ROIs.
105
+ """
106
+ rois = by_chunks(
107
+ self.rois, self.ref_image, overlap_xy=overlap_xy, overlap_z=overlap_z
108
+ )
109
+ return self._new_from_rois(rois)
110
+
111
+ def product(self, other: list[Roi] | GenericRoiTable) -> Self:
112
+ """Cartesian product of the current ROIs with an arbitrary list of ROIs."""
113
+ if isinstance(other, GenericRoiTable):
114
+ other = other.rois()
115
+ rois = rois_product(self.rois, other)
116
+ return self._new_from_rois(rois)
117
+
118
+ @abstractmethod
119
+ def build_numpy_getter(self, roi: Roi):
120
+ """Build a getter function for the given ROI."""
121
+ raise NotImplementedError
122
+
123
+ @abstractmethod
124
+ def build_numpy_setter(self, roi: Roi):
125
+ """Build a setter function for the given ROI."""
126
+ raise NotImplementedError
127
+
128
+ @abstractmethod
129
+ def build_dask_getter(self, roi: Roi):
130
+ """Build a Dask reader function for the given ROI."""
131
+ raise NotImplementedError
132
+
133
+ @abstractmethod
134
+ def build_dask_setter(self, roi: Roi):
135
+ """Build a Dask setter function for the given ROI."""
136
+ raise NotImplementedError
137
+
138
+ @abstractmethod
139
+ def post_consolidate(self) -> None:
140
+ """Post-process the consolidated data."""
141
+ raise NotImplementedError
142
+
143
+ def iter_as_numpy(self):
144
+ """Create an iterator over the pixels of the ROIs."""
145
+ for roi in self.rois:
146
+ data = self.build_numpy_getter(roi)()
147
+ yield data, self.build_numpy_setter(roi)
148
+ self.post_consolidate()
149
+
150
+ def iter_as_dask(self):
151
+ """Create an iterator over the pixels of the ROIs."""
152
+ for roi in self.rois:
153
+ data = self.build_dask_getter(roi)()
154
+ yield data, self.build_dask_setter(roi)
155
+
156
+ def map_as_numpy(self, func: Callable) -> None:
157
+ """Apply a transformation function to the ROI pixels."""
158
+ for roi in self.rois:
159
+ data = self.build_numpy_getter(roi)
160
+ data = func(data)
161
+ self.build_numpy_setter(roi)
162
+ self.post_consolidate()
163
+
164
+ def map_as_dask(self, func: Callable) -> None:
165
+ """Apply a transformation function to the ROI pixels."""
166
+ for roi in self.rois:
167
+ data = self.build_dask_getter(roi)
168
+ data = func(data)
169
+ self.build_dask_setter(roi)
170
+ self.post_consolidate()
@@ -0,0 +1,151 @@
1
+ from collections.abc import Callable, Generator, Sequence
2
+
3
+ import dask.array as da
4
+ import numpy as np
5
+
6
+ from ngio.common import (
7
+ Roi,
8
+ TransformProtocol,
9
+ build_roi_dask_getter,
10
+ build_roi_numpy_getter,
11
+ )
12
+ from ngio.experimental.iterators._abstract_iterator import AbstractIteratorBuilder
13
+ from ngio.images import Image, Label
14
+ from ngio.images._image import (
15
+ ChannelSlicingInputType,
16
+ add_channel_selection_to_slicing_dict,
17
+ )
18
+
19
+
20
+ class FeatureExtractorIterator(AbstractIteratorBuilder):
21
+ """Base class for iterators over ROIs."""
22
+
23
+ def __init__(
24
+ self,
25
+ input_image: Image,
26
+ input_label: Label,
27
+ channel_selection: ChannelSlicingInputType = None,
28
+ axes_order: Sequence[str] | None = None,
29
+ input_transforms: Sequence[TransformProtocol] | None = None,
30
+ label_transforms: Sequence[TransformProtocol] | None = None,
31
+ ) -> None:
32
+ """Initialize the iterator with a ROI table and input/output images.
33
+
34
+ Args:
35
+ input_image (Image): The input image to be used as input for the
36
+ segmentation.
37
+ input_label (Label): The input label with the segmentation masks.
38
+ channel_selection (ChannelSlicingInputType): Optional
39
+ selection of channels to use for the segmentation.
40
+ axes_order (Sequence[str] | None): Optional axes order for the
41
+ segmentation.
42
+ input_transforms (Sequence[TransformProtocol] | None): Optional
43
+ transforms to apply to the input image.
44
+ label_transforms (Sequence[TransformProtocol] | None): Optional
45
+ transforms to apply to the output label.
46
+ """
47
+ self._input = input_image
48
+ self._input_label = input_label
49
+ self._ref_image = input_image
50
+ self._rois = input_image.build_image_roi_table().rois()
51
+
52
+ # Set iteration parameters
53
+ self._input_slicing_kwargs = add_channel_selection_to_slicing_dict(
54
+ image=self._input, channel_selection=channel_selection, slicing_dict={}
55
+ )
56
+ self._channel_selection = channel_selection
57
+ self._axes_order = axes_order
58
+ self._input_transforms = input_transforms
59
+ self._label_transforms = label_transforms
60
+
61
+ def get_init_kwargs(self) -> dict:
62
+ """Return the initialization arguments for the iterator."""
63
+ return {
64
+ "input_image": self._input,
65
+ "input_label": self._input_label,
66
+ "channel_selection": self._channel_selection,
67
+ "axes_order": self._axes_order,
68
+ "input_transforms": self._input_transforms,
69
+ "label_transforms": self._label_transforms,
70
+ }
71
+
72
+ def build_numpy_getter(self, roi: Roi):
73
+ data_getter = build_roi_numpy_getter(
74
+ zarr_array=self._input.zarr_array,
75
+ dimensions=self._input.dimensions,
76
+ axes_order=self._axes_order,
77
+ transforms=self._input_transforms,
78
+ pixel_size=self._input.pixel_size,
79
+ roi=roi,
80
+ slicing_dict=self._input_slicing_kwargs,
81
+ )
82
+ label_getter = build_roi_numpy_getter(
83
+ zarr_array=self._input_label.zarr_array,
84
+ dimensions=self._input_label.dimensions,
85
+ axes_order=self._axes_order,
86
+ transforms=self._label_transforms,
87
+ pixel_size=self._input_label.pixel_size,
88
+ roi=roi,
89
+ remove_channel_selection=True,
90
+ )
91
+ return lambda: (data_getter(), label_getter(), roi)
92
+
93
+ def build_numpy_setter(self, roi: Roi):
94
+ return None
95
+
96
+ def build_dask_getter(self, roi: Roi):
97
+ data_getter = build_roi_dask_getter(
98
+ zarr_array=self._input.zarr_array,
99
+ dimensions=self._input.dimensions,
100
+ axes_order=self._axes_order,
101
+ transforms=self._input_transforms,
102
+ pixel_size=self._input.pixel_size,
103
+ roi=roi,
104
+ slicing_dict=self._input_slicing_kwargs,
105
+ )
106
+ label_getter = build_roi_dask_getter(
107
+ zarr_array=self._input_label.zarr_array,
108
+ dimensions=self._input_label.dimensions,
109
+ axes_order=self._axes_order,
110
+ transforms=self._label_transforms,
111
+ pixel_size=self._input_label.pixel_size,
112
+ roi=roi,
113
+ remove_channel_selection=True,
114
+ )
115
+ return lambda: (data_getter(), label_getter(), roi)
116
+
117
+ def build_dask_setter(self, roi: Roi):
118
+ return None
119
+
120
+ def post_consolidate(self):
121
+ pass
122
+
123
+ def iter_as_numpy(self) -> Generator[tuple[np.ndarray, np.ndarray, Roi]]: # type: ignore (non compatible override)
124
+ """Create an iterator over the pixels of the ROIs as Dask arrays.
125
+
126
+ Returns:
127
+ Generator[tuple[da.Array, DaskWriter]]: An iterator the input
128
+ image as Dask arrays and a writer to write the output
129
+ to the label image.
130
+ """
131
+ for (data, label, roi), _ in super().iter_as_numpy():
132
+ yield data, label, roi
133
+
134
+ def map_as_numpy(self, func: Callable[[np.ndarray], np.ndarray]) -> None:
135
+ """Apply a transformation function to the ROI pixels."""
136
+ raise NotImplementedError("Numpy mapping not implemented for this iterator.")
137
+
138
+ def iter_as_dask(self) -> Generator[tuple[da.Array, da.Array, Roi]]: # type: ignore (non compatible override)
139
+ """Create an iterator over the pixels of the ROIs as Dask arrays.
140
+
141
+ Returns:
142
+ Generator[tuple[da.Array, DaskWriter]]: An iterator the input
143
+ image as Dask arrays and a writer to write the output
144
+ to the label image.
145
+ """
146
+ for (data, label, roi), _ in super().iter_as_dask():
147
+ yield data, label, roi
148
+
149
+ def map_as_dask(self, func: Callable[[da.Array], da.Array]) -> None:
150
+ """Apply a transformation function to the ROI pixels."""
151
+ raise NotImplementedError("Dask mapping not implemented for this iterator.")