ngio 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. ngio/__init__.py +7 -2
  2. ngio/common/__init__.py +5 -52
  3. ngio/common/_dimensions.py +270 -55
  4. ngio/common/_masking_roi.py +38 -10
  5. ngio/common/_pyramid.py +51 -30
  6. ngio/common/_roi.py +269 -82
  7. ngio/common/_synt_images_utils.py +101 -0
  8. ngio/common/_zoom.py +49 -19
  9. ngio/experimental/__init__.py +5 -0
  10. ngio/experimental/iterators/__init__.py +15 -0
  11. ngio/experimental/iterators/_abstract_iterator.py +390 -0
  12. ngio/experimental/iterators/_feature.py +189 -0
  13. ngio/experimental/iterators/_image_processing.py +130 -0
  14. ngio/experimental/iterators/_mappers.py +48 -0
  15. ngio/experimental/iterators/_rois_utils.py +127 -0
  16. ngio/experimental/iterators/_segmentation.py +235 -0
  17. ngio/hcs/_plate.py +41 -36
  18. ngio/images/__init__.py +22 -1
  19. ngio/images/_abstract_image.py +403 -176
  20. ngio/images/_create.py +31 -15
  21. ngio/images/_create_synt_container.py +138 -0
  22. ngio/images/_image.py +452 -63
  23. ngio/images/_label.py +56 -30
  24. ngio/images/_masked_image.py +387 -129
  25. ngio/images/_ome_zarr_container.py +237 -67
  26. ngio/{common → images}/_table_ops.py +41 -41
  27. ngio/io_pipes/__init__.py +75 -0
  28. ngio/io_pipes/_io_pipes.py +361 -0
  29. ngio/io_pipes/_io_pipes_masked.py +488 -0
  30. ngio/io_pipes/_io_pipes_roi.py +152 -0
  31. ngio/io_pipes/_io_pipes_types.py +56 -0
  32. ngio/io_pipes/_match_shape.py +376 -0
  33. ngio/io_pipes/_ops_axes.py +344 -0
  34. ngio/io_pipes/_ops_slices.py +446 -0
  35. ngio/io_pipes/_ops_slices_utils.py +196 -0
  36. ngio/io_pipes/_ops_transforms.py +104 -0
  37. ngio/io_pipes/_zoom_transform.py +175 -0
  38. ngio/ome_zarr_meta/__init__.py +4 -2
  39. ngio/ome_zarr_meta/ngio_specs/__init__.py +4 -10
  40. ngio/ome_zarr_meta/ngio_specs/_axes.py +186 -175
  41. ngio/ome_zarr_meta/ngio_specs/_channels.py +55 -18
  42. ngio/ome_zarr_meta/ngio_specs/_dataset.py +48 -122
  43. ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +6 -15
  44. ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +38 -87
  45. ngio/ome_zarr_meta/ngio_specs/_pixel_size.py +17 -1
  46. ngio/ome_zarr_meta/v04/_v04_spec_utils.py +34 -31
  47. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/mask.png +0 -0
  48. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/nuclei.png +0 -0
  49. ngio/resources/20200812-CardiomyocyteDifferentiation14-Cycle1_B03/raw.jpg +0 -0
  50. ngio/resources/__init__.py +55 -0
  51. ngio/resources/resource_model.py +36 -0
  52. ngio/tables/backends/_abstract_backend.py +5 -6
  53. ngio/tables/backends/_anndata.py +1 -2
  54. ngio/tables/backends/_anndata_utils.py +3 -3
  55. ngio/tables/backends/_non_zarr_backends.py +1 -1
  56. ngio/tables/backends/_table_backends.py +0 -1
  57. ngio/tables/backends/_utils.py +3 -3
  58. ngio/tables/v1/_roi_table.py +165 -70
  59. ngio/transforms/__init__.py +5 -0
  60. ngio/transforms/_zoom.py +19 -0
  61. ngio/utils/__init__.py +2 -3
  62. ngio/utils/_datasets.py +5 -0
  63. ngio/utils/_logger.py +19 -0
  64. ngio/utils/_zarr_utils.py +6 -6
  65. {ngio-0.3.4.dist-info → ngio-0.4.0.dist-info}/METADATA +24 -22
  66. ngio-0.4.0.dist-info/RECORD +85 -0
  67. ngio/common/_array_pipe.py +0 -288
  68. ngio/common/_axes_transforms.py +0 -64
  69. ngio/common/_common_types.py +0 -5
  70. ngio/common/_slicer.py +0 -96
  71. ngio-0.3.4.dist-info/RECORD +0 -61
  72. {ngio-0.3.4.dist-info → ngio-0.4.0.dist-info}/WHEEL +0 -0
  73. {ngio-0.3.4.dist-info → ngio-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,15 @@
1
1
  """Fractal internal module for axes handling."""
2
2
 
3
- from collections.abc import Collection
3
+ from collections.abc import Sequence
4
4
  from enum import Enum
5
- from typing import Literal, TypeVar
5
+ from typing import Literal, TypeAlias, TypeVar
6
6
 
7
- import numpy as np
8
7
  from pydantic import BaseModel, ConfigDict, Field
9
8
 
10
- from ngio.utils import NgioValidationError, NgioValueError, ngio_logger
9
+ from ngio.utils import NgioValidationError, NgioValueError
11
10
 
12
11
  T = TypeVar("T")
12
+ SlicingType: TypeAlias = slice | tuple[int, ...] | int
13
13
 
14
14
  ################################################################################################
15
15
  #
@@ -90,7 +90,7 @@ DefaultTimeUnit = "second"
90
90
  class Axis(BaseModel):
91
91
  """Axis infos model."""
92
92
 
93
- on_disk_name: str
93
+ name: str
94
94
  unit: str | None = None
95
95
  axis_type: AxisType | None = None
96
96
 
@@ -98,27 +98,13 @@ class Axis(BaseModel):
98
98
 
99
99
  def implicit_type_cast(self, cast_type: AxisType) -> "Axis":
100
100
  unit = self.unit
101
- if self.axis_type != cast_type:
102
- ngio_logger.warning(
103
- f"Axis {self.on_disk_name} has type {self.axis_type}. "
104
- f"Casting to {cast_type}."
105
- )
106
-
107
101
  if cast_type == AxisType.time and unit is None:
108
- ngio_logger.warning(
109
- f"Time axis {self.on_disk_name} has unit {self.unit}. "
110
- f"Casting to {DefaultSpaceUnit}."
111
- )
112
102
  unit = DefaultTimeUnit
113
103
 
114
104
  if cast_type == AxisType.space and unit is None:
115
- ngio_logger.warning(
116
- f"Space axis {self.on_disk_name} has unit {unit}. "
117
- f"Casting to {DefaultSpaceUnit}."
118
- )
119
105
  unit = DefaultSpaceUnit
120
106
 
121
- return Axis(on_disk_name=self.on_disk_name, axis_type=cast_type, unit=unit)
107
+ return Axis(name=self.name, axis_type=cast_type, unit=unit)
122
108
 
123
109
  def canonical_axis_cast(self, canonical_name: str) -> "Axis":
124
110
  """Cast the implicit axis to the correct type."""
@@ -170,10 +156,40 @@ class AxesSetup(BaseModel):
170
156
 
171
157
  model_config = ConfigDict(extra="forbid", frozen=True)
172
158
 
173
-
174
- def _check_unique_names(axes: Collection[Axis]):
159
+ def canonical_map(self) -> dict[str, str]:
160
+ """Get the canonical map of axes."""
161
+ return {
162
+ "t": self.t,
163
+ "c": self.c,
164
+ "z": self.z,
165
+ "y": self.y,
166
+ "x": self.x,
167
+ }
168
+
169
+ def get_on_disk_name(self, canonical_name: str) -> str | None:
170
+ """Get the on disk name of the axis by its canonical name."""
171
+ canonical_map = self.canonical_map()
172
+ return canonical_map.get(canonical_name, None)
173
+
174
+ def inverse_canonical_map(self) -> dict[str, str]:
175
+ """Get the on disk map of axes."""
176
+ return {
177
+ self.t: "t",
178
+ self.c: "c",
179
+ self.z: "z",
180
+ self.y: "y",
181
+ self.x: "x",
182
+ }
183
+
184
+ def get_canonical_name(self, on_disk_name: str) -> str | None:
185
+ """Get the canonical name of the axis by its on disk name."""
186
+ inv_map = self.inverse_canonical_map()
187
+ return inv_map.get(on_disk_name, None)
188
+
189
+
190
+ def _check_unique_names(axes: Sequence[Axis]):
175
191
  """Check if all axes on disk have unique names."""
176
- names = [ax.on_disk_name for ax in axes]
192
+ names = [ax.name for ax in axes]
177
193
  if len(set(names)) != len(names):
178
194
  duplicates = {item for item in names if names.count(item) > 1}
179
195
  raise NgioValidationError(
@@ -190,41 +206,41 @@ def _check_non_canonical_axes(axes_setup: AxesSetup, allow_non_canonical_axes: b
190
206
  )
191
207
 
192
208
 
193
- def _check_axes_validity(axes: Collection[Axis], axes_setup: AxesSetup):
209
+ def _check_axes_validity(axes: Sequence[Axis], axes_setup: AxesSetup):
194
210
  """Check if all axes are valid."""
195
211
  _axes_setup = axes_setup.model_dump(exclude={"others"})
196
212
  _all_known_axes = [*_axes_setup.values(), *axes_setup.others]
197
213
  for ax in axes:
198
- if ax.on_disk_name not in _all_known_axes:
214
+ if ax.name not in _all_known_axes:
199
215
  raise NgioValidationError(
200
- f"Invalid axis name '{ax.on_disk_name}'. "
201
- f"Please correct map `{ax.on_disk_name}` "
216
+ f"Invalid axis name '{ax.name}'. "
217
+ f"Please correct map `{ax.name}` "
202
218
  f"using the AxesSetup model {axes_setup}"
203
219
  )
204
220
 
205
221
 
206
222
  def _check_canonical_order(
207
- axes: Collection[Axis], axes_setup: AxesSetup, strict_canonical_order: bool
223
+ axes: Sequence[Axis], axes_setup: AxesSetup, strict_canonical_order: bool
208
224
  ):
209
225
  """Check if the axes are in the canonical order."""
210
226
  if not strict_canonical_order:
211
227
  return
212
- _on_disk_names = [ax.on_disk_name for ax in axes]
228
+ _names = [ax.name for ax in axes]
213
229
  _canonical_order = []
214
230
  for name in canonical_axes_order():
215
231
  mapped_name = getattr(axes_setup, name)
216
- if mapped_name in _on_disk_names:
232
+ if mapped_name in _names:
217
233
  _canonical_order.append(mapped_name)
218
234
 
219
- if _on_disk_names != _canonical_order:
235
+ if _names != _canonical_order:
220
236
  raise NgioValidationError(
221
237
  f"Invalid axes order. The axes must be in the canonical order. "
222
- f"Expected {_canonical_order}, but found {_on_disk_names}"
238
+ f"Expected {_canonical_order}, but found {_names}"
223
239
  )
224
240
 
225
241
 
226
242
  def validate_axes(
227
- axes: Collection[Axis],
243
+ axes: Sequence[Axis],
228
244
  axes_setup: AxesSetup,
229
245
  allow_non_canonical_axes: bool = False,
230
246
  strict_canonical_order: bool = False,
@@ -246,33 +262,20 @@ def validate_axes(
246
262
  )
247
263
 
248
264
 
249
- class AxesTransformation(BaseModel):
250
- model_config = ConfigDict(extra="forbid", frozen=True, arbitrary_types_allowed=True)
251
-
252
-
253
- class AxesTranspose(AxesTransformation):
254
- axes: tuple[int, ...]
255
-
256
-
257
- class AxesExpand(AxesTransformation):
258
- axes: tuple[int, ...]
259
-
260
-
261
- class AxesSqueeze(AxesTransformation):
262
- axes: tuple[int, ...]
263
-
264
-
265
- class AxesMapper:
266
- """Map on disk axes to canonical axes.
267
-
268
- This class is used to map the on disk axes to the canonical axes.
265
+ class AxesHandler:
266
+ """This class is used to handle and operate on OME-Zarr axes.
269
267
 
268
+ The class also provides:
269
+ - methods to reorder, squeeze and expand axes.
270
+ - methods to validate the axes.
271
+ - methods to get axis by name or index.
272
+ - methods to operate on the axes.
270
273
  """
271
274
 
272
275
  def __init__(
273
276
  self,
274
277
  # spec dictated args
275
- on_disk_axes: Collection[Axis],
278
+ axes: Sequence[Axis],
276
279
  # user defined args
277
280
  axes_setup: AxesSetup | None = None,
278
281
  allow_non_canonical_axes: bool = False,
@@ -281,7 +284,7 @@ class AxesMapper:
281
284
  """Create a new AxesMapper object.
282
285
 
283
286
  Args:
284
- on_disk_axes (list[Axis]): The axes on disk.
287
+ axes (list[Axis]): The axes on disk.
285
288
  axes_setup (AxesSetup, optional): The axis setup. Defaults to None.
286
289
  allow_non_canonical_axes (bool, optional): Allow non canonical axes.
287
290
  strict_canonical_order (bool, optional): Check if the axes are in the
@@ -290,7 +293,7 @@ class AxesMapper:
290
293
  axes_setup = axes_setup if axes_setup is not None else AxesSetup()
291
294
 
292
295
  validate_axes(
293
- axes=on_disk_axes,
296
+ axes=axes,
294
297
  axes_setup=axes_setup,
295
298
  allow_non_canonical_axes=allow_non_canonical_axes,
296
299
  strict_canonical_order=strict_canonical_order,
@@ -300,56 +303,42 @@ class AxesMapper:
300
303
  self._strict_canonical_order = strict_canonical_order
301
304
 
302
305
  self._canonical_order = canonical_axes_order()
303
- self._extended_canonical_order = [*axes_setup.others, *self._canonical_order]
304
306
 
305
- self._on_disk_axes = on_disk_axes
307
+ self._axes = axes
306
308
  self._axes_setup = axes_setup
307
309
 
308
- self._name_mapping = self._compute_name_mapping()
309
310
  self._index_mapping = self._compute_index_mapping()
310
311
 
311
312
  # Validate the axes type and cast them if necessary
312
313
  # This needs to be done after the name mapping is computed
313
- self.validate_axex_type()
314
-
315
- def _compute_name_mapping(self):
316
- """Compute the name mapping.
317
-
318
- The name mapping is a dictionary with keys the canonical axes names
319
- and values the on disk axes names.
320
- """
321
- _name_mapping = {}
322
- axis_setup_dict = self._axes_setup.model_dump(exclude={"others"})
323
- _on_disk_names = self.on_disk_axes_names
324
- for canonical_key, on_disk_value in axis_setup_dict.items():
325
- if on_disk_value in _on_disk_names:
326
- _name_mapping[canonical_key] = on_disk_value
327
- else:
328
- _name_mapping[canonical_key] = None
329
-
330
- for on_disk_name in _on_disk_names:
331
- if on_disk_name not in _name_mapping.keys():
332
- _name_mapping[on_disk_name] = on_disk_name
333
-
334
- for other in self._axes_setup.others:
335
- if other not in _name_mapping.keys():
336
- _name_mapping[other] = None
337
- return _name_mapping
314
+ self.validate_axes_type()
338
315
 
339
316
  def _compute_index_mapping(self):
340
317
  """Compute the index mapping.
341
318
 
342
319
  The index mapping is a dictionary with keys the canonical axes names
343
320
  and values the on disk axes index.
321
+
322
+ Example:
323
+ If the on disk axes are ['channel', 't', 'z', 'y', 'x'],
324
+ the index mapping will be:
325
+ {
326
+ 'c': 0,
327
+ 'channel': 0,
328
+ 't': 1,
329
+ 'z': 2,
330
+ 'y': 3,
331
+ 'x': 4,
332
+ }
344
333
  """
345
334
  _index_mapping = {}
346
- for canonical_key, on_disk_value in self._name_mapping.items():
347
- if on_disk_value is not None:
348
- _index_mapping[canonical_key] = self.on_disk_axes_names.index(
349
- on_disk_value
350
- )
351
- else:
352
- _index_mapping[canonical_key] = None
335
+ for i, ax in enumerate(self.axes_names):
336
+ _index_mapping[ax] = i
337
+ # If the axis is not in the canonical order we also set it.
338
+ canonical_map = self._axes_setup.canonical_map()
339
+ for canonical_name, on_disk_name in canonical_map.items():
340
+ if on_disk_name in _index_mapping.keys():
341
+ _index_mapping[canonical_name] = _index_mapping[on_disk_name]
353
342
  return _index_mapping
354
343
 
355
344
  @property
@@ -358,12 +347,12 @@ class AxesMapper:
358
347
  return self._axes_setup
359
348
 
360
349
  @property
361
- def on_disk_axes(self) -> list[Axis]:
362
- return list(self._on_disk_axes)
350
+ def axes(self) -> tuple[Axis, ...]:
351
+ return tuple(self._axes)
363
352
 
364
353
  @property
365
- def on_disk_axes_names(self) -> list[str]:
366
- return [ax.on_disk_name for ax in self._on_disk_axes]
354
+ def axes_names(self) -> tuple[str, ...]:
355
+ return tuple(ax.name for ax in self._axes)
367
356
 
368
357
  @property
369
358
  def allow_non_canonical_axes(self) -> bool:
@@ -375,103 +364,119 @@ class AxesMapper:
375
364
  """Return if strict canonical order is enforced."""
376
365
  return self._strict_canonical_order
377
366
 
367
+ @property
368
+ def space_unit(self) -> str | None:
369
+ """Return the space unit for a given axis."""
370
+ x_axis = self.get_axis("x")
371
+ y_axis = self.get_axis("y")
372
+
373
+ if x_axis is None or y_axis is None:
374
+ raise NgioValidationError(
375
+ "The dataset must have x and y axes to determine the space unit."
376
+ )
377
+
378
+ if x_axis.unit == y_axis.unit:
379
+ return x_axis.unit
380
+ else:
381
+ raise NgioValidationError(
382
+ "Inconsistent space units. "
383
+ f"x={x_axis.unit} and y={y_axis.unit} should have the same unit."
384
+ )
385
+
386
+ @property
387
+ def time_unit(self) -> str | None:
388
+ """Return the time unit for a given axis."""
389
+ t_axis = self.get_axis("t")
390
+ if t_axis is None:
391
+ return None
392
+ return t_axis.unit
393
+
394
+ def to_units(
395
+ self,
396
+ *,
397
+ space_unit: SpaceUnits = DefaultSpaceUnit,
398
+ time_unit: TimeUnits = DefaultTimeUnit,
399
+ ) -> "AxesHandler":
400
+ """Convert the pixel size to the given units.
401
+
402
+ Args:
403
+ space_unit(str): The space unit to convert to.
404
+ time_unit(str): The time unit to convert to.
405
+ """
406
+ new_axes = []
407
+ for ax in self.axes:
408
+ if ax.axis_type == AxisType.space:
409
+ new_ax = Axis(
410
+ name=ax.name,
411
+ axis_type=ax.axis_type,
412
+ unit=space_unit,
413
+ )
414
+ new_axes.append(new_ax)
415
+ elif ax.axis_type == AxisType.time:
416
+ new_ax = Axis(name=ax.name, axis_type=ax.axis_type, unit=time_unit)
417
+ new_axes.append(new_ax)
418
+ else:
419
+ new_axes.append(ax)
420
+
421
+ return AxesHandler(
422
+ axes=new_axes,
423
+ axes_setup=self.axes_setup,
424
+ allow_non_canonical_axes=self.allow_non_canonical_axes,
425
+ strict_canonical_order=self.strict_canonical_order,
426
+ )
427
+
378
428
  def get_index(self, name: str) -> int | None:
379
429
  """Get the index of the axis by name."""
380
- if name not in self._index_mapping.keys():
381
- raise NgioValueError(
382
- f"Invalid axis name '{name}'. "
383
- f"Possible values are {self._index_mapping.keys()}"
384
- )
385
- return self._index_mapping[name]
430
+ return self._index_mapping.get(name, None)
431
+
432
+ def has_axis(self, axis_name: str) -> bool:
433
+ """Return whether the axis exists."""
434
+ index = self.get_index(axis_name)
435
+ if index is None:
436
+ return False
437
+ return True
438
+
439
+ def get_canonical_name(self, name: str) -> str | None:
440
+ """Get the canonical name of the axis by name."""
441
+ return self._axes_setup.get_canonical_name(name)
386
442
 
387
443
  def get_axis(self, name: str) -> Axis | None:
388
444
  """Get the axis object by name."""
389
445
  index = self.get_index(name)
390
446
  if index is None:
391
447
  return None
392
- return self.on_disk_axes[index]
448
+ return self.axes[index]
393
449
 
394
- def validate_axex_type(self):
450
+ def validate_axes_type(self):
395
451
  """Validate the axes type.
396
452
 
397
453
  If the axes type is not correct, a warning is issued.
398
454
  and the axis is implicitly cast to the correct type.
399
455
  """
400
456
  new_axes = []
401
- for axes in self.on_disk_axes:
457
+ for axes in self.axes:
402
458
  for name in self._canonical_order:
403
459
  if axes == self.get_axis(name):
404
460
  new_axes.append(axes.canonical_axis_cast(name))
405
461
  break
406
462
  else:
407
463
  new_axes.append(axes)
408
- self._on_disk_axes = new_axes
409
-
410
- def _change_order(
411
- self, names: Collection[str]
412
- ) -> tuple[tuple[int, ...], tuple[int, ...]]:
413
- unique_names = set()
414
- for name in names:
415
- if name not in self._index_mapping.keys():
416
- raise NgioValueError(
417
- f"Invalid axis name '{name}'. "
418
- f"Possible values are {self._index_mapping.keys()}"
419
- )
420
- _unique_name = self._name_mapping.get(name)
421
- if _unique_name is None:
422
- continue
423
- if _unique_name in unique_names:
424
- raise NgioValueError(
425
- f"Duplicate axis name, two or more '{_unique_name}' were found. "
426
- f"Please provide unique names."
427
- )
428
- unique_names.add(_unique_name)
464
+ self._axes = new_axes
429
465
 
430
- if len(self.on_disk_axes_names) > len(unique_names):
431
- missing_names = set(self.on_disk_axes_names) - unique_names
432
- raise NgioValueError(
433
- f"Some axes where not queried. "
434
- f"Please provide the following missing axes {missing_names}"
435
- )
436
- _indices, _insert = [], []
437
- for i, name in enumerate(names):
438
- _index = self._index_mapping[name]
439
- if _index is None:
440
- _insert.append(i)
441
- else:
442
- _indices.append(self._index_mapping[name])
443
- return tuple(_indices), tuple(_insert)
444
-
445
- def to_order(self, names: Collection[str]) -> tuple[AxesTransformation, ...]:
446
- """Get the new order of the axes."""
447
- _indices, _insert = self._change_order(names)
448
- return AxesTranspose(axes=_indices), AxesExpand(axes=_insert)
449
-
450
- def from_order(self, names: Collection[str]) -> tuple[AxesTransformation, ...]:
451
- """Get the new order of the axes."""
452
- _indices, _insert = self._change_order(names)
453
- # Inverse transpose is just the transpose with the inverse indices
454
- _reverse_indices = tuple(np.argsort(_indices))
455
- return AxesSqueeze(axes=_insert), AxesTranspose(axes=_reverse_indices)
456
-
457
- def to_canonical(self) -> tuple[AxesTransformation, ...]:
458
- """Get the new order of the axes."""
459
- return self.to_order(self._extended_canonical_order)
460
-
461
- def from_canonical(self) -> tuple[AxesTransformation, ...]:
462
- """Get the new order of the axes."""
463
- return self.from_order(self._extended_canonical_order)
464
-
465
-
466
- def canonical_axes(
467
- axes_names: Collection[str],
468
- space_units: SpaceUnits | None = DefaultSpaceUnit,
469
- time_units: TimeUnits | None = DefaultTimeUnit,
470
- ) -> list[Axis]:
466
+
467
+ def build_canonical_axes_handler(
468
+ axes_names: Sequence[str],
469
+ space_units: SpaceUnits | str | None = DefaultSpaceUnit,
470
+ time_units: TimeUnits | str | None = DefaultTimeUnit,
471
+ # user defined args
472
+ axes_setup: AxesSetup | None = None,
473
+ allow_non_canonical_axes: bool = False,
474
+ strict_canonical_order: bool = False,
475
+ ) -> AxesHandler:
471
476
  """Create a new canonical axes mapper.
472
477
 
473
478
  Args:
474
- axes_names (Collection[str] | int): The axes names on disk.
479
+ axes_names (Sequence[str] | int): The axes names on disk.
475
480
  - The axes should be in ['t', 'c', 'z', 'y', 'x']
476
481
  - The axes should be in strict canonical order.
477
482
  - If an integer is provided, the axes are created from the last axis
@@ -479,25 +484,31 @@ def canonical_axes(
479
484
  e.g. 3 -> ["z", "y", "x"]
480
485
  space_units (SpaceUnits, optional): The space units. Defaults to None.
481
486
  time_units (TimeUnits, optional): The time units. Defaults to None.
487
+ axes_setup (AxesSetup, optional): The axis setup. Defaults to None.
488
+ allow_non_canonical_axes (bool, optional): Allow non canonical axes.
489
+ Defaults to False.
490
+ strict_canonical_order (bool, optional): Check if the axes are in the
491
+ canonical order. Defaults to False.
482
492
 
483
493
  """
484
494
  axes = []
485
495
  for name in axes_names:
486
496
  match name:
487
497
  case "t":
488
- axes.append(
489
- Axis(on_disk_name=name, axis_type=AxisType.time, unit=time_units)
490
- )
498
+ axes.append(Axis(name=name, axis_type=AxisType.time, unit=time_units))
491
499
  case "c":
492
- axes.append(Axis(on_disk_name=name, axis_type=AxisType.channel))
500
+ axes.append(Axis(name=name, axis_type=AxisType.channel))
493
501
  case "z" | "y" | "x":
494
- axes.append(
495
- Axis(on_disk_name=name, axis_type=AxisType.space, unit=space_units)
496
- )
502
+ axes.append(Axis(name=name, axis_type=AxisType.space, unit=space_units))
497
503
  case _:
498
504
  raise NgioValueError(
499
505
  f"Invalid axis name '{name}'. "
500
506
  "Only 't', 'c', 'z', 'y', 'x' are allowed."
501
507
  )
502
508
 
503
- return axes
509
+ return AxesHandler(
510
+ axes=axes,
511
+ axes_setup=axes_setup,
512
+ allow_non_canonical_axes=allow_non_canonical_axes,
513
+ strict_canonical_order=strict_canonical_order,
514
+ )
@@ -3,7 +3,7 @@
3
3
  Stores the same information as the Omero section of the ngff 0.4 metadata.
4
4
  """
5
5
 
6
- from collections.abc import Collection
6
+ from collections.abc import Sequence
7
7
  from difflib import SequenceMatcher
8
8
  from enum import Enum
9
9
  from typing import Any, TypeVar
@@ -77,7 +77,8 @@ class NgioColors(str, Enum):
77
77
  # try to match the color to the channel name
78
78
  similarity[color] = SequenceMatcher(None, channel_name, color).ratio()
79
79
  # Get the color with the highest similarity
80
- color_str = max(similarity, key=similarity.get) # type: ignore
80
+ color_str = max(similarity, key=similarity.get) # type: ignore (max type overload fails to infer type)
81
+ assert isinstance(color_str, str), "Color name must be a string."
81
82
  return NgioColors.__members__[color_str]
82
83
 
83
84
 
@@ -287,7 +288,7 @@ class Channel(BaseModel):
287
288
  T = TypeVar("T")
288
289
 
289
290
 
290
- def _check_elements(elements: Collection[T], expected_type: Any) -> Collection[T]:
291
+ def _check_elements(elements: Sequence[T], expected_type: Any) -> Sequence[T]:
291
292
  """Check that the elements are of the same type."""
292
293
  if len(elements) == 0:
293
294
  raise NgioValidationError("At least one element must be provided.")
@@ -301,7 +302,7 @@ def _check_elements(elements: Collection[T], expected_type: Any) -> Collection[T
301
302
  return elements
302
303
 
303
304
 
304
- def _check_unique(elements: Collection[T]) -> Collection[T]:
305
+ def _check_unique(elements: Sequence[T]) -> Sequence[T]:
305
306
  """Check that the elements are unique."""
306
307
  if len(set(elements)) != len(elements):
307
308
  raise NgioValidationError("All elements must be unique.")
@@ -329,35 +330,35 @@ class ChannelsMeta(BaseModel):
329
330
  @classmethod
330
331
  def default_init(
331
332
  cls,
332
- labels: Collection[str | None] | int,
333
- wavelength_id: Collection[str | None] | None = None,
334
- colors: Collection[str | NgioColors | None] | None = None,
335
- start: Collection[int | float | None] | int | float | None = None,
336
- end: Collection[int | float | None] | int | float | None = None,
337
- active: Collection[bool | None] | None = None,
333
+ labels: Sequence[str | None] | int,
334
+ wavelength_id: Sequence[str | None] | None = None,
335
+ colors: Sequence[str | NgioColors | None] | None = None,
336
+ start: Sequence[int | float | None] | int | float | None = None,
337
+ end: Sequence[int | float | None] | int | float | None = None,
338
+ active: Sequence[bool | None] | None = None,
338
339
  data_type: Any = np.uint16,
339
340
  **omero_kwargs: dict,
340
341
  ) -> "ChannelsMeta":
341
342
  """Create a ChannelsMeta object with the default unit.
342
343
 
343
344
  Args:
344
- labels(Collection[str | None] | int): The list of channels names
345
+ labels(Sequence[str | None] | int): The list of channels names
345
346
  in the image. If an integer is provided, the channels will be
346
347
  named "channel_i".
347
- wavelength_id(Collection[str | None] | None): The wavelength ID of the
348
+ wavelength_id(Sequence[str | None] | None): The wavelength ID of the
348
349
  channel. If None, the wavelength ID will be the same as the
349
350
  channel name.
350
- colors(Collection[str | NgioColors | None] | None): The list of
351
+ colors(Sequence[str | NgioColors | None] | None): The list of
351
352
  colors for the channels. If None, the colors will be random.
352
- start(Collection[int | float | None] | int | float | None): The start
353
+ start(Sequence[int | float | None] | int | float | None): The start
353
354
  value of the channel. If None, the start value will be the
354
355
  minimum value of the data type.
355
- end(Collection[int | float | None] | int | float | None): The end
356
+ end(Sequence[int | float | None] | int | float | None): The end
356
357
  value of the channel. If None, the end value will be the
357
358
  maximum value of the data type.
358
359
  data_type(Any): The data type of the channel. Will be used to set the
359
360
  min and max values of the channel.
360
- active (Collection[bool | None] | None): Whether the channel should
361
+ active (Sequence[bool | None] | None): Whether the channel should
361
362
  be shown by default.
362
363
  omero_kwargs(dict): Extra fields to store in the omero attributes.
363
364
  """
@@ -367,9 +368,9 @@ class ChannelsMeta(BaseModel):
367
368
  labels = _check_elements(labels, str)
368
369
  labels = _check_unique(labels)
369
370
 
370
- _wavelength_id: Collection[str | None] = [None] * len(labels)
371
+ _wavelength_id: Sequence[str | None] = [None] * len(labels)
371
372
  if wavelength_id is None:
372
- _wavelength_id: Collection[str | None] = [None] * len(labels)
373
+ _wavelength_id: Sequence[str | None] = [None] * len(labels)
373
374
  else:
374
375
  _wavelength_id = _check_elements(wavelength_id, str)
375
376
  _wavelength_id = _check_unique(wavelength_id)
@@ -425,3 +426,39 @@ class ChannelsMeta(BaseModel):
425
426
  )
426
427
  )
427
428
  return cls(channels=channels, **omero_kwargs)
429
+
430
+ @property
431
+ def channel_labels(self) -> list[str]:
432
+ """Get the labels of the channels in the image."""
433
+ return [channel.label for channel in self.channels]
434
+
435
+ @property
436
+ def channel_wavelength_ids(self) -> list[str | None]:
437
+ """Get the wavelength IDs of the channels in the image."""
438
+ return [channel.wavelength_id for channel in self.channels]
439
+
440
+ def get_channel_idx(
441
+ self, channel_label: str | None = None, wavelength_id: str | None = None
442
+ ) -> int:
443
+ """Get the index of a channel by its label or wavelength ID."""
444
+ # Only one of the arguments must be provided
445
+ if channel_label is not None and wavelength_id is not None:
446
+ raise NgioValueError(
447
+ "get_channel_idx must receive either label or wavelength_id, not both."
448
+ )
449
+
450
+ if channel_label is not None:
451
+ if channel_label not in self.channel_labels:
452
+ raise NgioValueError(f"Channel with label {channel_label} not found.")
453
+ return self.channel_labels.index(channel_label)
454
+
455
+ if wavelength_id is not None:
456
+ if wavelength_id not in self.channel_wavelength_ids:
457
+ raise NgioValueError(
458
+ f"Channel with wavelength ID {wavelength_id} not found."
459
+ )
460
+ return self.channel_wavelength_ids.index(wavelength_id)
461
+
462
+ raise NgioValueError(
463
+ "get_channel_idx must receive either label or wavelength_id"
464
+ )