ngio 0.4.0a3__py3-none-any.whl → 0.4.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ngio/__init__.py +1 -2
  2. ngio/common/__init__.py +2 -51
  3. ngio/common/_dimensions.py +253 -74
  4. ngio/common/_pyramid.py +42 -23
  5. ngio/common/_roi.py +49 -413
  6. ngio/common/_zoom.py +32 -7
  7. ngio/experimental/iterators/__init__.py +0 -2
  8. ngio/experimental/iterators/_abstract_iterator.py +246 -26
  9. ngio/experimental/iterators/_feature.py +90 -52
  10. ngio/experimental/iterators/_image_processing.py +24 -63
  11. ngio/experimental/iterators/_mappers.py +48 -0
  12. ngio/experimental/iterators/_rois_utils.py +4 -4
  13. ngio/experimental/iterators/_segmentation.py +38 -85
  14. ngio/images/_abstract_image.py +192 -95
  15. ngio/images/_create.py +16 -0
  16. ngio/images/_create_synt_container.py +10 -0
  17. ngio/images/_image.py +35 -9
  18. ngio/images/_label.py +26 -3
  19. ngio/images/_masked_image.py +45 -61
  20. ngio/images/_ome_zarr_container.py +33 -0
  21. ngio/io_pipes/__init__.py +75 -0
  22. ngio/io_pipes/_io_pipes.py +361 -0
  23. ngio/io_pipes/_io_pipes_masked.py +488 -0
  24. ngio/io_pipes/_io_pipes_roi.py +152 -0
  25. ngio/io_pipes/_io_pipes_types.py +56 -0
  26. ngio/io_pipes/_match_shape.py +376 -0
  27. ngio/io_pipes/_ops_axes.py +344 -0
  28. ngio/io_pipes/_ops_slices.py +446 -0
  29. ngio/io_pipes/_ops_slices_utils.py +196 -0
  30. ngio/io_pipes/_ops_transforms.py +104 -0
  31. ngio/io_pipes/_zoom_transform.py +175 -0
  32. ngio/ome_zarr_meta/__init__.py +4 -2
  33. ngio/ome_zarr_meta/ngio_specs/__init__.py +4 -4
  34. ngio/ome_zarr_meta/ngio_specs/_axes.py +129 -141
  35. ngio/ome_zarr_meta/ngio_specs/_dataset.py +47 -121
  36. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +30 -22
  37. ngio/ome_zarr_meta/ngio_specs/_pixel_size.py +17 -1
  38. ngio/ome_zarr_meta/v04/_v04_spec_utils.py +33 -30
  39. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  40. ngio/resources/__init__.py +1 -0
  41. ngio/resources/resource_model.py +1 -0
  42. ngio/{common/transforms → transforms}/__init__.py +1 -1
  43. ngio/transforms/_zoom.py +19 -0
  44. ngio/utils/_datasets.py +5 -0
  45. ngio/utils/_zarr_utils.py +5 -1
  46. {ngio-0.4.0a3.dist-info → ngio-0.4.0b1.dist-info}/METADATA +1 -1
  47. ngio-0.4.0b1.dist-info/RECORD +85 -0
  48. ngio/common/_array_io_pipes.py +0 -554
  49. ngio/common/_array_io_utils.py +0 -508
  50. ngio/common/transforms/_label.py +0 -12
  51. ngio/common/transforms/_zoom.py +0 -109
  52. ngio-0.4.0a3.dist-info/RECORD +0 -76
  53. {ngio-0.4.0a3.dist-info → ngio-0.4.0b1.dist-info}/WHEEL +0 -0
  54. {ngio-0.4.0a3.dist-info → ngio-0.4.0b1.dist-info}/licenses/LICENSE +0 -0
ngio/__init__.py CHANGED
@@ -9,7 +9,7 @@ except PackageNotFoundError: # pragma: no cover
9
9
  __author__ = "Lorenzo Cerrone"
10
10
  __email__ = "lorenzo.cerrone@uzh.ch"
