ngio 0.3.5__py3-none-any.whl → 0.4.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ngio/__init__.py +6 -0
  2. ngio/common/__init__.py +50 -48
  3. ngio/common/_array_io_pipes.py +554 -0
  4. ngio/common/_array_io_utils.py +508 -0
  5. ngio/common/_dimensions.py +63 -27
  6. ngio/common/_masking_roi.py +38 -10
  7. ngio/common/_pyramid.py +9 -7
  8. ngio/common/_roi.py +583 -72
  9. ngio/common/_synt_images_utils.py +101 -0
  10. ngio/common/_zoom.py +17 -12
  11. ngio/common/transforms/__init__.py +5 -0
  12. ngio/common/transforms/_label.py +12 -0
  13. ngio/common/transforms/_zoom.py +109 -0
  14. ngio/experimental/__init__.py +5 -0
  15. ngio/experimental/iterators/__init__.py +17 -0
  16. ngio/experimental/iterators/_abstract_iterator.py +170 -0
  17. ngio/experimental/iterators/_feature.py +151 -0
  18. ngio/experimental/iterators/_image_processing.py +169 -0
  19. ngio/experimental/iterators/_rois_utils.py +127 -0
  20. ngio/experimental/iterators/_segmentation.py +282 -0
  21. ngio/hcs/_plate.py +41 -36
  22. ngio/images/__init__.py +22 -1
  23. ngio/images/_abstract_image.py +247 -117
  24. ngio/images/_create.py +15 -15
  25. ngio/images/_create_synt_container.py +128 -0
  26. ngio/images/_image.py +425 -62
  27. ngio/images/_label.py +33 -30
  28. ngio/images/_masked_image.py +396 -122
  29. ngio/images/_ome_zarr_container.py +203 -66
  30. ngio/{common → images}/_table_ops.py +41 -41
  31. ngio/ome_zarr_meta/ngio_specs/__init__.py +2 -8
  32. ngio/ome_zarr_meta/ngio_specs/_axes.py +151 -128
  33. ngio/ome_zarr_meta/ngio_specs/_channels.py +55 -18
  34. ngio/ome_zarr_meta/ngio_specs/_dataset.py +7 -7
  35. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +3 -3
  36. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +11 -68
  37. ngio/ome_zarr_meta/v04/_v04_spec_utils.py +1 -1
  38. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  39. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  40. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  41. ngio/resources/__init__.py +54 -0
  42. ngio/resources/resource_model.py +35 -0
  43. ngio/tables/backends/_abstract_backend.py +5 -6
  44. ngio/tables/backends/_anndata.py +1 -1
  45. ngio/tables/backends/_anndata_utils.py +3 -3
  46. ngio/tables/backends/_non_zarr_backends.py +1 -1
  47. ngio/tables/backends/_table_backends.py +0 -1
  48. ngio/tables/backends/_utils.py +3 -3
  49. ngio/tables/v1/_roi_table.py +156 -69
  50. ngio/utils/__init__.py +2 -3
  51. ngio/utils/_logger.py +19 -0
  52. ngio/utils/_zarr_utils.py +1 -5
  53. {ngio-0.3.5.dist-info → ngio-0.4.0a2.dist-info}/METADATA +3 -1
  54. ngio-0.4.0a2.dist-info/RECORD +76 -0
  55. ngio/common/_array_pipe.py +0 -288
  56. ngio/common/_axes_transforms.py +0 -64
  57. ngio/common/_common_types.py +0 -5
  58. ngio/common/_slicer.py +0 -96
  59. ngio-0.3.5.dist-info/RECORD +0 -61
  60. {ngio-0.3.5.dist-info → ngio-0.4.0a2.dist-info}/WHEEL +0 -0
  61. {ngio-0.3.5.dist-info → ngio-0.4.0a2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,169 @@
1
+ from collections.abc import Callable, Generator, Sequence
2
+
3
+ import dask.array as da
4
+ import numpy as np
5
+
6
+ from ngio.common import (
7
+ Roi,
8
+ TransformProtocol,
9
+ build_roi_dask_getter,
10
+ build_roi_dask_setter,
11
+ build_roi_numpy_getter,
12
+ build_roi_numpy_setter,
13
+ )
14
+ from ngio.experimental.iterators._abstract_iterator import AbstractIteratorBuilder
15
+ from ngio.images import Image
16
+ from ngio.images._image import (
17
+ ChannelSlicingInputType,
18
+ add_channel_selection_to_slicing_dict,
19
+ )
20
+ from ngio.utils._errors import NgioValidationError
21
+
22
+
23
+ class ImageProcessingIterator(AbstractIteratorBuilder):
24
+ """Base class for iterators over ROIs."""
25
+
26
+ def __init__(
27
+ self,
28
+ input_image: Image,
29
+ output_image: Image,
30
+ input_channel_selection: ChannelSlicingInputType = None,
31
+ output_channel_selection: ChannelSlicingInputType = None,
32
+ axes_order: Sequence[str] | None = None,
33
+ input_transforms: Sequence[TransformProtocol] | None = None,
34
+ output_transforms: Sequence[TransformProtocol] | None = None,
35
+ ) -> None:
36
+ """Initialize the iterator with a ROI table and input/output images.
37
+
38
+ Args:
39
+ input_image (Image): The input image to be used as input for the
40
+ segmentation.
41
+ output_image (Image): The image where the ROIs will be written.
42
+ input_channel_selection (ChannelSlicingInputType): Optional
43
+ selection of channels to use for the input image.
44
+ output_channel_selection (ChannelSlicingInputType): Optional
45
+ selection of channels to use for the output image.
46
+ axes_order (Sequence[str] | None): Optional axes order for the
47
+ segmentation.
48
+ input_transforms (Sequence[TransformProtocol] | None): Optional
49
+ transforms to apply to the input image.
50
+ output_transforms (Sequence[TransformProtocol] | None): Optional
51
+ transforms to apply to the output label.
52
+ """
53
+ self._input = input_image
54
+ self._output = output_image
55
+ self._ref_image = input_image
56
+ self._rois = input_image.build_image_roi_table().rois()
57
+
58
+ # Set iteration parameters
59
+ self._input_slicing_kwargs = add_channel_selection_to_slicing_dict(
60
+ image=self._input,
61
+ channel_selection=input_channel_selection,
62
+ slicing_dict={},
63
+ )
64
+ self._output_slicing_kwargs = add_channel_selection_to_slicing_dict(
65
+ image=self._output,
66
+ channel_selection=output_channel_selection,
67
+ slicing_dict={},
68
+ )
69
+ self._input_channel_selection = input_channel_selection
70
+ self._output_channel_selection = output_channel_selection
71
+ self._axes_order = axes_order
72
+ self._input_transforms = input_transforms
73
+ self._output_transforms = output_transforms
74
+
75
+ # Check compatibility between input and output images
76
+ if not self._input.dimensions.is_compatible_with(self._output.dimensions):
77
+ raise NgioValidationError(
78
+ "Input image and output label have incompatible dimensions. "
79
+ f"Input: {self._input.dimensions}, Output: {self._output.dimensions}."
80
+ )
81
+
82
+ def get_init_kwargs(self) -> dict:
83
+ """Return the initialization arguments for the iterator."""
84
+ return {
85
+ "input_image": self._input,
86
+ "output_image": self._output,
87
+ "input_channel_selection": self._input_channel_selection,
88
+ "output_channel_selection": self._output_channel_selection,
89
+ "axes_order": self._axes_order,
90
+ "input_transforms": self._input_transforms,
91
+ "output_transforms": self._output_transforms,
92
+ }
93
+
94
+ def build_numpy_getter(self, roi: Roi):
95
+ return build_roi_numpy_getter(
96
+ zarr_array=self._input.zarr_array,
97
+ dimensions=self._input.dimensions,
98
+ axes_order=self._axes_order,
99
+ transforms=self._input_transforms,
100
+ pixel_size=self._input.pixel_size,
101
+ roi=roi,
102
+ slicing_dict=self._input_slicing_kwargs,
103
+ )
104
+
105
+ def build_numpy_setter(self, roi: Roi):
106
+ return build_roi_numpy_setter(
107
+ zarr_array=self._output.zarr_array,
108
+ dimensions=self._output.dimensions,
109
+ axes_order=self._axes_order,
110
+ transforms=self._output_transforms,
111
+ pixel_size=self._output.pixel_size,
112
+ roi=roi,
113
+ slicing_dict=self._output_slicing_kwargs,
114
+ )
115
+
116
+ def build_dask_getter(self, roi: Roi):
117
+ return build_roi_dask_getter(
118
+ zarr_array=self._input.zarr_array,
119
+ dimensions=self._input.dimensions,
120
+ axes_order=self._axes_order,
121
+ transforms=self._input_transforms,
122
+ pixel_size=self._input.pixel_size,
123
+ roi=roi,
124
+ slicing_dict=self._input_slicing_kwargs,
125
+ )
126
+
127
+ def build_dask_setter(self, roi: Roi):
128
+ return build_roi_dask_setter(
129
+ zarr_array=self._output.zarr_array,
130
+ dimensions=self._output.dimensions,
131
+ axes_order=self._axes_order,
132
+ transforms=self._output_transforms,
133
+ pixel_size=self._output.pixel_size,
134
+ roi=roi,
135
+ slicing_dict=self._output_slicing_kwargs,
136
+ )
137
+
138
+ def post_consolidate(self):
139
+ self._output.consolidate()
140
+
141
+ def iter_as_numpy(
142
+ self,
143
+ ) -> Generator[tuple[np.ndarray, Callable[[np.ndarray], None]]]:
144
+ """Create an iterator over the pixels of the ROIs as Dask arrays.
145
+
146
+ Returns:
147
+ Generator[tuple[da.Array, DaskWriter]]: An iterator the input
148
+ image as Dask arrays and a writer to write the output
149
+ to the label image.
150
+ """
151
+ return super().iter_as_numpy()
152
+
153
+ def map_as_numpy(self, func: Callable[[np.ndarray], np.ndarray]) -> None:
154
+ """Apply a transformation function to the ROI pixels."""
155
+ return super().map_as_numpy(func)
156
+
157
+ def iter_as_dask(self) -> Generator[tuple[da.Array, Callable[[da.Array], None]]]:
158
+ """Create an iterator over the pixels of the ROIs as Dask arrays.
159
+
160
+ Returns:
161
+ Generator[tuple[da.Array, DaskWriter]]: An iterator the input
162
+ image as Dask arrays and a writer to write the output
163
+ to the label image.
164
+ """
165
+ return super().iter_as_dask()
166
+
167
+ def map_as_dask(self, func: Callable[[da.Array], da.Array]) -> None:
168
+ """Apply a transformation function to the ROI pixels."""
169
+ return super().map_as_dask(func)
@@ -0,0 +1,127 @@
1
+ from ngio import Roi, RoiPixels
2
+ from ngio.images._abstract_image import AbstractImage
3
+
4
+
5
+ def rois_product(rois_a: list[Roi], rois_b: list[Roi]) -> list[Roi]:
6
+ """Compute the product of two sets of ROIs."""
7
+ rois_product = []
8
+ for roi_a in rois_a:
9
+ for roi_b in rois_b:
10
+ intersection = roi_a.intersection(roi_b)
11
+ if intersection:
12
+ rois_product.append(intersection)
13
+ return rois_product
14
+
15
+
16
+ def grid(
17
+ rois: list[Roi],
18
+ ref_image: AbstractImage,
19
+ size_x: int | None = None,
20
+ size_y: int | None = None,
21
+ size_z: int | None = None,
22
+ size_t: int | None = None,
23
+ stride_x: int | None = None,
24
+ stride_y: int | None = None,
25
+ stride_z: int | None = None,
26
+ stride_t: int | None = None,
27
+ base_name: str = "",
28
+ ) -> list[Roi]:
29
+ """This method is a placeholder for creating a regular grid of ROIs."""
30
+ t_dim = ref_image.dimensions.get("t", default=1)
31
+ z_dim = ref_image.dimensions.get("z", default=1)
32
+ y_dim = ref_image.dimensions.get("y", default=1)
33
+ x_dim = ref_image.dimensions.get("x", default=1)
34
+
35
+ size_t = size_t if size_t is not None else t_dim
36
+ size_z = size_z if size_z is not None else z_dim
37
+ size_y = size_y if size_y is not None else y_dim
38
+ size_x = size_x if size_x is not None else x_dim
39
+
40
+ stride_t = stride_t if stride_t is not None else size_t
41
+ stride_z = stride_z if stride_z is not None else size_z
42
+ stride_y = stride_y if stride_y is not None else size_y
43
+ stride_x = stride_x if stride_x is not None else size_x
44
+
45
+ # Here we would create a grid of ROIs based on the specified parameters.
46
+ new_rois = []
47
+ for t in range(0, t_dim, stride_t):
48
+ for z in range(0, z_dim, stride_z):
49
+ for y in range(0, y_dim, stride_y):
50
+ for x in range(0, x_dim, stride_x):
51
+ roi = RoiPixels(
52
+ name=f"{base_name}({t}, {z}, {y}, {x})",
53
+ x=x,
54
+ y=y,
55
+ z=z,
56
+ t=t,
57
+ x_length=size_x,
58
+ y_length=size_y,
59
+ z_length=size_z,
60
+ t_length=size_t,
61
+ )
62
+ new_rois.append(roi.to_roi(pixel_size=ref_image.pixel_size))
63
+
64
+ return rois_product(rois, new_rois)
65
+
66
+
67
+ def by_yx(rois: list[Roi], ref_image: AbstractImage) -> list[Roi]:
68
+ """Return a new iterator that iterates over ROIs by YX coordinates."""
69
+ return grid(
70
+ rois=rois,
71
+ ref_image=ref_image,
72
+ size_z=1,
73
+ stride_z=1,
74
+ size_t=1,
75
+ stride_t=1,
76
+ )
77
+
78
+
79
+ def by_zyx(rois: list[Roi], ref_image: AbstractImage, strict: bool = True) -> list[Roi]:
80
+ """Return a new iterator that iterates over ROIs by ZYX coordinates."""
81
+ if strict and not ref_image.is_3d:
82
+ raise ValueError(
83
+ "Reference Input image must be 3D to iterate by ZXY coordinates. "
84
+ f"Current dimensions: {ref_image.dimensions}"
85
+ )
86
+ return grid(
87
+ rois=rois,
88
+ ref_image=ref_image,
89
+ size_t=1,
90
+ stride_t=1,
91
+ )
92
+
93
+
94
+ def by_chunks(
95
+ rois: list[Roi],
96
+ ref_image: AbstractImage,
97
+ overlap_xy: int = 0,
98
+ overlap_z: int = 0,
99
+ overlap_t: int = 0,
100
+ ) -> list[Roi]:
101
+ """This method is a placeholder for chunked processing."""
102
+ chunk_size = ref_image.chunks
103
+ t_axis = ref_image.axes_mapper.get_index("t")
104
+ z_axis = ref_image.axes_mapper.get_index("z")
105
+ y_axis = ref_image.axes_mapper.get_index("y")
106
+ x_axis = ref_image.axes_mapper.get_index("x")
107
+
108
+ size_x = chunk_size[x_axis] if x_axis is not None else None
109
+ size_y = chunk_size[y_axis] if y_axis is not None else None
110
+ size_z = chunk_size[z_axis] if z_axis is not None else None
111
+ size_t = chunk_size[t_axis] if t_axis is not None else None
112
+ stride_x = size_x - overlap_xy if size_x is not None else None
113
+ stride_y = size_y - overlap_xy if size_y is not None else None
114
+ stride_z = size_z - overlap_z if size_z is not None else None
115
+ stride_t = size_t - overlap_t if size_t is not None else None
116
+ return grid(
117
+ rois=rois,
118
+ ref_image=ref_image,
119
+ size_x=size_x,
120
+ size_y=size_y,
121
+ size_z=size_z,
122
+ size_t=size_t,
123
+ stride_x=stride_x,
124
+ stride_y=stride_y,
125
+ stride_z=stride_z,
126
+ stride_t=stride_t,
127
+ )
@@ -0,0 +1,282 @@
1
+ from collections.abc import Callable, Generator, Sequence
2
+
3
+ import dask.array as da
4
+ import numpy as np
5
+
6
+ from ngio.common import (
7
+ Roi,
8
+ TransformProtocol,
9
+ build_roi_dask_getter,
10
+ build_roi_dask_setter,
11
+ build_roi_masked_dask_getter,
12
+ build_roi_masked_dask_setter,
13
+ build_roi_masked_numpy_getter,
14
+ build_roi_masked_numpy_setter,
15
+ build_roi_numpy_getter,
16
+ build_roi_numpy_setter,
17
+ )
18
+ from ngio.experimental.iterators._abstract_iterator import AbstractIteratorBuilder
19
+ from ngio.images import Image, Label
20
+ from ngio.images._image import (
21
+ ChannelSlicingInputType,
22
+ add_channel_selection_to_slicing_dict,
23
+ )
24
+ from ngio.images._masked_image import MaskedImage
25
+ from ngio.utils._errors import NgioValidationError
26
+
27
+
28
+ class SegmentationIterator(AbstractIteratorBuilder):
29
+ """Base class for iterators over ROIs."""
30
+
31
+ def __init__(
32
+ self,
33
+ input_image: Image,
34
+ output_label: Label,
35
+ channel_selection: ChannelSlicingInputType = None,
36
+ axes_order: Sequence[str] | None = None,
37
+ input_transforms: Sequence[TransformProtocol] | None = None,
38
+ output_transforms: Sequence[TransformProtocol] | None = None,
39
+ ) -> None:
40
+ """Initialize the iterator with a ROI table and input/output images.
41
+
42
+ Args:
43
+ input_image (Image): The input image to be used as input for the
44
+ segmentation.
45
+ output_label (Label): The label image where the ROIs will be written.
46
+ channel_selection (ChannelSlicingInputType): Optional
47
+ selection of channels to use for the segmentation.
48
+ axes_order (Sequence[str] | None): Optional axes order for the
49
+ segmentation.
50
+ input_transforms (Sequence[TransformProtocol] | None): Optional
51
+ transforms to apply to the input image.
52
+ output_transforms (Sequence[TransformProtocol] | None): Optional
53
+ transforms to apply to the output label.
54
+ """
55
+ self._input = input_image
56
+ self._output = output_label
57
+ self._ref_image = input_image
58
+ self._rois = input_image.build_image_roi_table().rois()
59
+
60
+ # Set iteration parameters
61
+ self._input_slicing_kwargs = add_channel_selection_to_slicing_dict(
62
+ image=self._input, channel_selection=channel_selection, slicing_dict={}
63
+ )
64
+ self._channel_selection = channel_selection
65
+ self._axes_order = axes_order
66
+ self._input_transforms = input_transforms
67
+ self._output_transforms = output_transforms
68
+
69
+ # Check compatibility between input and output images
70
+ if not self._input.dimensions.is_compatible_with(self._output.dimensions):
71
+ raise NgioValidationError(
72
+ "Input image and output label have incompatible dimensions. "
73
+ f"Input: {self._input.dimensions}, Output: {self._output.dimensions}."
74
+ )
75
+
76
+ def get_init_kwargs(self) -> dict:
77
+ """Return the initialization arguments for the iterator."""
78
+ return {
79
+ "input_image": self._input,
80
+ "output_label": self._output,
81
+ "channel_selection": self._channel_selection,
82
+ "axes_order": self._axes_order,
83
+ "input_transforms": self._input_transforms,
84
+ "output_transforms": self._output_transforms,
85
+ }
86
+
87
+ def build_numpy_getter(self, roi: Roi):
88
+ return build_roi_numpy_getter(
89
+ zarr_array=self._input.zarr_array,
90
+ dimensions=self._input.dimensions,
91
+ axes_order=self._axes_order,
92
+ transforms=self._input_transforms,
93
+ pixel_size=self._input.pixel_size,
94
+ roi=roi,
95
+ slicing_dict=self._input_slicing_kwargs,
96
+ )
97
+
98
+ def build_numpy_setter(self, roi: Roi):
99
+ return build_roi_numpy_setter(
100
+ zarr_array=self._output.zarr_array,
101
+ dimensions=self._output.dimensions,
102
+ axes_order=self._axes_order,
103
+ transforms=self._output_transforms,
104
+ pixel_size=self._output.pixel_size,
105
+ roi=roi,
106
+ remove_channel_selection=True,
107
+ )
108
+
109
+ def build_dask_getter(self, roi: Roi):
110
+ return build_roi_dask_getter(
111
+ zarr_array=self._input.zarr_array,
112
+ dimensions=self._input.dimensions,
113
+ axes_order=self._axes_order,
114
+ transforms=self._input_transforms,
115
+ pixel_size=self._input.pixel_size,
116
+ roi=roi,
117
+ slicing_dict=self._input_slicing_kwargs,
118
+ )
119
+
120
+ def build_dask_setter(self, roi: Roi):
121
+ return build_roi_dask_setter(
122
+ zarr_array=self._output.zarr_array,
123
+ dimensions=self._output.dimensions,
124
+ axes_order=self._axes_order,
125
+ transforms=self._output_transforms,
126
+ pixel_size=self._output.pixel_size,
127
+ roi=roi,
128
+ remove_channel_selection=True,
129
+ )
130
+
131
+ def post_consolidate(self):
132
+ self._output.consolidate()
133
+
134
+ def iter_as_numpy(
135
+ self,
136
+ ) -> Generator[tuple[np.ndarray, Callable[[np.ndarray], None]]]:
137
+ """Create an iterator over the pixels of the ROIs as Dask arrays.
138
+
139
+ Returns:
140
+ Generator[tuple[da.Array, DaskWriter]]: An iterator the input
141
+ image as Dask arrays and a writer to write the output
142
+ to the label image.
143
+ """
144
+ return super().iter_as_numpy()
145
+
146
+ def map_as_numpy(self, func: Callable[[np.ndarray], np.ndarray]) -> None:
147
+ """Apply a transformation function to the ROI pixels."""
148
+ return super().map_as_numpy(func)
149
+
150
+ def iter_as_dask(self) -> Generator[tuple[da.Array, Callable[[da.Array], None]]]:
151
+ """Create an iterator over the pixels of the ROIs as Dask arrays.
152
+
153
+ Returns:
154
+ Generator[tuple[da.Array, DaskWriter]]: An iterator the input
155
+ image as Dask arrays and a writer to write the output
156
+ to the label image.
157
+ """
158
+ return super().iter_as_dask()
159
+
160
+ def map_as_dask(self, func: Callable[[da.Array], da.Array]) -> None:
161
+ """Apply a transformation function to the ROI pixels."""
162
+ return super().map_as_dask(func)
163
+
164
+
165
+ class MaskedSegmentationIterator(SegmentationIterator):
166
+ """Base class for iterators over ROIs."""
167
+
168
+ def __init__(
169
+ self,
170
+ input_image: MaskedImage,
171
+ output_label: Label,
172
+ channel_selection: ChannelSlicingInputType = None,
173
+ axes_order: Sequence[str] | None = None,
174
+ input_transforms: Sequence[TransformProtocol] | None = None,
175
+ output_transforms: Sequence[TransformProtocol] | None = None,
176
+ ) -> None:
177
+ """Initialize the iterator with a ROI table and input/output images.
178
+
179
+ Args:
180
+ input_image (MaskedImage): The input image to be used as input for the
181
+ segmentation.
182
+ output_label (Label): The label image where the ROIs will be written.
183
+ channel_selection (ChannelSlicingInputType): Optional
184
+ selection of channels to use for the segmentation.
185
+ axes_order (Sequence[str] | None): Optional axes order for the
186
+ segmentation.
187
+ input_transforms (Sequence[TransformProtocol] | None): Optional
188
+ transforms to apply to the input image.
189
+ output_transforms (Sequence[TransformProtocol] | None): Optional
190
+ transforms to apply to the output label.
191
+ """
192
+ self._input = input_image
193
+ self._output = output_label
194
+
195
+ self._ref_image = input_image
196
+ self._set_rois(input_image._masking_roi_table.rois())
197
+
198
+ # Set iteration parameters
199
+ self._input_slicing_kwargs = add_channel_selection_to_slicing_dict(
200
+ image=self._input, channel_selection=channel_selection, slicing_dict={}
201
+ )
202
+ self._channel_selection = channel_selection
203
+ self._axes_order = axes_order
204
+ self._input_transforms = input_transforms
205
+ self._output_transforms = output_transforms
206
+
207
+ # Check compatibility between input and output images
208
+ if not self._input.dimensions.is_compatible_with(self._output.dimensions):
209
+ raise NgioValidationError(
210
+ "Input image and output label have incompatible dimensions. "
211
+ f"Input: {self._input.dimensions}, Output: {self._output.dimensions}."
212
+ )
213
+
214
+ def get_init_kwargs(self) -> dict:
215
+ """Return the initialization arguments for the iterator."""
216
+ return {
217
+ "input_image": self._input,
218
+ "output_label": self._output,
219
+ "channel_selection": self._channel_selection,
220
+ "axes_order": self._axes_order,
221
+ "input_transforms": self._input_transforms,
222
+ "output_transforms": self._output_transforms,
223
+ }
224
+
225
+ def build_numpy_getter(self, roi: Roi):
226
+ return build_roi_masked_numpy_getter(
227
+ roi=roi,
228
+ zarr_array=self._input.zarr_array,
229
+ dimensions=self._input.dimensions,
230
+ label_zarr_array=self._input._label.zarr_array,
231
+ label_dimensions=self._input._label.dimensions,
232
+ label_pixel_size=self._input._label.pixel_size,
233
+ axes_order=self._axes_order,
234
+ transforms=self._input_transforms,
235
+ pixel_size=self._input.pixel_size,
236
+ slicing_dict=self._input_slicing_kwargs,
237
+ )
238
+
239
+ def build_numpy_setter(self, roi: Roi):
240
+ return build_roi_masked_numpy_setter(
241
+ roi=roi,
242
+ zarr_array=self._output.zarr_array,
243
+ dimensions=self._output.dimensions,
244
+ label_zarr_array=self._input._label.zarr_array,
245
+ label_dimensions=self._input._label.dimensions,
246
+ label_pixel_size=self._input._label.pixel_size,
247
+ axes_order=self._axes_order,
248
+ transforms=self._output_transforms,
249
+ pixel_size=self._output.pixel_size,
250
+ remove_channel_selection=True,
251
+ )
252
+
253
+ def build_dask_getter(self, roi: Roi):
254
+ return build_roi_masked_dask_getter(
255
+ roi=roi,
256
+ zarr_array=self._input.zarr_array,
257
+ dimensions=self._input.dimensions,
258
+ label_zarr_array=self._input._label.zarr_array,
259
+ label_dimensions=self._input._label.dimensions,
260
+ label_pixel_size=self._input._label.pixel_size,
261
+ axes_order=self._axes_order,
262
+ transforms=self._input_transforms,
263
+ pixel_size=self._input.pixel_size,
264
+ slicing_dict=self._input_slicing_kwargs,
265
+ )
266
+
267
+ def build_dask_setter(self, roi: Roi):
268
+ return build_roi_masked_dask_setter(
269
+ roi=roi,
270
+ zarr_array=self._output.zarr_array,
271
+ dimensions=self._output.dimensions,
272
+ label_zarr_array=self._input._label.zarr_array,
273
+ label_dimensions=self._input._label.dimensions,
274
+ label_pixel_size=self._input._label.pixel_size,
275
+ axes_order=self._axes_order,
276
+ transforms=self._output_transforms,
277
+ pixel_size=self._output.pixel_size,
278
+ remove_channel_selection=True,
279
+ )
280
+
281
+ def post_consolidate(self):
282
+ self._output.consolidate()