tensogram-xarray 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ # (C) Copyright 2026- ECMWF and individual contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation nor
7
+ # does it submit to any jurisdiction.
8
+
9
+ """tensogram-xarray: xarray backend engine for tensogram .tgm files.
10
+
11
+ Provides ``engine="tensogram"`` for ``xr.open_dataset()`` and a top-level
12
+ ``open_datasets()`` function for multi-message .tgm files that auto-groups
13
+ incompatible objects into separate Datasets.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from tensogram_xarray.backend import TensogramBackendEntrypoint
19
+ from tensogram_xarray.merge import open_datasets
20
+
21
+ __all__ = [
22
+ "TensogramBackendEntrypoint",
23
+ "open_datasets",
24
+ ]
@@ -0,0 +1,408 @@
1
+ # (C) Copyright 2026- ECMWF and individual contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation nor
7
+ # does it submit to any jurisdiction.
8
+
9
+ """Lazy-loading backend array for tensogram data objects.
10
+
11
+ ``TensogramBackendArray`` implements :class:`xarray.backends.BackendArray` so
12
+ that tensor payloads are decoded on demand. For compressors that support
13
+ random access (``none``, ``szip``, ``blosc2``, ``zfp`` fixed-rate) and have no
14
+ ``shuffle`` filter, N-dimensional slice requests are mapped to flat element
15
+ ranges and decoded via ``tensogram.decode_range()``. Otherwise the full
16
+ object is decoded and sliced in-memory via ``tensogram.decode_object()``.
17
+
18
+ A ratio-based heuristic controls when partial reads are used: if the
19
+ fraction of requested elements exceeds ``range_threshold`` (default 0.5),
20
+ the backend falls back to a full decode.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import logging
26
+ import math
27
+ import os
28
+ import threading
29
+ from itertools import product as iterproduct
30
+ from typing import Any
31
+
32
+ import numpy as np
33
+ from xarray.backends import BackendArray
34
+ from xarray.core import indexing
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Compressor values that support partial decode via decode_range().
39
+ _RANDOM_ACCESS_COMPRESSORS = frozenset({"none", "szip", "blosc2", "zfp"})
40
+
41
+ # Filters that break contiguous byte ranges (shuffle rearranges bytes).
42
+ _RANGE_BLOCKING_FILTERS = frozenset({"shuffle"})
43
+
44
+ # Default ratio threshold: use decode_range when the requested fraction of
45
+ # total elements is at or below this value.
46
+ DEFAULT_RANGE_THRESHOLD = 0.5
47
+
48
+
49
+ def _supports_range_decode(descriptor: Any) -> bool:
50
+ """Return *True* if the object's pipeline allows ``decode_range()``."""
51
+ compression = getattr(descriptor, "compression", "none")
52
+ filt = getattr(descriptor, "filter", "none")
53
+
54
+ if filt in _RANGE_BLOCKING_FILTERS:
55
+ return False
56
+
57
+ if compression not in _RANDOM_ACCESS_COMPRESSORS:
58
+ return False
59
+
60
+ # zfp supports range decode only in fixed_rate mode.
61
+ if compression == "zfp":
62
+ params = getattr(descriptor, "params", {}) or {}
63
+ if params.get("zfp_mode") != "fixed_rate":
64
+ return False
65
+
66
+ return True
67
+
68
+
69
+ def _is_contiguous_slice(key: tuple) -> bool:
70
+ """Return *True* when *key* is a tuple of unit-stride slices."""
71
+ for k in key:
72
+ if not isinstance(k, slice):
73
+ return False
74
+ # Reject non-unit strides (step != 1 and step != None).
75
+ if k.step is not None and k.step != 1:
76
+ return False
77
+ return True
78
+
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # N-D slice -> flat element ranges
82
+ # ---------------------------------------------------------------------------
83
+
84
+
85
+ def _nd_slice_to_flat_ranges(
86
+ shape: tuple[int, ...],
87
+ key: tuple[slice, ...],
88
+ ) -> tuple[list[tuple[int, int]], tuple[int, ...]]:
89
+ """Map an N-dimensional slice to flat ``(start, count)`` ranges.
90
+
91
+ For a C-contiguous (row-major) array the elements of a hyper-rectangular
92
+ slice are **not** contiguous in general. Contiguous runs exist only
93
+ along the innermost (rightmost) axis. This function decomposes the
94
+ N-D slice into the minimal set of flat ranges that cover exactly the
95
+ requested elements, then merges adjacent ranges.
96
+
97
+ Parameters
98
+ ----------
99
+ shape
100
+ Shape of the full tensor.
101
+ key
102
+ Tuple of ``slice`` objects (one per dimension, unit-stride only).
103
+
104
+ Returns
105
+ -------
106
+ flat_ranges
107
+ List of ``(element_offset, element_count)`` in the flattened array.
108
+ output_shape
109
+ Shape of the result after slicing.
110
+ """
111
+ ndim = len(shape)
112
+
113
+ # Parse each slice into (start, count).
114
+ dim_ranges: list[tuple[int, int]] = []
115
+ output_dims: list[int] = []
116
+ for slc, d in zip(key, shape):
117
+ s, e, _ = slc.indices(d)
118
+ count = e - s
119
+ dim_ranges.append((s, count))
120
+ output_dims.append(count)
121
+ output_shape = tuple(output_dims)
122
+
123
+ total_output = math.prod(output_dims)
124
+ if total_output == 0:
125
+ return [], output_shape
126
+
127
+ # Compute C-contiguous strides (in elements).
128
+ strides = [1] * ndim
129
+ for i in range(ndim - 2, -1, -1):
130
+ strides[i] = strides[i + 1] * shape[i + 1]
131
+
132
+ # Find the *split point* k: the innermost dimension whose slice is
133
+ # NOT a full slice. All dimensions k+1 .. n-1 are full slices, so
134
+ # their elements form a contiguous block.
135
+ split = -1 # -1 means all dims are full
136
+ for i in range(ndim - 1, -1, -1):
137
+ start_i, count_i = dim_ranges[i]
138
+ if start_i != 0 or count_i != shape[i]:
139
+ split = i
140
+ break
141
+
142
+ if split == -1:
143
+ # Every dimension is a full slice -- one range covering everything.
144
+ return [(0, math.prod(shape))], output_shape
145
+
146
+ # Contiguous block size: count at split dim * product of trailing dims.
147
+ block_size = dim_ranges[split][1]
148
+ for i in range(split + 1, ndim):
149
+ block_size *= shape[i]
150
+
151
+ block_start_within_row = dim_ranges[split][0] * strides[split]
152
+
153
+ if split == 0:
154
+ # No outer dimensions to iterate.
155
+ return [(dim_ranges[0][0] * strides[0], block_size)], output_shape
156
+
157
+ # Generate one range per combination of outer-dimension indices.
158
+ outer_index_ranges = [
159
+ range(dim_ranges[i][0], dim_ranges[i][0] + dim_ranges[i][1]) for i in range(split)
160
+ ]
161
+
162
+ flat_ranges: list[tuple[int, int]] = []
163
+ for idx in iterproduct(*outer_index_ranges):
164
+ base = sum(idx[j] * strides[j] for j in range(split))
165
+ flat_ranges.append((base + block_start_within_row, block_size))
166
+
167
+ # Merge adjacent ranges (consecutive with no gap).
168
+ if len(flat_ranges) > 1:
169
+ merged: list[tuple[int, int]] = [flat_ranges[0]]
170
+ for start, count in flat_ranges[1:]:
171
+ prev_start, prev_count = merged[-1]
172
+ if start == prev_start + prev_count:
173
+ merged[-1] = (prev_start, prev_count + count)
174
+ else:
175
+ merged.append((start, count))
176
+ flat_ranges = merged
177
+
178
+ return flat_ranges, output_shape
179
+
180
+
181
+ # ---------------------------------------------------------------------------
182
+ # Backend array
183
+ # ---------------------------------------------------------------------------
184
+
185
+
186
+ class TensogramBackendArray(BackendArray):
187
+ """Lazy array backed by a tensogram file.
188
+
189
+ Stores the file path (or remote URL) and optionally a shared file handle.
190
+ The handle is dropped on pickle for dask multiprocessing compatibility
191
+ and lazily reopened on the worker.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ file_path: str,
197
+ msg_index: int,
198
+ obj_index: int,
199
+ shape: tuple[int, ...],
200
+ dtype: np.dtype,
201
+ supports_range: bool,
202
+ *,
203
+ verify_hash: bool = False,
204
+ range_threshold: float = DEFAULT_RANGE_THRESHOLD,
205
+ lock: threading.Lock | None = None,
206
+ storage_options: dict[str, Any] | None = None,
207
+ shared_file: Any | None = None,
208
+ ):
209
+ import tensogram
210
+
211
+ self._is_remote = tensogram.is_remote_url(file_path)
212
+ self.file_path = file_path if self._is_remote else os.path.abspath(file_path)
213
+ self.msg_index = msg_index
214
+ self.obj_index = obj_index
215
+ self.shape = shape
216
+ self.dtype = dtype
217
+ self.supports_range = supports_range
218
+ self.verify_hash = verify_hash
219
+ self.range_threshold = range_threshold
220
+ self.storage_options = storage_options
221
+ self._shared_file = shared_file
222
+
223
+ # -- pickle support (no open handles stored) ----------------------------
224
+
225
+ def __getstate__(self) -> dict:
226
+ state = self.__dict__.copy()
227
+ state["_shared_file"] = None
228
+ return state
229
+
230
+ def __setstate__(self, state: dict) -> None:
231
+ self.__dict__.update(state)
232
+ self._shared_file = None
233
+
234
+ # -- BackendArray interface ---------------------------------------------
235
+
236
+ def __getitem__(self, key: indexing.ExplicitIndexer) -> np.ndarray:
237
+ return indexing.explicit_indexing_adapter(
238
+ key,
239
+ self.shape,
240
+ indexing.IndexingSupport.BASIC,
241
+ self._raw_indexing_method,
242
+ )
243
+
244
+ def _get_file(self):
245
+ if self._shared_file is not None:
246
+ return self._shared_file
247
+ import tensogram
248
+
249
+ if self._is_remote:
250
+ return tensogram.TensogramFile.open_remote(self.file_path, self.storage_options or {})
251
+ return tensogram.TensogramFile.open(self.file_path)
252
+
253
+ def _raw_indexing_method(self, key: tuple) -> np.ndarray:
254
+ import tensogram
255
+
256
+ if self._shared_file is not None:
257
+ return self._read_from_file(self._shared_file, key, tensogram)
258
+
259
+ with self._get_file() as f:
260
+ return self._read_from_file(f, key, tensogram)
261
+
262
+ def _read_from_file(self, f, key: tuple, tensogram) -> np.ndarray:
263
+ if self.supports_range and _is_contiguous_slice(key):
264
+ try:
265
+ flat_ranges, out_shape = _nd_slice_to_flat_ranges(self.shape, key)
266
+ total_requested = sum(c for _, c in flat_ranges)
267
+ total_elements = math.prod(self.shape)
268
+
269
+ if total_elements > 0 and total_requested / total_elements <= self.range_threshold:
270
+ arr = f.file_decode_range(
271
+ self.msg_index,
272
+ obj_index=self.obj_index,
273
+ ranges=flat_ranges,
274
+ join=True,
275
+ verify_hash=self.verify_hash,
276
+ native_byte_order=True,
277
+ )
278
+ return np.asarray(arr).reshape(out_shape)
279
+ except (ValueError, RuntimeError, OSError) as exc:
280
+ logger.debug(
281
+ "decode_range failed for %s msg=%d obj=%d, falling back to full decode: %s",
282
+ self.file_path,
283
+ self.msg_index,
284
+ self.obj_index,
285
+ exc,
286
+ )
287
+
288
+ if self._is_remote:
289
+ result = f.file_decode_object(
290
+ self.msg_index,
291
+ self.obj_index,
292
+ verify_hash=self.verify_hash,
293
+ )
294
+ return np.asarray(result["data"][key])
295
+
296
+ raw_msg = f.read_message(self.msg_index)
297
+ _meta, _desc, arr = tensogram.decode_object(
298
+ raw_msg,
299
+ self.obj_index,
300
+ verify_hash=self.verify_hash,
301
+ )
302
+ return np.asarray(arr[key])
303
+
304
+
305
+ # ---------------------------------------------------------------------------
306
+ # Stacked backend array (lazy hypercube)
307
+ # ---------------------------------------------------------------------------
308
+
309
+
310
+ class StackedBackendArray(BackendArray):
311
+ """Lazy stacked array composed of multiple :class:`TensogramBackendArray`.
312
+
313
+ Each position along the outer dimensions maps to a separate backing
314
+ array. Indexing dispatches to the correct backing array(s) and
315
+ assembles the result, so no data is decoded until actually accessed.
316
+ """
317
+
318
+ def __init__(
319
+ self,
320
+ arrays: list[TensogramBackendArray],
321
+ outer_shape: tuple[int, ...],
322
+ inner_shape: tuple[int, ...],
323
+ dtype: np.dtype,
324
+ ):
325
+ if len(arrays) != math.prod(outer_shape):
326
+ msg = (
327
+ f"StackedBackendArray: expected {math.prod(outer_shape)} "
328
+ f"backing arrays for outer_shape={outer_shape}, "
329
+ f"got {len(arrays)}"
330
+ )
331
+ raise ValueError(msg)
332
+
333
+ self._arrays = arrays
334
+ self._outer_shape = outer_shape
335
+ self._inner_shape = inner_shape
336
+ self.shape = outer_shape + inner_shape
337
+ self.dtype = dtype
338
+
339
+ def __getitem__(self, key: indexing.ExplicitIndexer) -> np.ndarray:
340
+ return indexing.explicit_indexing_adapter(
341
+ key,
342
+ self.shape,
343
+ indexing.IndexingSupport.BASIC,
344
+ self._raw_indexing_method,
345
+ )
346
+
347
+ def _raw_indexing_method(self, key: tuple) -> np.ndarray:
348
+ n_outer = len(self._outer_shape)
349
+
350
+ # Split key into outer and inner parts.
351
+ outer_key = key[:n_outer]
352
+ inner_key = key[n_outer:]
353
+
354
+ # Compute which backing arrays are needed.
355
+ outer_indices = _expand_key_to_indices(outer_key, self._outer_shape)
356
+
357
+ # Determine output shape for outer dimensions.
358
+ outer_out_shape = tuple(len(idx) for idx in outer_indices)
359
+
360
+ # Compute inner output shape from inner_key: slices preserve the
361
+ # dimension (with the slice length), integer keys drop it -- matching
362
+ # numpy's basic-indexing semantics.
363
+ inner_out_shape = tuple(
364
+ len(range(*k.indices(s)))
365
+ for k, s in zip(inner_key, self._inner_shape)
366
+ if isinstance(k, slice)
367
+ )
368
+
369
+ result = np.empty(outer_out_shape + inner_out_shape, dtype=self.dtype)
370
+
371
+ for flat_pos, combo in enumerate(iterproduct(*outer_indices)):
372
+ # Map N-D outer index to flat backing-array index (row-major).
373
+ flat_idx = 0
374
+ for dim, idx_val in enumerate(combo):
375
+ stride = 1
376
+ for d2 in range(dim + 1, n_outer):
377
+ stride *= self._outer_shape[d2]
378
+ flat_idx += idx_val * stride
379
+
380
+ backing = self._arrays[flat_idx]
381
+ inner_data = backing._raw_indexing_method(inner_key)
382
+
383
+ # Unravel flat_pos into N-D output position (row-major / C order).
384
+ # iterproduct iterates in row-major order (rightmost index varies
385
+ # fastest), so unraveling must go right-to-left.
386
+ out_idx: list[int] = []
387
+ remainder = flat_pos
388
+ for size in reversed(outer_out_shape):
389
+ out_idx.append(remainder % size)
390
+ remainder //= size
391
+ out_idx.reverse()
392
+ result[tuple(out_idx)] = inner_data
393
+
394
+ # Apply outer slicing to produce correct output shape.
395
+ return result
396
+
397
+
398
+ def _expand_key_to_indices(key: tuple, shape: tuple[int, ...]) -> list[list[int]]:
399
+ """Expand a tuple of slices/ints into lists of concrete indices."""
400
+ result: list[list[int]] = []
401
+ for k, size in zip(key, shape):
402
+ if isinstance(k, slice):
403
+ result.append(list(range(*k.indices(size))))
404
+ elif isinstance(k, int):
405
+ result.append([k])
406
+ else:
407
+ result.append(list(range(size)))
408
+ return result
@@ -0,0 +1,139 @@
1
+ # (C) Copyright 2026- ECMWF and individual contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation nor
7
+ # does it submit to any jurisdiction.
8
+
9
+ """xarray backend entry point for tensogram ``.tgm`` files.
10
+
11
+ Registers ``engine="tensogram"`` with xarray via the ``xarray.backends``
12
+ entry point in ``pyproject.toml``.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ from collections.abc import Iterable, Sequence
19
+ from typing import Any
20
+
21
+ import xarray as xr
22
+ from xarray.backends import BackendEntrypoint
23
+
24
+ from tensogram_xarray.store import TensogramDataStore
25
+
26
+
27
+ class TensogramBackendEntrypoint(BackendEntrypoint):
28
+ """Open tensogram ``.tgm`` files as xarray Datasets.
29
+
30
+ Usage::
31
+
32
+ import xarray as xr
33
+
34
+ # Simple open (single message, generic dim names)
35
+ ds = xr.open_dataset("file.tgm", engine="tensogram")
36
+
37
+ # With user-specified dimension names
38
+ ds = xr.open_dataset("file.tgm", engine="tensogram",
39
+ dim_names=["latitude", "longitude"])
40
+
41
+ # With variable naming from metadata
42
+ ds = xr.open_dataset("file.tgm", engine="tensogram",
43
+ variable_key="mars.param")
44
+ """
45
+
46
+ description = "Open tensogram .tgm files in xarray"
47
+ url = "https://github.com/ecmwf/tensogram"
48
+
49
+ def open_dataset( # type: ignore[override]
50
+ self,
51
+ filename_or_obj: str | os.PathLike,
52
+ *,
53
+ drop_variables: Iterable[str] | None = None,
54
+ dim_names: Sequence[str] | None = None,
55
+ variable_key: str | None = None,
56
+ message_index: int = 0,
57
+ merge_objects: bool = False,
58
+ verify_hash: bool = False,
59
+ range_threshold: float = 0.5,
60
+ storage_options: dict[str, Any] | None = None,
61
+ ) -> xr.Dataset:
62
+ """Open a single tensogram message as an :class:`xr.Dataset`.
63
+
64
+ Parameters
65
+ ----------
66
+ filename_or_obj
67
+ Path to a ``.tgm`` file.
68
+ drop_variables
69
+ Variable names to exclude from the Dataset.
70
+ dim_names
71
+ Explicit dimension names for data variables. Must have exactly
72
+ as many entries as the tensor has axes.
73
+ variable_key
74
+ Dotted metadata path (e.g. ``"mars.param"``) whose value at each
75
+ data object becomes the xarray variable name.
76
+ message_index
77
+ Which message to open when the file contains multiple messages.
78
+ merge_objects
79
+ If *True*, attempt to merge objects across messages by stacking
80
+ along metadata dimensions that vary. When *False* (default),
81
+ only the single message at *message_index* is opened.
82
+ verify_hash
83
+ Whether to verify xxh3 hashes during decode.
84
+ range_threshold
85
+ Maximum fraction of total array elements (0.0-1.0) for which
86
+ partial ``decode_range()`` is used instead of a full
87
+ ``decode_object()``. Default ``0.5`` (50%).
88
+ storage_options
89
+ Key-value pairs forwarded to the object store backend when
90
+ the path is a remote URL. Ignored for local files.
91
+
92
+ Returns
93
+ -------
94
+ xr.Dataset
95
+ """
96
+ file_path = str(filename_or_obj)
97
+
98
+ if message_index < 0:
99
+ msg = f"message_index must be >= 0, got {message_index}"
100
+ raise ValueError(msg)
101
+
102
+ if merge_objects:
103
+ # Delegate to open_datasets and return the first result.
104
+ from tensogram_xarray.merge import open_datasets
105
+
106
+ datasets = open_datasets(
107
+ file_path,
108
+ dim_names=dim_names,
109
+ variable_key=variable_key,
110
+ verify_hash=verify_hash,
111
+ range_threshold=range_threshold,
112
+ storage_options=storage_options,
113
+ )
114
+ if not datasets:
115
+ return xr.Dataset()
116
+ return datasets[0]
117
+
118
+ store = TensogramDataStore(
119
+ file_path=file_path,
120
+ msg_index=message_index,
121
+ dim_names=dim_names,
122
+ variable_key=variable_key,
123
+ verify_hash=verify_hash,
124
+ range_threshold=range_threshold,
125
+ storage_options=storage_options,
126
+ )
127
+
128
+ drop_set = set(drop_variables) if drop_variables else None
129
+ ds = store.build_dataset(drop_variables=drop_set)
130
+ ds.set_close(store.close)
131
+ return ds
132
+
133
+ def guess_can_open(self, filename_or_obj: str) -> bool: # type: ignore[override]
134
+ """Return *True* for files with ``.tgm`` extension."""
135
+ try:
136
+ _, ext = os.path.splitext(filename_or_obj)
137
+ except (TypeError, AttributeError):
138
+ return False
139
+ return ext.lower() == ".tgm"
@@ -0,0 +1,113 @@
1
+ # (C) Copyright 2026- ECMWF and individual contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation nor
7
+ # does it submit to any jurisdiction.
8
+
9
+ """Coordinate detection by name matching.
10
+
11
+ When a data object's per-object metadata contains a ``name`` or ``param`` key
12
+ whose value matches a known coordinate name (case-insensitive), that object is
13
+ treated as a coordinate array rather than a data variable.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections.abc import Sequence
19
+ from typing import Any
20
+
21
+ # Known coordinate names (all lower-case for case-insensitive matching).
22
+ KNOWN_COORD_NAMES: frozenset[str] = frozenset(
23
+ {
24
+ "lat",
25
+ "latitude",
26
+ "lon",
27
+ "longitude",
28
+ "x",
29
+ "y",
30
+ "time",
31
+ "level",
32
+ "pressure",
33
+ "height",
34
+ "depth",
35
+ "frequency",
36
+ "step",
37
+ }
38
+ )
39
+
40
+ # Canonical name mapping: aliases -> preferred dimension name.
41
+ CANONICAL_DIM: dict[str, str] = {
42
+ "lat": "latitude",
43
+ "latitude": "latitude",
44
+ "lon": "longitude",
45
+ "longitude": "longitude",
46
+ "x": "x",
47
+ "y": "y",
48
+ "time": "time",
49
+ "level": "level",
50
+ "pressure": "pressure",
51
+ "height": "height",
52
+ "depth": "depth",
53
+ "frequency": "frequency",
54
+ "step": "step",
55
+ }
56
+
57
+
58
+ # Module-level assertion: every known name must have a canonical mapping.
59
+ _canonical_keys = frozenset(CANONICAL_DIM.keys())
60
+ assert _canonical_keys == KNOWN_COORD_NAMES, (
61
+ f"KNOWN_COORD_NAMES and CANONICAL_DIM keys are out of sync: "
62
+ f"missing from CANONICAL_DIM: {KNOWN_COORD_NAMES - _canonical_keys}, "
63
+ f"extra in CANONICAL_DIM: {_canonical_keys - KNOWN_COORD_NAMES}"
64
+ )
65
+
66
+
67
+ def _get_object_name(meta: dict[str, Any]) -> str | None:
68
+ """Extract the name/param identifier from per-object metadata.
69
+
70
+ Checks ``name``, ``param``, and nested ``mars.param`` in that order.
71
+ """
72
+ if "name" in meta:
73
+ return str(meta["name"])
74
+ if "param" in meta:
75
+ return str(meta["param"])
76
+ mars = meta.get("mars")
77
+ if isinstance(mars, dict) and "param" in mars:
78
+ return str(mars["param"])
79
+ return None
80
+
81
+
82
+ def detect_coords(
83
+ object_metas: Sequence[dict[str, Any]],
84
+ ) -> tuple[list[int], list[int], dict[int, str]]:
85
+ """Partition data objects into coordinates and variables.
86
+
87
+ Parameters
88
+ ----------
89
+ object_metas
90
+ Per-object metadata dicts (one per data object in the message).
91
+
92
+ Returns
93
+ -------
94
+ coord_indices
95
+ Indices of objects identified as coordinates.
96
+ var_indices
97
+ Indices of objects identified as data variables.
98
+ coord_dim_names
99
+ Mapping from coord object index to canonical dimension name.
100
+ """
101
+ coord_indices: list[int] = []
102
+ var_indices: list[int] = []
103
+ coord_dim_names: dict[int, str] = {}
104
+
105
+ for i, meta in enumerate(object_metas):
106
+ obj_name = _get_object_name(meta)
107
+ if obj_name is not None and obj_name.lower() in KNOWN_COORD_NAMES:
108
+ coord_indices.append(i)
109
+ coord_dim_names[i] = CANONICAL_DIM[obj_name.lower()]
110
+ else:
111
+ var_indices.append(i)
112
+
113
+ return coord_indices, var_indices, coord_dim_names