11
11
 
12
- from ngio.common import ArrayLike, Dimensions, Roi, RoiPixels
12
+ from ngio.common import Dimensions, Roi, RoiPixels
13
13
  from ngio.hcs import (
14
14
  OmeZarrPlate,
15
15
  OmeZarrWell,
@@ -39,7 +39,6 @@ from ngio.ome_zarr_meta.ngio_specs import (
39
39
  )
40
40
 
41
41
  __all__ = [
42
- "ArrayLike",
43
42
  "AxesSetup",
44
43
  "ChannelSelectionModel",
45
44
  "DefaultNgffVersion",
ngio/common/__init__.py CHANGED
@@ -1,72 +1,23 @@
1
1
  """Common classes and functions that are used across the package."""
2
2
 
3
- from ngio.common._array_io_pipes import (
4
- build_dask_getter,
5
- build_dask_setter,
6
- build_masked_dask_getter,
7
- build_masked_dask_setter,
8
- build_masked_numpy_getter,
9
- build_masked_numpy_setter,
10
- build_numpy_getter,
11
- build_numpy_setter,
12
- )
13
- from ngio.common._array_io_utils import (
14
- ArrayLike,
15
- SlicingInputType,
16
- TransformProtocol,
17
- apply_dask_axes_ops,
18
- apply_numpy_axes_ops,
19
- apply_sequence_axes_ops,
20
- )
21
3
  from ngio.common._dimensions import Dimensions
22
4
  from ngio.common._masking_roi import compute_masking_roi
23
5
  from ngio.common._pyramid import consolidate_pyramid, init_empty_pyramid, on_disk_zoom
24
6
  from ngio.common._roi import (
25
7
  Roi,
26
8
  RoiPixels,
27
- build_roi_dask_getter,
28
- build_roi_dask_setter,
29
- build_roi_masked_dask_getter,
30
- build_roi_masked_dask_setter,
31
- build_roi_masked_numpy_getter,
32
- build_roi_masked_numpy_setter,
33
- build_roi_numpy_getter,
34
- build_roi_numpy_setter,
35
- roi_to_slicing_dict,
36
9
  )
37
- from ngio.common._zoom import dask_zoom, numpy_zoom
10
+ from ngio.common._zoom import InterpolationOrder, dask_zoom, numpy_zoom
38
11
 
39
12
  __all__ = [
40
- "ArrayLike",
41
13
  "Dimensions",
14
+ "InterpolationOrder",
42
15
  "Roi",
43
16
  "RoiPixels",
44
- "SlicingInputType",
45
- "TransformProtocol",
46
- "apply_dask_axes_ops",
47
- "apply_numpy_axes_ops",
48
- "apply_sequence_axes_ops",
49
- "build_dask_getter",
50
- "build_dask_setter",
51
- "build_masked_dask_getter",
52
- "build_masked_dask_setter",
53
- "build_masked_numpy_getter",
54
- "build_masked_numpy_setter",
55
- "build_numpy_getter",
56
- "build_numpy_setter",
57
- "build_roi_dask_getter",
58
- "build_roi_dask_setter",
59
- "build_roi_masked_dask_getter",
60
- "build_roi_masked_dask_setter",
61
- "build_roi_masked_numpy_getter",
62
- "build_roi_masked_numpy_setter",
63
- "build_roi_numpy_getter",
64
- "build_roi_numpy_setter",
65
17
  "compute_masking_roi",
66
18
  "consolidate_pyramid",
67
19
  "dask_zoom",
68
20
  "init_empty_pyramid",
69
21
  "numpy_zoom",
70
22
  "on_disk_zoom",
71
- "roi_to_slicing_dict",
72
23
  ]
@@ -4,153 +4,332 @@ This is not related to the NGFF metadata,
4
4
  but it is based on the actual metadata of the image data.
5
5
  """
6
6
 
7
+ import math
7
8
  from typing import overload
8
9
 
9
- from ngio.ome_zarr_meta import AxesMapper
10
- from ngio.ome_zarr_meta.ngio_specs import AxisType
10
+ from ngio.ome_zarr_meta import (
11
+ AxesHandler,
12
+ )
13
+ from ngio.ome_zarr_meta.ngio_specs._dataset import Dataset
14
+ from ngio.ome_zarr_meta.ngio_specs._pixel_size import PixelSize
11
15
  from ngio.utils import NgioValueError
12
16
 
13
17
 
18
+ def _are_compatible(shape1: int, shape2: int, scaling: float) -> bool:
19
+ """Check if shape2 is consistent with shape1 given pixel sizes.
20
+
21
+ Since we only deal with shape discrepancies due to rounding, we
22
+ shape1, needs to be larger than shape2.
23
+ """
24
+ if shape1 < shape2:
25
+ return _are_compatible(shape2, shape1, 1 / scaling)
26
+ expected_shape2 = shape1 * scaling
27
+ expected_shape2_floor = math.floor(expected_shape2)
28
+ expected_shape2_ceil = math.ceil(expected_shape2)
29
+ return shape2 in {expected_shape2_floor, expected_shape2_ceil}
30
+
31
+
32
+ def require_axes_match(reference: "Dimensions", other: "Dimensions") -> None:
33
+ """Check if two Dimensions objects have the same axes.
34
+
35
+ Besides the channel axis (which is a special case), all axes must be
36
+ present in both Dimensions objects.
37
+
38
+ Args:
39
+ reference (Dimensions): The reference dimensions object to compare against.
40
+ other (Dimensions): The other dimensions object to compare against.
41
+
42
+ Raises:
43
+ NgioValueError: If the axes do not match.
44
+ """
45
+ for s_axis in reference.axes_handler.axes:
46
+ if s_axis.axis_type == "channel":
47
+ continue
48
+ o_axis = other.axes_handler.get_axis(s_axis.name)
49
+ if o_axis is None:
50
+ raise NgioValueError(
51
+ f"Axes do not match. The axis {s_axis.name} "
52
+ f"is not present in either dimensions."
53
+ )
54
+ # Check for axes present in the other dimensions but not in this one
55
+ for o_axis in other.axes_handler.axes:
56
+ if o_axis.axis_type == "channel":
57
+ continue
58
+ s_axis = reference.axes_handler.get_axis(o_axis.name)
59
+ if s_axis is None:
60
+ raise NgioValueError(
61
+ f"Axes do not match. The axis {o_axis.name} "
62
+ f"is not present in either dimensions."
63
+ )
64
+
65
+
66
+ def check_if_axes_match(reference: "Dimensions", other: "Dimensions") -> bool:
67
+ """Check if two Dimensions objects have the same axes.
68
+
69
+ Besides the channel axis (which is a special case), all axes must be
70
+ present in both Dimensions objects.
71
+
72
+ Args:
73
+ reference (Dimensions): The reference dimensions object to compare against.
74
+ other (Dimensions): The other dimensions object to compare against.
75
+
76
+ Returns:
77
+ bool: True if the axes match, False otherwise.
78
+ """
79
+ try:
80
+ require_axes_match(reference, other)
81
+ return True
82
+ except NgioValueError:
83
+ return False
84
+
85
+
86
+ def require_dimensions_match(
87
+ reference: "Dimensions", other: "Dimensions", allow_singleton: bool = False
88
+ ) -> None:
89
+ """Check if two Dimensions objects have the same axes and dimensions.
90
+
91
+ Besides the channel axis, all axes must have the same dimension in
92
+ both images.
93
+
94
+ Args:
95
+ reference (Dimensions): The reference dimensions object to compare against.
96
+ other (Dimensions): The other dimensions object to compare against.
97
+ allow_singleton (bool): Whether to allow singleton dimensions to be
98
+ different. For example, if the input image has shape
99
+ (5, 100, 100) and the label has shape (1, 100, 100).
100
+
101
+ Raises:
102
+ NgioValueError: If the dimensions do not match.
103
+ """
104
+ require_axes_match(reference, other)
105
+ for r_axis in reference.axes_handler.axes:
106
+ if r_axis.axis_type == "channel":
107
+ continue
108
+ o_axis = other.axes_handler.get_axis(r_axis.name)
109
+ assert o_axis is not None # already checked in assert_axes_match
110
+
111
+ r_dim = reference.get(r_axis.name, default=1)
112
+ o_dim = other.get(o_axis.name, default=1)
113
+
114
+ if r_dim != o_dim:
115
+ if allow_singleton and (r_dim == 1 or o_dim == 1):
116
+ continue
117
+ raise NgioValueError(
118
+ f"Dimensions do not match for axis "
119
+ f"{r_axis.name}. Got {r_dim} and {o_dim}."
120
+ )
121
+
122
+
123
+ def check_if_dimensions_match(
124
+ reference: "Dimensions", other: "Dimensions", allow_singleton: bool = False
125
+ ) -> bool:
126
+ """Check if two Dimensions objects have the same axes and dimensions.
127
+
128
+ Besides the channel axis, all axes must have the same dimension in
129
+ both images.
130
+
131
+ Args:
132
+ reference (Dimensions): The reference dimensions object to compare against.
133
+ other (Dimensions): The other dimensions object to compare against.
134
+ allow_singleton (bool): Whether to allow singleton dimensions to be
135
+ different. For example, if the input image has shape
136
+ (5, 100, 100) and the label has shape (1, 100, 100).
137
+
138
+ Returns:
139
+ bool: True if the dimensions match, False otherwise.
140
+ """
141
+ try:
142
+ require_dimensions_match(reference, other, allow_singleton)
143
+ return True
144
+ except NgioValueError:
145
+ return False
146
+
147
+
148
+ def require_rescalable(reference: "Dimensions", other: "Dimensions") -> None:
149
+ """Assert that two images can be rescaled.
150
+
151
+ For this to be true, the images must have the same axes, and
152
+ the pixel sizes must be compatible (i.e. one can be scaled to the other).
153
+
154
+ Args:
155
+ reference (Dimensions): The reference dimensions object to compare against.
156
+ other (Dimensions): The other dimensions object to compare against.
157
+
158
+ """
159
+ require_axes_match(reference, other)
160
+ for ax_r in reference.axes_handler.axes:
161
+ if ax_r.axis_type == "channel":
162
+ continue
163
+ ax_o = other.axes_handler.get_axis(ax_r.name)
164
+ assert ax_o is not None, "Axes do not match."
165
+ px_r = reference.pixel_size.get(ax_r.name, default=1.0)
166
+ px_o = other.pixel_size.get(ax_o.name, default=1.0)
167
+ shape_r = reference.get(ax_r.name, default=1)
168
+ shape_o = other.get(ax_o.name, default=1)
169
+ scale = px_r / px_o
170
+ if not _are_compatible(
171
+ shape1=shape_r,
172
+ shape2=shape_o,
173
+ scaling=scale,
174
+ ):
175
+ raise NgioValueError(
176
+ f"Reference image with shape {reference.shape}, "
177
+ f"and pixel size {reference.pixel_size}, "
178
+ f"cannot be rescaled to "
179
+ f"image with shape {other.shape} "
180
+ f"and pixel size {other.pixel_size}. "
181
+ )
182
+
183
+
184
+ def check_if_rescalable(reference: "Dimensions", other: "Dimensions") -> bool:
185
+ """Check if two images can be rescaled.
186
+
187
+ For this to be true, the images must have the same axes, and
188
+ the pixel sizes must be compatible (i.e. one can be scaled to the other).
189
+
190
+ Args:
191
+ reference (Dimensions): The reference dimensions object to compare against.
192
+ other (Dimensions): The other dimensions object to compare against.
193
+
194
+ Returns:
195
+ bool: True if the images can be rescaled, False otherwise.
196
+ """
197
+ try:
198
+ require_rescalable(reference, other)
199
+ return True
200
+ except NgioValueError:
201
+ return False
202
+
203
+
14
204
  class Dimensions:
15
- """Dimension metadata."""
205
+ """Dimension metadata Handling Class.
206
+
207
+ This class is used to handle and manipulate dimension metadata.
208
+ It provides methods to access and validate dimension information,
209
+ such as shape, axes, and properties like is_2d, is_3d, is_time_series, etc.
210
+ """
211
+
212
+ require_axes_match = require_axes_match
213
+ check_if_axes_match = check_if_axes_match
214
+ require_dimensions_match = require_dimensions_match
215
+ check_if_dimensions_match = check_if_dimensions_match
216
+ require_rescalable = require_rescalable
217
+ check_if_rescalable = check_if_rescalable
16
218
 
17
219
  def __init__(
18
220
  self,
19
221
  shape: tuple[int, ...],
20
- axes_mapper: AxesMapper,
222
+ chunks: tuple[int, ...],
223
+ dataset: Dataset,
21
224
  ) -> None:
22
225
  """Create a Dimension object from a Zarr array.
23
226
 
24
227
  Args:
25
228
  shape: The shape of the Zarr array.
26
- axes_mapper: The axes mapper object.
229
+ chunks: The chunks of the Zarr array.
230
+ dataset: The dataset object.
27
231
  """
28
232
  self._shape = shape
29
- self._axes_mapper = axes_mapper
233
+ self._chunks = chunks
234
+ self._axes_handler = dataset.axes_handler
235
+ self._pixel_size = dataset.pixel_size
30
236
 
31
- if len(self._shape) != len(self._axes_mapper.axes):
237
+ if len(self._shape) != len(self._axes_handler.axes):
32
238
  raise NgioValueError(
33
239
  "The number of dimensions must match the number of axes. "
34
- f"Expected Axis {self._axes_mapper.axes_names} but got shape "
240
+ f"Expected Axis {self._axes_handler.axes_names} but got shape "
35
241
  f"{self._shape}."
36
242
  )
37
243
 
38
244
  def __str__(self) -> str:
39
245
  """Return the string representation of the object."""
40
246
  dims = ", ".join(
41
- f"{ax.on_disk_name}: {s}"
42
- for ax, s in zip(self._axes_mapper.axes, self._shape, strict=True)
247
+ f"{ax.name}: {s}"
248
+ for ax, s in zip(self._axes_handler.axes, self._shape, strict=True)
43
249
  )
44
250
  return f"Dimensions({dims})"
45
251
 
46
- @overload
47
- def get(self, axis_name: str, default: None = None) -> int | None:
48
- pass
49
-
50
- @overload
51
- def get(self, axis_name: str, default: int) -> int:
52
- pass
53
-
54
- def get(self, axis_name: str, default: int | None = None) -> int | None:
55
- """Return the dimension of the given axis name.
56
-
57
- Args:
58
- axis_name: The name of the axis (either canonical or non-canonical).
59
- default: The default value to return if the axis does not exist.
60
- """
61
- index = self._axes_mapper.get_index(axis_name)
62
- if index is None:
63
- return default
64
- return self._shape[index]
65
-
66
- def get_index(self, axis_name: str) -> int | None:
67
- """Return the index of the given axis name.
68
-
69
- Args:
70
- axis_name: The name of the axis (either canonical or non-canonical).
71
- """
72
- return self._axes_mapper.get_index(axis_name)
73
-
74
- def has_axis(self, axis_name: str) -> bool:
75
- """Return whether the axis exists."""
76
- index = self._axes_mapper.get_axis(axis_name)
77
- if index is None:
78
- return False
79
- return True
80
-
81
252
  def __repr__(self) -> str:
82
253
  """Return the string representation of the object."""
83
254
  return str(self)
84
255
 
85
256
  @property
86
- def axes_mapper(self) -> AxesMapper:
87
- """Return the axes mapper object."""
88
- return self._axes_mapper
257
+ def axes_handler(self) -> AxesHandler:
258
+ """Return the axes handler object."""
259
+ return self._axes_handler
260
+
261
+ @property
262
+ def pixel_size(self) -> PixelSize:
263
+ """Return the pixel size object."""
264
+ return self._pixel_size
89
265
 
90
266
  @property
91
267
  def shape(self) -> tuple[int, ...]:
92
268
  """Return the shape as a tuple."""
93
- return tuple(self._shape)
269
+ return self._shape
270
+
271
+ @property
272
+ def chunks(self) -> tuple[int, ...]:
273
+ """Return the chunks as a tuple."""
274
+ return self._chunks
94
275
 
95
276
  @property
96
277
  def axes(self) -> tuple[str, ...]:
97
278
  """Return the axes as a tuple of strings."""
98
- return self._axes_mapper.axes_names
279
+ return self.axes_handler.axes_names
99
280
 
100
281
  @property
101
282
  def is_time_series(self) -> bool:
102
- """Return whether the data is a time series."""
283
+ """Return whether the image is a time series."""
103
284
  if self.get("t", default=1) == 1:
104
285
  return False
105
286
  return True
106
287
 
107
288
  @property
108
289
  def is_2d(self) -> bool:
109
- """Return whether the data is 2D."""
290
+ """Return whether the image is 2D."""
110
291
  if self.get("z", default=1) != 1:
111
292
  return False
112
293
  return True
113
294
 
114
295
  @property
115
296
  def is_2d_time_series(self) -> bool:
116
- """Return whether the data is a 2D time series."""
297
+ """Return whether the image is a 2D time series."""
117
298
  return self.is_2d and self.is_time_series
118
299
 
119
300
  @property
120
301
  def is_3d(self) -> bool:
121
- """Return whether the data is 3D."""
302
+ """Return whether the image is 3D."""
122
303
  return not self.is_2d
123
304
 
124
305
  @property
125
306
  def is_3d_time_series(self) -> bool:
126
- """Return whether the data is a 3D time series."""
307
+ """Return whether the image is a 3D time series."""
127
308
  return self.is_3d and self.is_time_series
128
309
 
129
310
  @property
130
311
  def is_multi_channels(self) -> bool:
131
- """Return whether the data has multiple channels."""
312
+ """Return whether the image has multiple channels."""
132
313
  if self.get("c", default=1) == 1:
133
314
  return False
134
315
  return True
135
316
 
136
- def is_compatible_with(self, other: "Dimensions") -> bool:
137
- """Check if the dimensions are compatible with another Dimensions object.
317
+ @overload
318
+ def get(self, axis_name: str, default: None = None) -> int | None:
319
+ pass
138
320
 
139
- Two dimensions are compatible if:
140
- - they have the same number of axes (excluding channels)
141
- - the shape of each axis is the same
142
- """
143
- if abs(len(self.shape) - len(other.shape)) > 1:
144
- # Since channels are not considered in compatibility
145
- # we allow a difference of 0, 1 n-dimension in the shapes.
146
- return False
321
+ @overload
322
+ def get(self, axis_name: str, default: int) -> int:
323
+ pass
147
324
 
148
- for ax in self._axes_mapper.axes:
149
- if ax.axis_type == AxisType.channel:
150
- continue
325
+ def get(self, axis_name: str, default: int | None = None) -> int | None:
326
+ """Return the dimension/shape of the given axis name.
151
327
 
152
- self_shape = self.get(ax.on_disk_name, default=None)
153
- other_shape = other.get(ax.on_disk_name, default=None)
154
- if self_shape != other_shape:
155
- return False
156
- return True
328
+ Args:
329
+ axis_name: The name of the axis (either canonical or non-canonical).
330
+ default: The default value to return if the axis does not exist.
331
+ """
332
+ index = self.axes_handler.get_index(axis_name)
333
+ if index is None:
334
+ return default
335
+ return self._shape[index]
ngio/common/_pyramid.py CHANGED
@@ -5,8 +5,14 @@ from typing import Literal
5
5
  import dask.array as da
6
6
  import numpy as np
7
7
  import zarr
8
+ from zarr.types import DIMENSION_SEPARATOR
8
9
 
9
- from ngio.common._zoom import _zoom_inputs_check, dask_zoom, numpy_zoom
10
+ from ngio.common._zoom import (
11
+ InterpolationOrder,
12
+ _zoom_inputs_check,
13
+ dask_zoom,
14
+ numpy_zoom,
15
+ )
10
16
  from ngio.utils import (
11
17
  AccessModeLiteral,
12
18
  NgioValueError,
@@ -18,7 +24,7 @@ from ngio.utils import (
18
24
  def _on_disk_numpy_zoom(
19
25
  source: zarr.Array,
20
26
  target: zarr.Array,
21
- order: Literal[0, 1, 2] = 1,
27
+ order: InterpolationOrder,
22
28
  ) -> None:
23
29
  target[...] = numpy_zoom(source[...], target_shape=target.shape, order=order)
24
30
 
@@ -26,7 +32,7 @@ def _on_disk_numpy_zoom(
26
32
  def _on_disk_dask_zoom(
27
33
  source: zarr.Array,
28
34
  target: zarr.Array,
29
- order: Literal[0, 1, 2] = 1,
35
+ order: InterpolationOrder,
30
36
  ) -> None:
31
37
  source_array = da.from_zarr(source)
32
38
  target_array = dask_zoom(source_array, target_shape=target.shape, order=order)
@@ -39,7 +45,7 @@ def _on_disk_dask_zoom(
39
45
  def _on_disk_coarsen(
40
46
  source: zarr.Array,
41
47
  target: zarr.Array,
42
- _order: Literal[0, 1] = 1,
48
+ order: InterpolationOrder = "linear",
43
49
  aggregation_function: Callable | None = None,
44
50
  ) -> None:
45
51
  """Apply a coarsening operation from a source zarr array to a target zarr array.
@@ -47,10 +53,10 @@ def _on_disk_coarsen(
47
53
  Args:
48
54
  source (zarr.Array): The source array to coarsen.
49
55
  target (zarr.Array): The target array to save the coarsened result to.
50
- _order (Literal[0, 1]): The order of interpolation is not really implemented
56
+ order (InterpolationOrder): The order of interpolation is not really implemented
51
57
  for coarsening, but it is kept for compatibility with the zoom function.
52
- _order=1 -> linear interpolation ~ np.mean
53
- _order=0 -> nearest interpolation ~ np.max
58
+ order="linear" -> linear interpolation ~ np.mean
59
+ order="nearest" -> nearest interpolation ~ np.max
54
60
  aggregation_function (np.ufunc): The aggregation function to use.
55
61
  """
56
62
  source_array = da.from_zarr(source)
@@ -64,13 +70,15 @@ def _on_disk_coarsen(
64
70
  )
65
71
 
66
72
  if aggregation_function is None:
67
- if _order == 1:
73
+ if order == "linear":
68
74
  aggregation_function = np.mean
69
- elif _order == 0:
75
+ elif order == "nearest":
70
76
  aggregation_function = np.max
77
+ elif order == "cubic":
78
+ raise NgioValueError("Cubic interpolation is not supported for coarsening.")
71
79
  else:
72
80
  raise NgioValueError(
73
- f"Aggregation function must be provided for order {_order}"
81
+ f"Aggregation function must be provided for order {order}"
74
82
  )
75
83
 
76
84
  coarsening_setup = {}
@@ -96,7 +104,7 @@ def _on_disk_coarsen(
96
104
  def on_disk_zoom(
97
105
  source: zarr.Array,
98
106
  target: zarr.Array,
99
- order: Literal[0, 1, 2] = 1,
107
+ order: InterpolationOrder = "linear",
100
108
  mode: Literal["dask", "numpy", "coarsen"] = "dask",
101
109
  ) -> None:
102
110
  """Apply a zoom operation from a source zarr array to a target zarr array.
@@ -104,7 +112,7 @@ def on_disk_zoom(
104
112
  Args:
105
113
  source (zarr.Array): The source array to zoom.
106
114
  target (zarr.Array): The target array to save the zoomed result to.
107
- order (Literal[0, 1, 2]): The order of interpolation. Defaults to 1.
115
+ order (InterpolationOrder): The order of interpolation. Defaults to "linear".
108
116
  mode (Literal["dask", "numpy", "coarsen"]): The mode to use. Defaults to "dask".
109
117
  """
110
118
  if not isinstance(source, zarr.Array):
@@ -155,7 +163,7 @@ def _find_closest_arrays(
155
163
  def consolidate_pyramid(
156
164
  source: zarr.Array,
157
165
  targets: list[zarr.Array],
158
- order: Literal[0, 1, 2] = 1,
166
+ order: InterpolationOrder = "linear",
159
167
  mode: Literal["dask", "numpy", "coarsen"] = "dask",
160
168
  ) -> None:
161
169
  """Consolidate the Zarr array."""
@@ -177,6 +185,15 @@ def consolidate_pyramid(
177
185
  processed.append(target_image)
178
186
 
179
187
 
188
+ def _maybe_int(value: float | int) -> float | int:
189
+ """Convert a float to an int if it is an integer."""
190
+ if isinstance(value, int):
191
+ return value
192
+ if value.is_integer():
193
+ return int(value)
194
+ return value
195
+
196
+
180
197
  def init_empty_pyramid(
181
198
  store: StoreOrGroup,
182
199
  paths: list[str],
@@ -185,6 +202,8 @@ def init_empty_pyramid(
185
202
  chunks: Sequence[int] | None = None,
186
203
  dtype: str = "uint16",
187
204
  mode: AccessModeLiteral = "a",
205
+ dimension_separator: DIMENSION_SEPARATOR = "/",
206
+ compressor="default",
188
207
  ) -> None:
189
208
  # Return the an Image object
190
209
  if chunks is not None and len(chunks) != len(ref_shape):
@@ -200,6 +219,10 @@ def init_empty_pyramid(
200
219
  "The shape and scaling factor must have the same number of dimensions."
201
220
  )
202
221
 
222
+ # Ensure scaling factors are int if possible
223
+ # To reduce the risk of floating point issues
224
+ scaling_factors = [_maybe_int(s) for s in scaling_factors]
225
+
203
226
  root_group = open_group_wrapper(store, mode=mode)
204
227
 
205
228
  for path in paths:
@@ -213,22 +236,18 @@ def init_empty_pyramid(
213
236
  shape=ref_shape,
214
237
  dtype=dtype,
215
238
  chunks=chunks,
216
- dimension_separator="/",
239
+ dimension_separator=dimension_separator,
217
240
  overwrite=True,
241
+ compressor=compressor,
218
242
  )
219
243
 
220
- # Todo redo this with when a proper build of pyramid is implemented
221
- _shape = []
222
- for s, sc in zip(ref_shape, scaling_factors, strict=True):
223
- if math.floor(s / sc) % 2 == 0:
224
- _shape.append(math.floor(s / sc))
225
- else:
226
- _shape.append(math.ceil(s / sc))
244
+ _shape = [
245
+ math.floor(s / sc) for s, sc in zip(ref_shape, scaling_factors, strict=True)
246
+ ]
227
247
  ref_shape = _shape
228
248
 
229
249
  if chunks is None:
230
250
  chunks = new_arr.chunks
231
- if chunks is None:
232
- raise NgioValueError("Something went wrong with the chunks")
251
+ assert chunks is not None
233
252
  chunks = [min(c, s) for c, s in zip(chunks, ref_shape, strict=True)]
234
253
  return None